In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from einops import rearrange

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [None]:
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
class ModularCNN(nn.Module):
    def __init__(self, n_classes: int):
        super(ModularCNN, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.Dropout(p=0.4, inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.Dropout(p=0.4, inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
        )
        self.dropout=nn.Dropout(p=0.4, inplace=True)
        self.fc1 = nn.Linear(64 * 16 * 16, 512)
        self.fc2 = nn.Linear(512, n_classes)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        out = self.layers(x)
        out = self.dropout(out)
        out = self.pool(out)
        out = rearrange(out, 'b c h w -> b (c h w)')
        out = self.relu(self.fc1(out))
        return out

In [None]:
model = ModularCNN(n_classes = 10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
from tqdm import tqdm

def main(train_loader, test_loader, num_epochs):
    for epoch in range(num_epochs):
        train_loss, train_acc = 0, 0
        for inputs, labels in tqdm(train_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(inputs)
            loss = criterion(output, labels)
            _, pred = torch.max(output, 1)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            train_acc += torch.sum(pred == labels.data)

        epoch_loss = train_loss / len(train_dataset)
        epoch_acc = (train_acc.double() / len(train_dataset)) * 100
        print(f"Epoch {epoch+1}/{num_epochs} Loss: {epoch_loss:.3f} Acc: {epoch_acc:.2f}%")

    with torch.no_grad():
        run_loss, run_acc = 0, 0
        for inputs, labels in tqdm(test_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            output = model(inputs)
            loss = criterion(output, labels)
            _, pred = torch.max(output, 1)
            run_loss += loss.item() * inputs.size(0)
            run_acc += torch.sum(pred == labels.data)
    test_loss = run_loss / len(test_dataset)
    test_acc = (run_acc.double() / len(test_dataset)) * 100
    print(f"Test Loss: {test_loss:.3f}, Test Acc: {test_acc:.2f}%")

In [None]:
n_epochs: int = 50
main(train_loader, test_loader, n_epochs)

100%|██████████| 782/782 [00:17<00:00, 44.94it/s]


Epoch 1/50 Loss: 1.832 Acc: 38.21%


100%|██████████| 782/782 [00:17<00:00, 44.02it/s]


Epoch 2/50 Loss: 1.417 Acc: 50.19%


100%|██████████| 782/782 [00:16<00:00, 47.33it/s]


Epoch 3/50 Loss: 1.283 Acc: 54.99%


100%|██████████| 782/782 [00:17<00:00, 45.61it/s]


Epoch 4/50 Loss: 1.196 Acc: 58.32%


100%|██████████| 782/782 [00:17<00:00, 44.41it/s]


Epoch 5/50 Loss: 1.112 Acc: 61.62%


100%|██████████| 782/782 [00:17<00:00, 45.60it/s]


Epoch 6/50 Loss: 1.056 Acc: 63.34%


100%|██████████| 782/782 [00:16<00:00, 46.40it/s]


Epoch 7/50 Loss: 1.005 Acc: 65.05%


100%|██████████| 782/782 [00:16<00:00, 46.64it/s]


Epoch 8/50 Loss: 0.975 Acc: 66.44%


100%|██████████| 782/782 [00:17<00:00, 44.54it/s]


Epoch 9/50 Loss: 0.950 Acc: 67.48%


100%|██████████| 782/782 [00:16<00:00, 46.53it/s]


Epoch 10/50 Loss: 0.921 Acc: 68.33%


100%|██████████| 782/782 [00:16<00:00, 46.63it/s]


Epoch 11/50 Loss: 0.902 Acc: 68.68%


100%|██████████| 782/782 [00:17<00:00, 43.94it/s]


Epoch 12/50 Loss: 0.889 Acc: 69.43%


100%|██████████| 782/782 [00:16<00:00, 46.42it/s]


Epoch 13/50 Loss: 0.870 Acc: 70.06%


100%|██████████| 782/782 [00:16<00:00, 46.33it/s]


Epoch 14/50 Loss: 0.854 Acc: 70.73%


100%|██████████| 782/782 [00:17<00:00, 45.13it/s]


Epoch 15/50 Loss: 0.842 Acc: 71.05%


100%|██████████| 782/782 [00:16<00:00, 46.46it/s]


Epoch 16/50 Loss: 0.830 Acc: 71.24%


100%|██████████| 782/782 [00:16<00:00, 46.23it/s]


Epoch 17/50 Loss: 0.825 Acc: 71.40%


100%|██████████| 782/782 [00:17<00:00, 45.47it/s]


Epoch 18/50 Loss: 0.809 Acc: 72.19%


100%|██████████| 782/782 [00:17<00:00, 44.17it/s]


Epoch 19/50 Loss: 0.801 Acc: 72.24%


100%|██████████| 782/782 [00:17<00:00, 45.56it/s]


Epoch 20/50 Loss: 0.791 Acc: 72.69%


100%|██████████| 782/782 [00:17<00:00, 44.86it/s]


Epoch 21/50 Loss: 0.784 Acc: 72.82%


100%|██████████| 782/782 [00:18<00:00, 42.89it/s]


Epoch 22/50 Loss: 0.776 Acc: 72.98%


100%|██████████| 782/782 [00:17<00:00, 45.08it/s]


Epoch 23/50 Loss: 0.768 Acc: 73.39%


100%|██████████| 782/782 [00:17<00:00, 44.64it/s]


Epoch 24/50 Loss: 0.759 Acc: 73.69%


100%|██████████| 782/782 [00:17<00:00, 43.86it/s]


Epoch 25/50 Loss: 0.755 Acc: 73.82%


100%|██████████| 782/782 [00:17<00:00, 45.88it/s]


Epoch 26/50 Loss: 0.745 Acc: 74.28%


100%|██████████| 782/782 [00:17<00:00, 45.90it/s]


Epoch 27/50 Loss: 0.740 Acc: 74.15%


100%|██████████| 782/782 [00:18<00:00, 42.98it/s]


Epoch 28/50 Loss: 0.735 Acc: 74.52%


100%|██████████| 782/782 [00:16<00:00, 46.07it/s]


Epoch 29/50 Loss: 0.729 Acc: 74.50%


100%|██████████| 782/782 [00:16<00:00, 46.02it/s]


Epoch 30/50 Loss: 0.725 Acc: 74.80%


100%|██████████| 782/782 [00:18<00:00, 43.32it/s]


Epoch 31/50 Loss: 0.714 Acc: 75.14%


100%|██████████| 782/782 [00:16<00:00, 46.53it/s]


Epoch 32/50 Loss: 0.705 Acc: 75.44%


100%|██████████| 782/782 [00:16<00:00, 46.47it/s]


Epoch 33/50 Loss: 0.704 Acc: 75.43%


100%|██████████| 782/782 [00:18<00:00, 42.14it/s]


Epoch 34/50 Loss: 0.699 Acc: 75.60%


100%|██████████| 782/782 [00:17<00:00, 44.99it/s]


Epoch 35/50 Loss: 0.694 Acc: 75.67%


100%|██████████| 782/782 [00:16<00:00, 46.02it/s]


Epoch 36/50 Loss: 0.693 Acc: 75.69%


100%|██████████| 782/782 [00:18<00:00, 43.25it/s]


Epoch 37/50 Loss: 0.689 Acc: 75.88%


100%|██████████| 782/782 [00:17<00:00, 43.84it/s]


Epoch 38/50 Loss: 0.683 Acc: 76.07%


100%|██████████| 782/782 [00:17<00:00, 44.74it/s]


Epoch 39/50 Loss: 0.678 Acc: 76.37%


100%|██████████| 782/782 [00:18<00:00, 43.02it/s]


Epoch 40/50 Loss: 0.674 Acc: 76.35%


100%|██████████| 782/782 [00:17<00:00, 44.73it/s]


Epoch 41/50 Loss: 0.668 Acc: 76.59%


100%|██████████| 782/782 [00:17<00:00, 45.43it/s]


Epoch 42/50 Loss: 0.668 Acc: 76.40%


100%|██████████| 782/782 [00:18<00:00, 42.11it/s]


Epoch 43/50 Loss: 0.659 Acc: 77.03%


100%|██████████| 782/782 [00:17<00:00, 44.35it/s]


Epoch 44/50 Loss: 0.658 Acc: 76.96%


100%|██████████| 782/782 [00:17<00:00, 45.16it/s]


Epoch 45/50 Loss: 0.650 Acc: 77.14%


100%|██████████| 782/782 [00:18<00:00, 42.43it/s]


Epoch 46/50 Loss: 0.652 Acc: 77.12%


100%|██████████| 782/782 [00:16<00:00, 46.01it/s]


Epoch 47/50 Loss: 0.651 Acc: 77.10%


100%|██████████| 782/782 [00:17<00:00, 43.76it/s]


Epoch 48/50 Loss: 0.648 Acc: 77.36%


100%|██████████| 782/782 [00:18<00:00, 41.99it/s]


Epoch 49/50 Loss: 0.642 Acc: 77.54%


100%|██████████| 782/782 [00:17<00:00, 45.10it/s]


Epoch 50/50 Loss: 0.638 Acc: 77.53%


100%|██████████| 157/157 [00:02<00:00, 52.85it/s]

Test Loss: 1.095, Test Acc: 66.08%



