In [1]:
import torchvision.transforms as transforms
import torch.optim as optim
import time
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import FashionMNIST
from tqdm.notebook import tqdm

device = 'cuda'

In [12]:
batch_size = 128

# Load the FashionMNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)), transforms.Resize((224,224))])
train_dataset = FashionMNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = FashionMNIST(root='./data', train=False, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [13]:
def init_cnn(module):
    if type(module) == nn.Linear or type(module) == nn.Conv2d:
        nn.init.xavier_uniform_(module.weight)

def nin_block(out_channels, kernel_size, strides, padding):
    return nn.Sequential(
        nn.LazyConv2d(out_channels, kernel_size, strides, padding), nn.ReLU(),
        nn.LazyConv2d(out_channels, kernel_size=1), nn.ReLU(),
        nn.LazyConv2d(out_channels, kernel_size=1), nn.ReLU()
    )

class NiN(nn.Module):
    def __init__(self, num_classes = 10):
        super().__init__()
        self.net = nn.Sequential(
            nin_block(96, kernel_size=11, strides=4, padding=0),
            nn.MaxPool2d(3, stride=2),
            nin_block(256, kernel_size=5, strides=1, padding=2),
            nn.MaxPool2d(3, stride=2),
            nin_block(384, kernel_size=3, strides=1, padding=1),
            nn.MaxPool2d(3, stride=2),
            nn.Dropout(0.5),
            nin_block(num_classes, kernel_size=3, strides=1, padding=1),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten())
        self.net.apply(init_cnn)


    def forward(self, x):
        x = self.net(x)
        return x

    def apply_init(self, inputs, init=None):
        self.forward(*inputs)
        if init is not None:
            self.net.apply(init)


In [18]:
model = NiN().to(device)
input_data = next(iter(train_loader))[0].to(device)
model.apply_init([input_data], init_cnn)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), 0.001)

total_step = len(train_loader)
for epoch in range(10):
    epoch_loss = 0.0
    for i, (images, labels) in tqdm(enumerate(train_loader), total=total_step):
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    # Print average epoch loss
    average_loss = epoch_loss / total_step
    print(f"Epoch [{epoch+1}/{5}], Average Loss: {average_loss:.4f}")

Epoch [4/5], Average Loss: 0.3537


  0%|          | 0/469 [00:00<?, ?it/s]

Epoch [5/5], Average Loss: 0.3227


  0%|          | 0/469 [00:00<?, ?it/s]

Epoch [6/5], Average Loss: 0.2995


  0%|          | 0/469 [00:00<?, ?it/s]

Epoch [7/5], Average Loss: 0.2862


  0%|          | 0/469 [00:00<?, ?it/s]

Epoch [8/5], Average Loss: 0.2654


  0%|          | 0/469 [00:00<?, ?it/s]

Epoch [9/5], Average Loss: 0.2468


  0%|          | 0/469 [00:00<?, ?it/s]

Epoch [10/5], Average Loss: 0.2394


  0%|          | 0/469 [00:00<?, ?it/s]

Epoch [1/5], Average Loss: 1.1404


  0%|          | 0/469 [00:00<?, ?it/s]

Epoch [2/5], Average Loss: 0.5912


  0%|          | 0/469 [00:00<?, ?it/s]

Epoch [3/5], Average Loss: 0.4819


  0%|          | 0/469 [00:00<?, ?it/s]

Epoch [4/5], Average Loss: 0.4007


  0%|          | 0/469 [00:00<?, ?it/s]

Epoch [5/5], Average Loss: 0.3427


  0%|          | 0/469 [00:00<?, ?it/s]

Epoch [6/5], Average Loss: 0.3157


  0%|          | 0/469 [00:00<?, ?it/s]

Epoch [7/5], Average Loss: 0.2869


  0%|          | 0/469 [00:00<?, ?it/s]

Epoch [8/5], Average Loss: 0.2703


  0%|          | 0/469 [00:00<?, ?it/s]

Epoch [9/5], Average Loss: 0.2561


  0%|          | 0/469 [00:00<?, ?it/s]

Epoch [10/5], Average Loss: 0.2413


  0%|          | 0/469 [00:00<?, ?it/s]

Epoch [1/5], Average Loss: 1.0604


  0%|          | 0/469 [00:00<?, ?it/s]

KeyboardInterrupt: 