In [2]:
import torch
import torchvision
train = torchvision.datasets.CIFAR10(root='./data',train=True,download=False,transform=torchvision.transforms.ToTensor())
test = torchvision.datasets.CIFAR10(root='./data',train=False,download=False,transform=torchvision.transforms.ToTensor())

In [3]:
trainloader = torch.utils.data.DataLoader(train,batch_size=512,shuffle=True)
testloader = torch.utils.data.DataLoader(test,batch_size=512,shuffle=False)

In [15]:
import torch.nn as nn
import torch.nn.functional as F
import torch

class SEBlock(nn.Module):
    """ Squeeze-and-Excitation Block """
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.fc1 = nn.Linear(channels, channels // reduction)
        self.fc2 = nn.Linear(channels // reduction, channels)

    def forward(self, x):
        batch, channels, _, _ = x.size()
        y = x.mean(dim=[2, 3])  # Global Average Pooling
        y = F.silu(self.fc1(y))
        y = torch.sigmoid(self.fc2(y)).view(batch, channels, 1, 1)
        return x * y

class ImprovedCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(ImprovedCNN, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)  # Downsample
        self.bn2 = nn.BatchNorm2d(64)
        self.se1 = SEBlock(64)  # Squeeze-and-Excitation

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)

        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)  # Downsample
        self.bn4 = nn.BatchNorm2d(256)
        self.se2 = SEBlock(256)

        self.fc1 = nn.Linear(16384, 4096)
        self.bn_fc1 = nn.BatchNorm1d(4096)  # BatchNorm for stability
        self.fc2 = nn.Linear(4096, 128)
        self.bn_fc2 = nn.BatchNorm1d(128)
        self.fc3 = nn.Linear(128, num_classes)

        self.dropout = nn.Dropout(0.3)  # Increased Dropout

    def forward(self, x):
        x = F.silu(self.bn1(self.conv1(x)))
        x = self.se1(F.silu(self.bn2(self.conv2(x))))

        x = F.silu(self.bn3(self.conv3(x)))
        x = self.se2(F.silu(self.bn4(self.conv4(x))))

        x = x.flatten(1)
        x = self.dropout(F.silu(self.bn_fc1(self.fc1(x))))
        x = self.dropout(F.silu(self.bn_fc2(self.fc2(x))))
        return self.fc3(x)

model = ImprovedCNN(num_classes=10)

In [16]:
with torch.no_grad():
  a = next(iter(trainloader))
  print(a[0].shape)
  a = model(a[0])
  print(a.shape)

torch.Size([512, 3, 32, 32])
torch.Size([512, 10])


In [17]:
device = "cuda"
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model = model.to(device)
from tqdm import tqdm
from tqdm import tqdm

from tqdm import tqdm

for i in range(100):
    correct, total, running_loss = 0, 0, 0
    train_bar = tqdm(trainloader, desc=f'Train Epoch {i}')

    for image, label in train_bar:
        image, label = image.to(device), label.to(device)

        out = model(image)
        loss_value = loss(out, label)

        optimizer.zero_grad()
        loss_value.backward()
        optimizer.step()

        predictions = out.argmax(dim=1)
        correct += (predictions == label).sum().item()
        total += label.size(0)
        running_loss += loss_value.item()

        train_bar.set_postfix(loss=running_loss / total, acc=100 * correct / total)

    correct, total, test_loss = 0, 0, 0
    test_bar = tqdm(testloader, desc=f'Test Epoch {i}')

    with torch.no_grad():
        for test, test_label in test_bar:
            test, test_label = test.to(device), test_label.to(device)

            test_out = model(test)
            loss_value = loss(test_out, test_label)
            test_loss += loss_value.item()

            predictions = test_out.argmax(dim=1)
            correct += (predictions == test_label).sum().item()
            total += test_label.size(0)

            test_bar.set_postfix(loss=test_loss / total, acc=100 * correct / total)

Train Epoch 0: 100%|██████████| 98/98 [00:12<00:00,  7.63it/s, acc=54.2, loss=0.00254]
Test Epoch 0: 100%|██████████| 20/20 [00:00<00:00, 22.66it/s, acc=64, loss=0.00205]  
Train Epoch 1: 100%|██████████| 98/98 [00:13<00:00,  7.29it/s, acc=69.9, loss=0.00168]
Test Epoch 1: 100%|██████████| 20/20 [00:00<00:00, 21.46it/s, acc=71.2, loss=0.00166]
Train Epoch 2:  98%|█████████▊| 96/98 [00:13<00:00,  7.29it/s, acc=78, loss=0.00124]  


KeyboardInterrupt: 