In [2]:
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

In [6]:
trainset = datasets.FashionMNIST(root = 'data', train=True, transform = ToTensor(), download = True)
testset = datasets.FashionMNIST(root = 'data', train=False, transform = ToTensor(), download = True)

batch_size = 8 
trainloader = DataLoader(dataset= trainset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(dataset= testset, batch_size=batch_size, shuffle=False)

In [10]:
trainset.class_to_idx

{'T-shirt/top': 0,
 'Trouser': 1,
 'Pullover': 2,
 'Dress': 3,
 'Coat': 4,
 'Sandal': 5,
 'Shirt': 6,
 'Sneaker': 7,
 'Bag': 8,
 'Ankle boot': 9}

In [7]:
img, lab = next(iter(trainloader))

In [8]:
img.shape

torch.Size([8, 1, 28, 28])

In [35]:
class cnn_classifier(nn.Module):
    """
    custom cnn classifier
    """
    def __init__(self):
        super().__init__()

        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(),
                nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size = 3, padding=1),\
                nn.BatchNorm2d(out_channels),
                nn.ReLU()
            )
        self.conv = nn.Sequential(
            conv_block(1, 32), # [B, 32, 28, 28]
            nn.Dropout2d(p = 0.3), 
            conv_block(32, 64), # [B, 64, 28, 28]
            nn.Dropout2d(p= 0.3),
            conv_block(64, 128), # [B, 128, 28, 28]
            nn.Dropout2d(p= 0.3),
            nn.MaxPool2d(kernel_size=2), # [B, 128, 14, 14]
            nn.AdaptiveAvgPool2d((1,1))  # [B, 128, 1, 1]
        )
        self.linear = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv(x)  # [B, 64, 1, 1]
        x = x.view((x.shape[0], -1)) # [B, 64]
        x = self.linear(x)
        return x

In [36]:
from tqdm import tqdm

In [37]:
cnn_model = cnn_classifier()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(cnn_model.parameters(), lr = 1e-3)

In [38]:
def train(trainloader, model, loss_fn, optimizer):
    model.train()
    train_loss = 0
    for img, label in tqdm(trainloader):
        # forward pass to compute output tensor (prediction)
        logits = model(img)

        # compute loss
        loss = loss_fn(logits, label)

        optimizer.zero_grad()

        # backward propagation
        loss.backward()

        optimizer.step()
        
        train_loss += loss.item()
    print(f"train loss: {train_loss / len(trainloader)}")

def evaluation(testloader, model, loss_fn):
    model.eval()
    test_loss = 0
    correct, total = 0, 0
    with torch.no_grad():
        for img, label in testloader:
            logits = model(img)

            loss = loss_fn(logits, label)
            
            pred = logits.argmax(dim=1)
            
            correct += (pred == label).sum().item()
            total += len(pred)

            test_loss += loss.item()
    print(f"test loss: {test_loss/ len(testloader)}")
    print(f"Accuracy: {correct / total}%")

            
            

In [39]:
for t in range(5):
    print(f"epoch: {t +1}")
    train(trainloader, cnn_model, criterion, optimizer)
    evaluation(testloader, cnn_model, criterion)

epoch: 1


100%|██████████| 7500/7500 [05:21<00:00, 23.31it/s]


train loss: 1.6292589291334152
test loss: 1.173403601169586
Accuracy: 0.6202%
epoch: 2


100%|██████████| 7500/7500 [05:22<00:00, 23.24it/s]


train loss: 1.252363410282135
test loss: 1.0014738932371139
Accuracy: 0.6774%
epoch: 3


100%|██████████| 7500/7500 [05:22<00:00, 23.24it/s]


train loss: 1.0880736953616141
test loss: 0.8819733852624894
Accuracy: 0.6963%
epoch: 4


100%|██████████| 7500/7500 [05:21<00:00, 23.30it/s]


train loss: 0.9673961407740911
test loss: 0.8407222236156464
Accuracy: 0.7032%
epoch: 5


100%|██████████| 7500/7500 [05:22<00:00, 23.26it/s]


train loss: 0.8832026705900828
test loss: 0.7404620177924633
Accuracy: 0.7352%


In [12]:

rand.reshape()
rand.view()


tensor([[0.6552]])