In [None]:
!git clone https://github.com/rahilsinghi/DL_PROJECT_1.git


Cloning into 'DL_PROJECT_1'...
remote: Enumerating objects: 45, done.[K
remote: Counting objects: 100% (11/11), done.[K
remote: Compressing objects: 100% (8/8), done.[K
remote: Total 45 (delta 4), reused 7 (delta 3), pack-reused 34 (from 1)[K
Receiving objects: 100% (45/45), 208.90 MiB | 19.28 MiB/s, done.
Resolving deltas: 100% (13/13), done.
Updating files: 100% (27/27), done.


In [None]:
!pip install torch torchvision tqdm


Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [16]:
# cifar_resnet/model.py

import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, initial_planes=32):
        super(ResNet, self).__init__()
        self.in_planes = initial_planes

        # Initial convolution (for CIFAR-10 32x32 images)
        self.conv1 = nn.Conv2d(3, initial_planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(initial_planes)

        # Residual layers: we use 4 stages with [2,2,2,2] blocks each
        self.layer1 = self._make_layer(block, initial_planes,   num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, initial_planes*2, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, initial_planes*4, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, initial_planes*8, num_blocks[3], stride=2)

        # Final linear layer
        self.linear = nn.Linear(initial_planes*8*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for s in strides:
            layers.append(block(self.in_planes, planes, s))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))   # (N, initial_planes, 32, 32)
        out = self.layer1(out)                    # (N, initial_planes, 32, 32)
        out = self.layer2(out)                    # (N, initial_planes*2, 16, 16)
        out = self.layer3(out)                    # (N, initial_planes*4, 8, 8)
        out = self.layer4(out)                    # (N, initial_planes*8, 4, 4)
        out = F.avg_pool2d(out, 4)                # (N, initial_planes*8, 1, 1)
        out = out.view(out.size(0), -1)           # (N, initial_planes*8)
        out = self.linear(out)                    # (N, num_classes)
        return out

def ResNet18CIFAR(num_classes=10, initial_planes=32):
    """
    Constructs a ResNet-18 model for CIFAR-10.
    Using initial_planes=32 typically yields a model under 5M parameters.
    """
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, initial_planes=initial_planes)

In [17]:
#This train file is for model2.py for achieving better than 0.82 (4:51 pm, march 4)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from tqdm import tqdm
#from model import ResNet18CIFAR

def train():
    device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
    print(f"Using device: {device}")

    # Data augmentation & normalization
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.AutoAugment(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
    ])

    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
    ])

    trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    trainloader = DataLoader(trainset, batch_size=256, shuffle=True, num_workers=4)

    valset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_val)
    valloader = DataLoader(valset, batch_size=128, shuffle=False, num_workers=4)

    model = ResNet18CIFAR().to(device)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

    best_val_acc = 0.0
    for epoch in range(200):
        model.train()
        running_loss = 0.0
        for inputs, labels in tqdm(trainloader, desc=f"Epoch {epoch+1}/200"):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        scheduler.step()

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in valloader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        val_acc = 100.0 * correct / total
        print(f"Epoch [{epoch+1}/200] Loss: {running_loss/len(trainloader):.4f}, Val Acc: {val_acc:.2f}%")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), "best_model.pth")
            print(f"Saved best model at epoch {epoch+1} with Val Acc: {val_acc:.2f}%")
    print("Training complete. Best Accuracy: {:.2f}%".format(best_val_acc))

if __name__ == "__main__":
    train()

Using device: cuda
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:13<00:00, 12.4MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


Epoch 1/200: 100%|██████████| 196/196 [00:40<00:00,  4.80it/s]


Epoch [1/200] Loss: 2.1203, Val Acc: 33.89%
Saved best model at epoch 1 with Val Acc: 33.89%


Epoch 2/200: 100%|██████████| 196/196 [00:38<00:00,  5.13it/s]


Epoch [2/200] Loss: 1.7562, Val Acc: 52.58%
Saved best model at epoch 2 with Val Acc: 52.58%


Epoch 3/200: 100%|██████████| 196/196 [00:37<00:00,  5.26it/s]


Epoch [3/200] Loss: 1.5579, Val Acc: 59.85%
Saved best model at epoch 3 with Val Acc: 59.85%


Epoch 4/200: 100%|██████████| 196/196 [00:37<00:00,  5.25it/s]


Epoch [4/200] Loss: 1.4391, Val Acc: 62.48%
Saved best model at epoch 4 with Val Acc: 62.48%


Epoch 5/200: 100%|██████████| 196/196 [00:38<00:00,  5.09it/s]


Epoch [5/200] Loss: 1.3334, Val Acc: 72.15%
Saved best model at epoch 5 with Val Acc: 72.15%


Epoch 6/200: 100%|██████████| 196/196 [00:40<00:00,  4.87it/s]


Epoch [6/200] Loss: 1.2553, Val Acc: 71.54%


Epoch 7/200: 100%|██████████| 196/196 [00:38<00:00,  5.06it/s]


Epoch [7/200] Loss: 1.1999, Val Acc: 75.09%
Saved best model at epoch 7 with Val Acc: 75.09%


Epoch 8/200: 100%|██████████| 196/196 [00:38<00:00,  5.05it/s]


Epoch [8/200] Loss: 1.1667, Val Acc: 76.25%
Saved best model at epoch 8 with Val Acc: 76.25%


Epoch 9/200: 100%|██████████| 196/196 [00:38<00:00,  5.05it/s]


Epoch [9/200] Loss: 1.1298, Val Acc: 78.25%
Saved best model at epoch 9 with Val Acc: 78.25%


Epoch 10/200: 100%|██████████| 196/196 [00:40<00:00,  4.84it/s]


Epoch [10/200] Loss: 1.1089, Val Acc: 77.64%


Epoch 11/200: 100%|██████████| 196/196 [00:39<00:00,  5.02it/s]


Epoch [11/200] Loss: 1.0879, Val Acc: 78.02%


Epoch 12/200: 100%|██████████| 196/196 [00:38<00:00,  5.03it/s]


Epoch [12/200] Loss: 1.0771, Val Acc: 76.77%


Epoch 13/200: 100%|██████████| 196/196 [00:39<00:00,  5.02it/s]


Epoch [13/200] Loss: 1.0610, Val Acc: 77.67%


Epoch 14/200: 100%|██████████| 196/196 [00:40<00:00,  4.84it/s]


Epoch [14/200] Loss: 1.0430, Val Acc: 75.20%


Epoch 15/200: 100%|██████████| 196/196 [00:38<00:00,  5.12it/s]


Epoch [15/200] Loss: 1.0352, Val Acc: 80.38%
Saved best model at epoch 15 with Val Acc: 80.38%


Epoch 16/200: 100%|██████████| 196/196 [00:37<00:00,  5.26it/s]


Epoch [16/200] Loss: 1.0227, Val Acc: 79.58%


Epoch 17/200: 100%|██████████| 196/196 [00:38<00:00,  5.13it/s]


Epoch [17/200] Loss: 1.0087, Val Acc: 79.53%


Epoch 18/200: 100%|██████████| 196/196 [00:39<00:00,  4.98it/s]


Epoch [18/200] Loss: 1.0083, Val Acc: 78.63%


Epoch 19/200: 100%|██████████| 196/196 [00:38<00:00,  5.15it/s]


Epoch [19/200] Loss: 0.9998, Val Acc: 80.11%


Epoch 20/200: 100%|██████████| 196/196 [00:38<00:00,  5.13it/s]


Epoch [20/200] Loss: 0.9939, Val Acc: 80.50%
Saved best model at epoch 20 with Val Acc: 80.50%


Epoch 21/200: 100%|██████████| 196/196 [00:38<00:00,  5.04it/s]


Epoch [21/200] Loss: 0.9854, Val Acc: 79.88%


Epoch 22/200: 100%|██████████| 196/196 [00:40<00:00,  4.84it/s]


Epoch [22/200] Loss: 0.9800, Val Acc: 80.72%
Saved best model at epoch 22 with Val Acc: 80.72%


Epoch 23/200: 100%|██████████| 196/196 [00:39<00:00,  4.98it/s]


Epoch [23/200] Loss: 0.9738, Val Acc: 81.82%
Saved best model at epoch 23 with Val Acc: 81.82%


Epoch 24/200: 100%|██████████| 196/196 [00:39<00:00,  4.95it/s]


Epoch [24/200] Loss: 0.9709, Val Acc: 83.48%
Saved best model at epoch 24 with Val Acc: 83.48%


Epoch 25/200: 100%|██████████| 196/196 [00:39<00:00,  5.01it/s]


Epoch [25/200] Loss: 0.9697, Val Acc: 82.01%


Epoch 26/200: 100%|██████████| 196/196 [00:40<00:00,  4.89it/s]


Epoch [26/200] Loss: 0.9555, Val Acc: 82.97%


Epoch 27/200: 100%|██████████| 196/196 [00:38<00:00,  5.03it/s]


Epoch [27/200] Loss: 0.9600, Val Acc: 82.98%


Epoch 28/200: 100%|██████████| 196/196 [00:38<00:00,  5.03it/s]


Epoch [28/200] Loss: 0.9479, Val Acc: 83.93%
Saved best model at epoch 28 with Val Acc: 83.93%


Epoch 29/200: 100%|██████████| 196/196 [00:38<00:00,  5.10it/s]


Epoch [29/200] Loss: 0.9450, Val Acc: 80.73%


Epoch 30/200: 100%|██████████| 196/196 [00:38<00:00,  5.14it/s]


Epoch [30/200] Loss: 0.9421, Val Acc: 83.61%


Epoch 31/200: 100%|██████████| 196/196 [00:38<00:00,  5.10it/s]


Epoch [31/200] Loss: 0.9397, Val Acc: 82.29%


Epoch 32/200: 100%|██████████| 196/196 [00:37<00:00,  5.23it/s]


Epoch [32/200] Loss: 0.9408, Val Acc: 80.04%


Epoch 33/200: 100%|██████████| 196/196 [00:37<00:00,  5.18it/s]


Epoch [33/200] Loss: 0.9390, Val Acc: 84.24%
Saved best model at epoch 33 with Val Acc: 84.24%


Epoch 34/200: 100%|██████████| 196/196 [00:39<00:00,  5.02it/s]


Epoch [34/200] Loss: 0.9299, Val Acc: 83.25%


Epoch 35/200: 100%|██████████| 196/196 [00:40<00:00,  4.82it/s]


Epoch [35/200] Loss: 0.9306, Val Acc: 84.12%


Epoch 36/200: 100%|██████████| 196/196 [00:38<00:00,  5.03it/s]


Epoch [36/200] Loss: 0.9254, Val Acc: 84.45%
Saved best model at epoch 36 with Val Acc: 84.45%


Epoch 37/200: 100%|██████████| 196/196 [00:38<00:00,  5.05it/s]


Epoch [37/200] Loss: 0.9257, Val Acc: 83.91%


Epoch 38/200: 100%|██████████| 196/196 [00:39<00:00,  5.00it/s]


Epoch [38/200] Loss: 0.9177, Val Acc: 83.41%


Epoch 39/200: 100%|██████████| 196/196 [00:40<00:00,  4.82it/s]


Epoch [39/200] Loss: 0.9179, Val Acc: 84.69%
Saved best model at epoch 39 with Val Acc: 84.69%


Epoch 40/200: 100%|██████████| 196/196 [00:39<00:00,  5.00it/s]


Epoch [40/200] Loss: 0.9152, Val Acc: 86.08%
Saved best model at epoch 40 with Val Acc: 86.08%


Epoch 41/200: 100%|██████████| 196/196 [00:39<00:00,  5.02it/s]


Epoch [41/200] Loss: 0.9158, Val Acc: 83.84%


Epoch 42/200: 100%|██████████| 196/196 [00:38<00:00,  5.11it/s]


Epoch [42/200] Loss: 0.9138, Val Acc: 79.55%


Epoch 43/200: 100%|██████████| 196/196 [00:39<00:00,  4.96it/s]


Epoch [43/200] Loss: 0.9102, Val Acc: 85.79%


Epoch 44/200: 100%|██████████| 196/196 [00:37<00:00,  5.20it/s]


Epoch [44/200] Loss: 0.9090, Val Acc: 84.63%


Epoch 45/200: 100%|██████████| 196/196 [00:38<00:00,  5.10it/s]


Epoch [45/200] Loss: 0.9043, Val Acc: 83.16%


Epoch 46/200: 100%|██████████| 196/196 [00:38<00:00,  5.03it/s]


Epoch [46/200] Loss: 0.9101, Val Acc: 83.74%


Epoch 47/200: 100%|██████████| 196/196 [00:39<00:00,  4.96it/s]


Epoch [47/200] Loss: 0.9048, Val Acc: 84.88%


Epoch 48/200: 100%|██████████| 196/196 [00:39<00:00,  4.90it/s]


Epoch [48/200] Loss: 0.9012, Val Acc: 84.72%


Epoch 49/200: 100%|██████████| 196/196 [00:39<00:00,  4.90it/s]


Epoch [49/200] Loss: 0.9028, Val Acc: 85.85%


Epoch 50/200: 100%|██████████| 196/196 [00:40<00:00,  4.87it/s]


Epoch [50/200] Loss: 0.8905, Val Acc: 83.38%


Epoch 51/200: 100%|██████████| 196/196 [00:40<00:00,  4.79it/s]


Epoch [51/200] Loss: 0.8921, Val Acc: 84.41%


Epoch 52/200: 100%|██████████| 196/196 [00:38<00:00,  5.03it/s]


Epoch [52/200] Loss: 0.8899, Val Acc: 84.59%


Epoch 53/200: 100%|██████████| 196/196 [00:38<00:00,  5.04it/s]


Epoch [53/200] Loss: 0.8810, Val Acc: 80.32%


Epoch 54/200: 100%|██████████| 196/196 [00:39<00:00,  4.98it/s]


Epoch [54/200] Loss: 0.8811, Val Acc: 83.51%


Epoch 55/200: 100%|██████████| 196/196 [00:41<00:00,  4.76it/s]


Epoch [55/200] Loss: 0.8876, Val Acc: 85.38%


Epoch 56/200: 100%|██████████| 196/196 [00:38<00:00,  5.10it/s]


Epoch [56/200] Loss: 0.8847, Val Acc: 81.99%


Epoch 57/200: 100%|██████████| 196/196 [00:38<00:00,  5.05it/s]


Epoch [57/200] Loss: 0.8830, Val Acc: 87.59%
Saved best model at epoch 57 with Val Acc: 87.59%


Epoch 58/200: 100%|██████████| 196/196 [00:38<00:00,  5.13it/s]


Epoch [58/200] Loss: 0.8789, Val Acc: 85.92%


Epoch 59/200: 100%|██████████| 196/196 [00:40<00:00,  4.88it/s]


Epoch [59/200] Loss: 0.8816, Val Acc: 85.11%


Epoch 60/200: 100%|██████████| 196/196 [00:37<00:00,  5.20it/s]


Epoch [60/200] Loss: 0.8757, Val Acc: 84.71%


Epoch 61/200: 100%|██████████| 196/196 [00:37<00:00,  5.19it/s]


Epoch [61/200] Loss: 0.8719, Val Acc: 88.21%
Saved best model at epoch 61 with Val Acc: 88.21%


Epoch 62/200: 100%|██████████| 196/196 [00:38<00:00,  5.06it/s]


Epoch [62/200] Loss: 0.8703, Val Acc: 85.60%


Epoch 63/200: 100%|██████████| 196/196 [00:40<00:00,  4.81it/s]


Epoch [63/200] Loss: 0.8692, Val Acc: 85.42%


Epoch 64/200: 100%|██████████| 196/196 [00:39<00:00,  4.99it/s]


Epoch [64/200] Loss: 0.8681, Val Acc: 85.06%


Epoch 65/200: 100%|██████████| 196/196 [00:38<00:00,  5.06it/s]


Epoch [65/200] Loss: 0.8709, Val Acc: 87.30%


Epoch 66/200: 100%|██████████| 196/196 [00:39<00:00,  4.99it/s]


Epoch [66/200] Loss: 0.8616, Val Acc: 86.54%


Epoch 67/200: 100%|██████████| 196/196 [00:41<00:00,  4.69it/s]


Epoch [67/200] Loss: 0.8607, Val Acc: 84.95%


Epoch 68/200: 100%|██████████| 196/196 [00:39<00:00,  4.91it/s]


Epoch [68/200] Loss: 0.8624, Val Acc: 86.89%


Epoch 69/200: 100%|██████████| 196/196 [00:39<00:00,  4.92it/s]


Epoch [69/200] Loss: 0.8605, Val Acc: 87.37%


Epoch 70/200: 100%|██████████| 196/196 [00:39<00:00,  4.97it/s]


Epoch [70/200] Loss: 0.8555, Val Acc: 87.52%


Epoch 71/200: 100%|██████████| 196/196 [00:39<00:00,  4.91it/s]


Epoch [71/200] Loss: 0.8523, Val Acc: 87.10%


Epoch 72/200: 100%|██████████| 196/196 [00:37<00:00,  5.20it/s]


Epoch [72/200] Loss: 0.8524, Val Acc: 84.85%


Epoch 73/200: 100%|██████████| 196/196 [00:37<00:00,  5.28it/s]


Epoch [73/200] Loss: 0.8507, Val Acc: 84.43%


Epoch 74/200: 100%|██████████| 196/196 [00:37<00:00,  5.18it/s]


Epoch [74/200] Loss: 0.8477, Val Acc: 86.21%


Epoch 75/200: 100%|██████████| 196/196 [00:39<00:00,  4.94it/s]


Epoch [75/200] Loss: 0.8458, Val Acc: 85.81%


Epoch 76/200: 100%|██████████| 196/196 [00:37<00:00,  5.22it/s]


Epoch [76/200] Loss: 0.8446, Val Acc: 87.64%


Epoch 77/200: 100%|██████████| 196/196 [00:38<00:00,  5.09it/s]


Epoch [77/200] Loss: 0.8427, Val Acc: 86.41%


Epoch 78/200: 100%|██████████| 196/196 [00:38<00:00,  5.11it/s]


Epoch [78/200] Loss: 0.8421, Val Acc: 86.11%


Epoch 79/200: 100%|██████████| 196/196 [00:39<00:00,  4.91it/s]


Epoch [79/200] Loss: 0.8403, Val Acc: 87.01%


Epoch 80/200: 100%|██████████| 196/196 [00:38<00:00,  5.08it/s]


Epoch [80/200] Loss: 0.8391, Val Acc: 88.51%
Saved best model at epoch 80 with Val Acc: 88.51%


Epoch 81/200: 100%|██████████| 196/196 [00:39<00:00,  4.94it/s]


Epoch [81/200] Loss: 0.8287, Val Acc: 86.49%


Epoch 82/200: 100%|██████████| 196/196 [00:39<00:00,  4.91it/s]


Epoch [82/200] Loss: 0.8351, Val Acc: 87.99%


Epoch 83/200: 100%|██████████| 196/196 [00:41<00:00,  4.77it/s]


Epoch [83/200] Loss: 0.8310, Val Acc: 87.81%


Epoch 84/200: 100%|██████████| 196/196 [00:39<00:00,  4.96it/s]


Epoch [84/200] Loss: 0.8245, Val Acc: 86.35%


Epoch 85/200: 100%|██████████| 196/196 [00:39<00:00,  5.01it/s]


Epoch [85/200] Loss: 0.8287, Val Acc: 83.97%


Epoch 86/200: 100%|██████████| 196/196 [00:39<00:00,  4.94it/s]


Epoch [86/200] Loss: 0.8237, Val Acc: 87.95%


Epoch 87/200: 100%|██████████| 196/196 [00:38<00:00,  5.13it/s]


Epoch [87/200] Loss: 0.8203, Val Acc: 86.20%


Epoch 88/200: 100%|██████████| 196/196 [00:38<00:00,  5.15it/s]


Epoch [88/200] Loss: 0.8246, Val Acc: 87.47%


Epoch 89/200: 100%|██████████| 196/196 [00:38<00:00,  5.08it/s]


Epoch [89/200] Loss: 0.8156, Val Acc: 85.11%


Epoch 90/200: 100%|██████████| 196/196 [00:39<00:00,  4.91it/s]


Epoch [90/200] Loss: 0.8191, Val Acc: 88.03%


Epoch 91/200: 100%|██████████| 196/196 [00:38<00:00,  5.09it/s]


Epoch [91/200] Loss: 0.8166, Val Acc: 87.66%


Epoch 92/200: 100%|██████████| 196/196 [00:38<00:00,  5.08it/s]


Epoch [92/200] Loss: 0.8109, Val Acc: 88.22%


Epoch 93/200: 100%|██████████| 196/196 [00:39<00:00,  4.97it/s]


Epoch [93/200] Loss: 0.8088, Val Acc: 87.17%


Epoch 94/200: 100%|██████████| 196/196 [00:41<00:00,  4.78it/s]


Epoch [94/200] Loss: 0.8136, Val Acc: 89.16%
Saved best model at epoch 94 with Val Acc: 89.16%


Epoch 95/200: 100%|██████████| 196/196 [00:38<00:00,  5.04it/s]


Epoch [95/200] Loss: 0.8092, Val Acc: 87.22%


Epoch 96/200: 100%|██████████| 196/196 [00:39<00:00,  5.02it/s]


Epoch [96/200] Loss: 0.8047, Val Acc: 89.70%
Saved best model at epoch 96 with Val Acc: 89.70%


Epoch 97/200: 100%|██████████| 196/196 [00:38<00:00,  5.04it/s]


Epoch [97/200] Loss: 0.8027, Val Acc: 88.53%


Epoch 98/200: 100%|██████████| 196/196 [00:41<00:00,  4.76it/s]


Epoch [98/200] Loss: 0.8015, Val Acc: 87.37%


Epoch 99/200: 100%|██████████| 196/196 [00:39<00:00,  5.03it/s]


Epoch [99/200] Loss: 0.7984, Val Acc: 89.10%


Epoch 100/200: 100%|██████████| 196/196 [00:39<00:00,  4.97it/s]


Epoch [100/200] Loss: 0.7978, Val Acc: 89.64%


Epoch 101/200: 100%|██████████| 196/196 [00:39<00:00,  4.92it/s]


Epoch [101/200] Loss: 0.7901, Val Acc: 88.99%


Epoch 102/200: 100%|██████████| 196/196 [00:40<00:00,  4.79it/s]


Epoch [102/200] Loss: 0.7908, Val Acc: 88.25%


Epoch 103/200: 100%|██████████| 196/196 [00:39<00:00,  5.00it/s]


Epoch [103/200] Loss: 0.7934, Val Acc: 88.79%


Epoch 104/200: 100%|██████████| 196/196 [00:39<00:00,  4.96it/s]


Epoch [104/200] Loss: 0.7855, Val Acc: 89.35%


Epoch 105/200: 100%|██████████| 196/196 [00:40<00:00,  4.80it/s]


Epoch [105/200] Loss: 0.7857, Val Acc: 88.52%


Epoch 106/200: 100%|██████████| 196/196 [00:39<00:00,  4.92it/s]


Epoch [106/200] Loss: 0.7855, Val Acc: 86.53%


Epoch 107/200: 100%|██████████| 196/196 [00:38<00:00,  5.07it/s]


Epoch [107/200] Loss: 0.7813, Val Acc: 90.44%
Saved best model at epoch 107 with Val Acc: 90.44%


Epoch 108/200: 100%|██████████| 196/196 [00:38<00:00,  5.15it/s]


Epoch [108/200] Loss: 0.7791, Val Acc: 88.59%


Epoch 109/200: 100%|██████████| 196/196 [00:39<00:00,  4.98it/s]


Epoch [109/200] Loss: 0.7769, Val Acc: 89.76%


Epoch 110/200: 100%|██████████| 196/196 [00:37<00:00,  5.19it/s]


Epoch [110/200] Loss: 0.7716, Val Acc: 85.53%


Epoch 111/200: 100%|██████████| 196/196 [00:37<00:00,  5.16it/s]


Epoch [111/200] Loss: 0.7710, Val Acc: 89.14%


Epoch 112/200: 100%|██████████| 196/196 [00:38<00:00,  5.13it/s]


Epoch [112/200] Loss: 0.7672, Val Acc: 89.61%


Epoch 113/200: 100%|██████████| 196/196 [00:41<00:00,  4.76it/s]


Epoch [113/200] Loss: 0.7681, Val Acc: 89.56%


Epoch 114/200: 100%|██████████| 196/196 [00:39<00:00,  4.94it/s]


Epoch [114/200] Loss: 0.7686, Val Acc: 90.01%


Epoch 115/200: 100%|██████████| 196/196 [00:39<00:00,  4.99it/s]


Epoch [115/200] Loss: 0.7599, Val Acc: 89.97%


Epoch 116/200: 100%|██████████| 196/196 [00:39<00:00,  4.99it/s]


Epoch [116/200] Loss: 0.7566, Val Acc: 91.28%
Saved best model at epoch 116 with Val Acc: 91.28%


Epoch 117/200: 100%|██████████| 196/196 [00:41<00:00,  4.77it/s]


Epoch [117/200] Loss: 0.7556, Val Acc: 90.10%


Epoch 118/200: 100%|██████████| 196/196 [00:38<00:00,  5.03it/s]


Epoch [118/200] Loss: 0.7549, Val Acc: 88.58%


Epoch 119/200: 100%|██████████| 196/196 [00:38<00:00,  5.03it/s]


Epoch [119/200] Loss: 0.7507, Val Acc: 90.00%


Epoch 120/200: 100%|██████████| 196/196 [00:39<00:00,  4.96it/s]


Epoch [120/200] Loss: 0.7490, Val Acc: 90.71%


Epoch 121/200: 100%|██████████| 196/196 [00:38<00:00,  5.11it/s]


Epoch [121/200] Loss: 0.7473, Val Acc: 90.26%


Epoch 122/200: 100%|██████████| 196/196 [00:37<00:00,  5.20it/s]


Epoch [122/200] Loss: 0.7419, Val Acc: 89.98%


Epoch 123/200: 100%|██████████| 196/196 [00:38<00:00,  5.11it/s]


Epoch [123/200] Loss: 0.7411, Val Acc: 90.99%


Epoch 124/200: 100%|██████████| 196/196 [00:39<00:00,  4.91it/s]


Epoch [124/200] Loss: 0.7351, Val Acc: 89.58%


Epoch 125/200: 100%|██████████| 196/196 [00:38<00:00,  5.05it/s]


Epoch [125/200] Loss: 0.7375, Val Acc: 90.99%


Epoch 126/200: 100%|██████████| 196/196 [00:38<00:00,  5.05it/s]


Epoch [126/200] Loss: 0.7326, Val Acc: 91.15%


Epoch 127/200: 100%|██████████| 196/196 [00:39<00:00,  5.01it/s]


Epoch [127/200] Loss: 0.7335, Val Acc: 88.87%


Epoch 128/200: 100%|██████████| 196/196 [00:41<00:00,  4.76it/s]


Epoch [128/200] Loss: 0.7280, Val Acc: 90.83%


Epoch 129/200: 100%|██████████| 196/196 [00:39<00:00,  4.99it/s]


Epoch [129/200] Loss: 0.7286, Val Acc: 90.77%


Epoch 130/200: 100%|██████████| 196/196 [00:39<00:00,  5.02it/s]


Epoch [130/200] Loss: 0.7237, Val Acc: 91.26%


Epoch 131/200: 100%|██████████| 196/196 [00:38<00:00,  5.11it/s]


Epoch [131/200] Loss: 0.7203, Val Acc: 91.04%


Epoch 132/200: 100%|██████████| 196/196 [00:40<00:00,  4.88it/s]


Epoch [132/200] Loss: 0.7179, Val Acc: 91.98%
Saved best model at epoch 132 with Val Acc: 91.98%


Epoch 133/200: 100%|██████████| 196/196 [00:38<00:00,  5.13it/s]


Epoch [133/200] Loss: 0.7104, Val Acc: 90.95%


Epoch 134/200: 100%|██████████| 196/196 [00:38<00:00,  5.13it/s]


Epoch [134/200] Loss: 0.7122, Val Acc: 91.23%


Epoch 135/200: 100%|██████████| 196/196 [00:37<00:00,  5.20it/s]


Epoch [135/200] Loss: 0.7078, Val Acc: 91.85%


Epoch 136/200: 100%|██████████| 196/196 [00:38<00:00,  5.03it/s]


Epoch [136/200] Loss: 0.7058, Val Acc: 90.40%


Epoch 137/200: 100%|██████████| 196/196 [00:38<00:00,  5.05it/s]


Epoch [137/200] Loss: 0.7055, Val Acc: 91.59%


Epoch 138/200: 100%|██████████| 196/196 [00:38<00:00,  5.03it/s]


Epoch [138/200] Loss: 0.7020, Val Acc: 91.55%


Epoch 139/200: 100%|██████████| 196/196 [00:40<00:00,  4.83it/s]


Epoch [139/200] Loss: 0.6986, Val Acc: 91.10%


Epoch 140/200: 100%|██████████| 196/196 [00:39<00:00,  5.01it/s]


Epoch [140/200] Loss: 0.6937, Val Acc: 92.17%
Saved best model at epoch 140 with Val Acc: 92.17%


Epoch 141/200: 100%|██████████| 196/196 [00:38<00:00,  5.04it/s]


Epoch [141/200] Loss: 0.6946, Val Acc: 91.80%


Epoch 142/200: 100%|██████████| 196/196 [00:39<00:00,  5.00it/s]


Epoch [142/200] Loss: 0.6866, Val Acc: 91.17%


Epoch 143/200: 100%|██████████| 196/196 [00:42<00:00,  4.65it/s]


Epoch [143/200] Loss: 0.6848, Val Acc: 91.24%


Epoch 144/200: 100%|██████████| 196/196 [00:39<00:00,  4.94it/s]


Epoch [144/200] Loss: 0.6842, Val Acc: 91.17%


Epoch 145/200: 100%|██████████| 196/196 [00:39<00:00,  4.99it/s]


Epoch [145/200] Loss: 0.6797, Val Acc: 92.46%
Saved best model at epoch 145 with Val Acc: 92.46%


Epoch 146/200: 100%|██████████| 196/196 [00:38<00:00,  5.10it/s]


Epoch [146/200] Loss: 0.6784, Val Acc: 92.28%


Epoch 147/200: 100%|██████████| 196/196 [00:39<00:00,  4.90it/s]


Epoch [147/200] Loss: 0.6700, Val Acc: 91.96%


Epoch 148/200: 100%|██████████| 196/196 [00:38<00:00,  5.09it/s]


Epoch [148/200] Loss: 0.6740, Val Acc: 92.87%
Saved best model at epoch 148 with Val Acc: 92.87%


Epoch 149/200: 100%|██████████| 196/196 [00:39<00:00,  5.02it/s]


Epoch [149/200] Loss: 0.6653, Val Acc: 91.91%


Epoch 150/200: 100%|██████████| 196/196 [00:40<00:00,  4.85it/s]


Epoch [150/200] Loss: 0.6664, Val Acc: 92.78%


Epoch 151/200: 100%|██████████| 196/196 [00:37<00:00,  5.21it/s]


Epoch [151/200] Loss: 0.6582, Val Acc: 92.25%


Epoch 152/200: 100%|██████████| 196/196 [00:37<00:00,  5.23it/s]


Epoch [152/200] Loss: 0.6561, Val Acc: 92.86%


Epoch 153/200: 100%|██████████| 196/196 [00:37<00:00,  5.18it/s]


Epoch [153/200] Loss: 0.6560, Val Acc: 92.41%


Epoch 154/200: 100%|██████████| 196/196 [00:39<00:00,  4.93it/s]


Epoch [154/200] Loss: 0.6510, Val Acc: 93.24%
Saved best model at epoch 154 with Val Acc: 93.24%


Epoch 155/200: 100%|██████████| 196/196 [00:38<00:00,  5.09it/s]


Epoch [155/200] Loss: 0.6452, Val Acc: 92.41%


Epoch 156/200: 100%|██████████| 196/196 [00:38<00:00,  5.04it/s]


Epoch [156/200] Loss: 0.6440, Val Acc: 92.38%


Epoch 157/200: 100%|██████████| 196/196 [00:38<00:00,  5.06it/s]


Epoch [157/200] Loss: 0.6393, Val Acc: 92.54%


Epoch 158/200: 100%|██████████| 196/196 [00:40<00:00,  4.78it/s]


Epoch [158/200] Loss: 0.6375, Val Acc: 93.03%


Epoch 159/200: 100%|██████████| 196/196 [00:39<00:00,  4.93it/s]


Epoch [159/200] Loss: 0.6358, Val Acc: 92.99%


Epoch 160/200: 100%|██████████| 196/196 [00:39<00:00,  4.90it/s]


Epoch [160/200] Loss: 0.6288, Val Acc: 92.89%


Epoch 161/200: 100%|██████████| 196/196 [00:40<00:00,  4.78it/s]


Epoch [161/200] Loss: 0.6291, Val Acc: 93.21%


Epoch 162/200: 100%|██████████| 196/196 [00:39<00:00,  4.96it/s]


Epoch [162/200] Loss: 0.6277, Val Acc: 93.11%


Epoch 163/200: 100%|██████████| 196/196 [00:39<00:00,  4.99it/s]


Epoch [163/200] Loss: 0.6210, Val Acc: 93.50%
Saved best model at epoch 163 with Val Acc: 93.50%


Epoch 164/200: 100%|██████████| 196/196 [00:39<00:00,  4.99it/s]


Epoch [164/200] Loss: 0.6192, Val Acc: 93.38%


Epoch 165/200: 100%|██████████| 196/196 [00:40<00:00,  4.81it/s]


Epoch [165/200] Loss: 0.6174, Val Acc: 93.76%
Saved best model at epoch 165 with Val Acc: 93.76%


Epoch 166/200: 100%|██████████| 196/196 [00:39<00:00,  5.00it/s]


Epoch [166/200] Loss: 0.6104, Val Acc: 93.53%


Epoch 167/200: 100%|██████████| 196/196 [00:38<00:00,  5.12it/s]


Epoch [167/200] Loss: 0.6090, Val Acc: 93.66%


Epoch 168/200: 100%|██████████| 196/196 [00:38<00:00,  5.12it/s]


Epoch [168/200] Loss: 0.6056, Val Acc: 93.67%


Epoch 169/200: 100%|██████████| 196/196 [00:39<00:00,  4.91it/s]


Epoch [169/200] Loss: 0.6081, Val Acc: 93.66%


Epoch 170/200: 100%|██████████| 196/196 [00:37<00:00,  5.17it/s]


Epoch [170/200] Loss: 0.6003, Val Acc: 93.88%
Saved best model at epoch 170 with Val Acc: 93.88%


Epoch 171/200: 100%|██████████| 196/196 [00:38<00:00,  5.13it/s]


Epoch [171/200] Loss: 0.6003, Val Acc: 93.96%
Saved best model at epoch 171 with Val Acc: 93.96%


Epoch 172/200: 100%|██████████| 196/196 [00:39<00:00,  4.97it/s]


Epoch [172/200] Loss: 0.5984, Val Acc: 94.03%
Saved best model at epoch 172 with Val Acc: 94.03%


Epoch 173/200: 100%|██████████| 196/196 [00:38<00:00,  5.06it/s]


Epoch [173/200] Loss: 0.5935, Val Acc: 94.03%


Epoch 174/200: 100%|██████████| 196/196 [00:38<00:00,  5.15it/s]


Epoch [174/200] Loss: 0.5925, Val Acc: 93.75%


Epoch 175/200: 100%|██████████| 196/196 [00:37<00:00,  5.18it/s]


Epoch [175/200] Loss: 0.5886, Val Acc: 93.97%


Epoch 176/200: 100%|██████████| 196/196 [00:40<00:00,  4.90it/s]


Epoch [176/200] Loss: 0.5892, Val Acc: 94.10%
Saved best model at epoch 176 with Val Acc: 94.10%


Epoch 177/200: 100%|██████████| 196/196 [00:37<00:00,  5.19it/s]


Epoch [177/200] Loss: 0.5826, Val Acc: 94.00%


Epoch 178/200: 100%|██████████| 196/196 [00:38<00:00,  5.09it/s]


Epoch [178/200] Loss: 0.5815, Val Acc: 94.01%


Epoch 179/200: 100%|██████████| 196/196 [00:38<00:00,  5.06it/s]


Epoch [179/200] Loss: 0.5793, Val Acc: 94.07%


Epoch 180/200: 100%|██████████| 196/196 [00:41<00:00,  4.74it/s]


Epoch [180/200] Loss: 0.5782, Val Acc: 94.15%
Saved best model at epoch 180 with Val Acc: 94.15%


Epoch 181/200: 100%|██████████| 196/196 [00:39<00:00,  4.97it/s]


Epoch [181/200] Loss: 0.5794, Val Acc: 93.97%


Epoch 182/200: 100%|██████████| 196/196 [00:39<00:00,  4.94it/s]


Epoch [182/200] Loss: 0.5740, Val Acc: 94.33%
Saved best model at epoch 182 with Val Acc: 94.33%


Epoch 183/200: 100%|██████████| 196/196 [00:41<00:00,  4.72it/s]


Epoch [183/200] Loss: 0.5739, Val Acc: 94.33%


Epoch 184/200: 100%|██████████| 196/196 [00:39<00:00,  4.91it/s]


Epoch [184/200] Loss: 0.5743, Val Acc: 94.22%


Epoch 185/200: 100%|██████████| 196/196 [00:39<00:00,  4.92it/s]


Epoch [185/200] Loss: 0.5706, Val Acc: 94.39%
Saved best model at epoch 185 with Val Acc: 94.39%


Epoch 186/200: 100%|██████████| 196/196 [00:39<00:00,  4.90it/s]


Epoch [186/200] Loss: 0.5682, Val Acc: 94.34%


Epoch 187/200: 100%|██████████| 196/196 [00:41<00:00,  4.74it/s]


Epoch [187/200] Loss: 0.5697, Val Acc: 94.33%


Epoch 188/200: 100%|██████████| 196/196 [00:39<00:00,  4.98it/s]


Epoch [188/200] Loss: 0.5685, Val Acc: 94.34%


Epoch 189/200: 100%|██████████| 196/196 [00:39<00:00,  4.97it/s]


Epoch [189/200] Loss: 0.5668, Val Acc: 94.46%
Saved best model at epoch 189 with Val Acc: 94.46%


Epoch 190/200: 100%|██████████| 196/196 [00:41<00:00,  4.75it/s]


Epoch [190/200] Loss: 0.5652, Val Acc: 94.33%


Epoch 191/200: 100%|██████████| 196/196 [00:39<00:00,  4.91it/s]


Epoch [191/200] Loss: 0.5658, Val Acc: 94.32%


Epoch 192/200: 100%|██████████| 196/196 [00:39<00:00,  5.02it/s]


Epoch [192/200] Loss: 0.5647, Val Acc: 94.50%
Saved best model at epoch 192 with Val Acc: 94.50%


Epoch 193/200: 100%|██████████| 196/196 [00:38<00:00,  5.15it/s]


Epoch [193/200] Loss: 0.5637, Val Acc: 94.37%


Epoch 194/200: 100%|██████████| 196/196 [00:40<00:00,  4.88it/s]


Epoch [194/200] Loss: 0.5640, Val Acc: 94.33%


Epoch 195/200: 100%|██████████| 196/196 [00:37<00:00,  5.19it/s]


Epoch [195/200] Loss: 0.5619, Val Acc: 94.37%


Epoch 196/200: 100%|██████████| 196/196 [00:37<00:00,  5.23it/s]


Epoch [196/200] Loss: 0.5635, Val Acc: 94.51%
Saved best model at epoch 196 with Val Acc: 94.51%


Epoch 197/200: 100%|██████████| 196/196 [00:37<00:00,  5.17it/s]


Epoch [197/200] Loss: 0.5619, Val Acc: 94.51%


Epoch 198/200: 100%|██████████| 196/196 [00:40<00:00,  4.87it/s]


Epoch [198/200] Loss: 0.5614, Val Acc: 94.53%
Saved best model at epoch 198 with Val Acc: 94.53%


Epoch 199/200: 100%|██████████| 196/196 [00:38<00:00,  5.15it/s]


Epoch [199/200] Loss: 0.5622, Val Acc: 94.45%


Epoch 200/200: 100%|██████████| 196/196 [00:38<00:00,  5.03it/s]


Epoch [200/200] Loss: 0.5620, Val Acc: 94.51%
Training complete. Best Accuracy: 94.53%


In [23]:
# cifar_resnet/inference.py

import os
import pickle
import numpy as np
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
#from model import ResNet18CIFAR

def unpickle(file):
    with open(file, 'rb') as fo:
        data_dict = pickle.load(fo, encoding='bytes')
    return data_dict

class CIFARTestDataset(Dataset):
    def __init__(self, data, ids, transform=None):
        """
        data: numpy array of shape (N, 32, 32, 3)
        ids: array-like of test image IDs
        transform: transforms to apply
        """
        self.data = data
        self.ids = ids
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # If data is stored as (32,32,3)
        img = self.data[idx].astype("uint8")
        img_id = self.ids[idx]
        if self.transform:
            img = self.transform(img)
        return img, img_id

def main():
    # Device selection: CUDA > MPS > CPU
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
    print("Using device:", device)

    # Load the best model (make sure parameters match training)
    model = ResNet18CIFAR(num_classes=10, initial_planes=32).to(device)
    model.load_state_dict(torch.load("best_model.pth", map_location=device))
    model.eval()

    # Load custom test set (provided by competition) – adjust path as needed
    # The custom test file should contain keys: b'data' (shape (N, 32,32,3)) and b'ids'
    test_file = "/content/cifar_test_nolabel.pkl"
    test_dict = unpickle(test_file)
    print("Keys in test_dict:", test_dict.keys())

    test_images = test_dict[b'data']  # shape (N, 32, 32, 3)
    test_ids = test_dict[b'ids']

    # Define test transforms (no random augmentation)
    transform_test = transforms.Compose([
        transforms.ToPILImage(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2470, 0.2435, 0.2616))
    ])

    test_dataset = CIFARTestDataset(test_images, test_ids, transform=transform_test)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)

    # Inference loop
    all_preds = []
    all_ids = []
    with torch.no_grad():
        for images, ids in test_loader:
            images = images.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().numpy().tolist())
            all_ids.extend(ids.tolist())

    # Create submission CSV (must have exactly two columns: "ID" and "Labels")
    submission_df = pd.DataFrame({
        "ID": all_ids,
        "Labels": all_preds
    })
    submission_df.to_csv("submission.csv", index=False)
    print("submission.csv generated! Ready for Kaggle submission.")

if __name__ == "__main__":
    main()

Using device: cuda
Keys in test_dict: dict_keys([b'data', b'ids'])


  model.load_state_dict(torch.load("best_model.pth", map_location=device))


submission.csv generated! Ready for Kaggle submission.


Do't touch anything beyond this. The followig cells got us 0.82861. Do all edits and chnages above this.

In [None]:
#model.py and train.py
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm

 # Ensure this function is at the module level in model.py

def unpickle(file):
    with open(file, 'rb') as fo:
        data_dict = pickle.load(fo, encoding='bytes')
    return data_dict

class CIFARDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data  # shape (N, 3072)
        self.labels = labels  # shape (N,)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Reshape vector (3072,) -> (32, 32, 3), then convert to PIL Image
        img = self.data[idx].reshape(3, 32, 32).transpose(1, 2, 0).astype("uint8")
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label

def load_cifar10_batches(root_dir):
    data_list, labels_list = [], []
    for i in range(1, 6):
        batch_file = os.path.join(root_dir, f"data_batch_{i}")
        batch = unpickle(batch_file)
        data_list.append(batch[b'data'])
        labels_list.extend(batch[b'labels'])
    X = np.concatenate(data_list, axis=0)
    y = np.array(labels_list)
    return X, y

def main():
    # Set up device: if MPS is available on your MacBook Pro, use it.

    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available():
        device = "mps"
    else:
        device = "cpu"
    print("Using device:", device)


    # Load CIFAR-10 training data
    root_dir = "/content/DL_PROJECT_1/cifar-10-python/cifar-10-batches-py"  # adjust path if needed
    X, y = load_cifar10_batches(root_dir)

    # Split into training and validation sets (90/10 split)
    num_samples = len(X)
    split_idx = int(num_samples * 0.9)
    X_train, X_val = X[:split_idx], X[split_idx:]
    y_train, y_val = y[:split_idx], y[split_idx:]

    # Define data augmentation and normalization for training; only normalization for validation
    transform_train = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        # Uncomment the following line to try color jitter (optional)
        # transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
    ])
    transform_val = transforms.Compose([
        transforms.ToPILImage(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
    ])

    train_dataset = CIFARDataset(X_train, y_train, transform=transform_train)
    val_dataset = CIFARDataset(X_val, y_val, transform=transform_val)

    # Increase num_workers if your system can support it; pin_memory only if not on MPS.
    train_loader = DataLoader(
    train_dataset,
    batch_size=256,         # larger batch size
    shuffle=True,
    num_workers=2,          # or 8 if your CPU can handle it
    pin_memory=False        # MPS typically doesn’t benefit
)
    val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=4, pin_memory=False)

    # Instantiate the model; use the same architecture as during training.
    # For example, if you trained with 64 initial channels, pass initial_planes=64.
    model = ResNet18CIFAR(num_classes=10, initial_planes=32).to(device)

    # Define loss function and optimizer (SGD with momentum and weight decay is standard for CIFAR-10)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    # Use cosine annealing for a smooth LR schedule (T_max set to number of epochs)
    num_epochs = 200
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    best_val_acc = 0.0

    # Optionally, for CUDA you can use AMP (MPS support is experimental)
    use_amp = (device == "cuda")
    scaler = torch.cuda.amp.GradScaler() if use_amp else None

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        train_progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
        for images, labels in train_progress:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            if use_amp:
                with torch.cuda.amp.autocast():
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

            running_loss += loss.item()
            train_progress.set_postfix(loss=loss.item())

        scheduler.step()

        # Evaluate on validation set
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100.0 * correct / total
        avg_loss = running_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}] - Loss: {avg_loss:.4f} | Val Acc: {val_acc:.2f}%")

        # Save best model checkpoint
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), "best_model.pth")
            print(f"Saved new best model at epoch {epoch+1} with Val Acc: {val_acc:.2f}%")

    print("Training complete.")
    print(f"Best Validation Accuracy: {best_val_acc:.2f}%")

if __name__ == "__main__":
    main()


Using device: cuda


  scaler = torch.cuda.amp.GradScaler() if use_amp else None
  with torch.cuda.amp.autocast():


Epoch [1/200] - Loss: 1.7395 | Val Acc: 47.20%
Saved new best model at epoch 1 with Val Acc: 47.20%




Epoch [2/200] - Loss: 1.3147 | Val Acc: 55.80%
Saved new best model at epoch 2 with Val Acc: 55.80%




Epoch [3/200] - Loss: 1.0314 | Val Acc: 67.44%
Saved new best model at epoch 3 with Val Acc: 67.44%




Epoch [4/200] - Loss: 0.8487 | Val Acc: 72.26%
Saved new best model at epoch 4 with Val Acc: 72.26%




Epoch [5/200] - Loss: 0.7229 | Val Acc: 69.62%




Epoch [6/200] - Loss: 0.6443 | Val Acc: 76.50%
Saved new best model at epoch 6 with Val Acc: 76.50%




Epoch [7/200] - Loss: 0.5813 | Val Acc: 74.24%




Epoch [8/200] - Loss: 0.5319 | Val Acc: 80.42%
Saved new best model at epoch 8 with Val Acc: 80.42%




Epoch [9/200] - Loss: 0.5047 | Val Acc: 79.96%




Epoch [10/200] - Loss: 0.4727 | Val Acc: 81.74%
Saved new best model at epoch 10 with Val Acc: 81.74%




Epoch [11/200] - Loss: 0.4521 | Val Acc: 81.42%




Epoch [12/200] - Loss: 0.4319 | Val Acc: 80.90%




Epoch [13/200] - Loss: 0.4173 | Val Acc: 78.52%




Epoch [14/200] - Loss: 0.4025 | Val Acc: 83.88%
Saved new best model at epoch 14 with Val Acc: 83.88%




Epoch [15/200] - Loss: 0.3853 | Val Acc: 84.10%
Saved new best model at epoch 15 with Val Acc: 84.10%




Epoch [16/200] - Loss: 0.3801 | Val Acc: 82.48%




Epoch [17/200] - Loss: 0.3665 | Val Acc: 83.78%




Epoch [18/200] - Loss: 0.3648 | Val Acc: 83.64%




Epoch [19/200] - Loss: 0.3559 | Val Acc: 85.28%
Saved new best model at epoch 19 with Val Acc: 85.28%




Epoch [20/200] - Loss: 0.3465 | Val Acc: 81.68%




Epoch [21/200] - Loss: 0.3414 | Val Acc: 85.82%
Saved new best model at epoch 21 with Val Acc: 85.82%




Epoch [22/200] - Loss: 0.3386 | Val Acc: 82.26%




Epoch [23/200] - Loss: 0.3237 | Val Acc: 85.02%




Epoch [24/200] - Loss: 0.3257 | Val Acc: 79.42%




Epoch [25/200] - Loss: 0.3166 | Val Acc: 82.54%




Epoch [26/200] - Loss: 0.3180 | Val Acc: 84.76%




Epoch [27/200] - Loss: 0.3077 | Val Acc: 82.76%




Epoch [28/200] - Loss: 0.2954 | Val Acc: 83.44%




Epoch [29/200] - Loss: 0.3050 | Val Acc: 83.36%




Epoch [30/200] - Loss: 0.2906 | Val Acc: 83.42%




Epoch [31/200] - Loss: 0.2924 | Val Acc: 81.88%




Epoch [32/200] - Loss: 0.2885 | Val Acc: 85.28%




Epoch [33/200] - Loss: 0.2846 | Val Acc: 86.34%
Saved new best model at epoch 33 with Val Acc: 86.34%




Epoch [34/200] - Loss: 0.2778 | Val Acc: 85.36%




Epoch [35/200] - Loss: 0.2751 | Val Acc: 84.96%




Epoch [36/200] - Loss: 0.2740 | Val Acc: 85.46%




Epoch [37/200] - Loss: 0.2756 | Val Acc: 86.24%




Epoch [38/200] - Loss: 0.2696 | Val Acc: 84.76%




Epoch [39/200] - Loss: 0.2681 | Val Acc: 88.26%
Saved new best model at epoch 39 with Val Acc: 88.26%




Epoch [40/200] - Loss: 0.2639 | Val Acc: 83.36%




Epoch [41/200] - Loss: 0.2532 | Val Acc: 86.48%




Epoch [42/200] - Loss: 0.2617 | Val Acc: 85.30%




Epoch [43/200] - Loss: 0.2484 | Val Acc: 87.04%




Epoch [44/200] - Loss: 0.2562 | Val Acc: 86.96%




Epoch [45/200] - Loss: 0.2413 | Val Acc: 85.32%




Epoch [46/200] - Loss: 0.2389 | Val Acc: 86.08%




Epoch [47/200] - Loss: 0.2509 | Val Acc: 87.72%




Epoch [48/200] - Loss: 0.2401 | Val Acc: 87.10%




Epoch [49/200] - Loss: 0.2434 | Val Acc: 86.34%




Epoch [50/200] - Loss: 0.2364 | Val Acc: 86.86%




Epoch [51/200] - Loss: 0.2319 | Val Acc: 87.12%




Epoch [52/200] - Loss: 0.2366 | Val Acc: 87.84%




Epoch [53/200] - Loss: 0.2293 | Val Acc: 85.90%




Epoch [54/200] - Loss: 0.2238 | Val Acc: 86.74%




Epoch [55/200] - Loss: 0.2295 | Val Acc: 85.40%




Epoch [56/200] - Loss: 0.2290 | Val Acc: 85.46%




Epoch [57/200] - Loss: 0.2202 | Val Acc: 87.70%




Epoch [58/200] - Loss: 0.2212 | Val Acc: 86.74%




Epoch [59/200] - Loss: 0.2134 | Val Acc: 87.30%




Epoch [60/200] - Loss: 0.2190 | Val Acc: 87.50%




Epoch [61/200] - Loss: 0.2128 | Val Acc: 85.20%




Epoch [62/200] - Loss: 0.2151 | Val Acc: 88.60%
Saved new best model at epoch 62 with Val Acc: 88.60%




Epoch [63/200] - Loss: 0.2107 | Val Acc: 87.84%




Epoch [64/200] - Loss: 0.2032 | Val Acc: 86.16%




Epoch [65/200] - Loss: 0.2029 | Val Acc: 88.94%
Saved new best model at epoch 65 with Val Acc: 88.94%




Epoch [66/200] - Loss: 0.2041 | Val Acc: 87.08%




Epoch [67/200] - Loss: 0.2067 | Val Acc: 87.82%




Epoch [68/200] - Loss: 0.1918 | Val Acc: 85.20%




Epoch [69/200] - Loss: 0.2009 | Val Acc: 86.82%




Epoch [70/200] - Loss: 0.1914 | Val Acc: 88.70%




Epoch [71/200] - Loss: 0.1978 | Val Acc: 87.06%




Epoch [72/200] - Loss: 0.1917 | Val Acc: 87.90%




Epoch [73/200] - Loss: 0.1841 | Val Acc: 88.74%




Epoch [74/200] - Loss: 0.1850 | Val Acc: 88.24%




Epoch [75/200] - Loss: 0.1914 | Val Acc: 85.94%




Epoch [76/200] - Loss: 0.1797 | Val Acc: 84.26%




Epoch [77/200] - Loss: 0.1816 | Val Acc: 87.78%




Epoch [78/200] - Loss: 0.1754 | Val Acc: 87.18%




Epoch [79/200] - Loss: 0.1780 | Val Acc: 88.60%




Epoch [80/200] - Loss: 0.1787 | Val Acc: 88.86%




Epoch [81/200] - Loss: 0.1657 | Val Acc: 88.16%




Epoch [82/200] - Loss: 0.1678 | Val Acc: 84.80%




Epoch [83/200] - Loss: 0.1699 | Val Acc: 87.88%




Epoch [84/200] - Loss: 0.1667 | Val Acc: 85.66%




Epoch [85/200] - Loss: 0.1645 | Val Acc: 89.06%
Saved new best model at epoch 85 with Val Acc: 89.06%




Epoch [86/200] - Loss: 0.1648 | Val Acc: 89.10%
Saved new best model at epoch 86 with Val Acc: 89.10%




Epoch [87/200] - Loss: 0.1552 | Val Acc: 87.82%




Epoch [88/200] - Loss: 0.1588 | Val Acc: 86.08%




Epoch [89/200] - Loss: 0.1569 | Val Acc: 87.88%




Epoch [90/200] - Loss: 0.1545 | Val Acc: 87.76%




Epoch [91/200] - Loss: 0.1500 | Val Acc: 87.64%




Epoch [92/200] - Loss: 0.1552 | Val Acc: 89.00%




Epoch [93/200] - Loss: 0.1409 | Val Acc: 88.92%




Epoch [94/200] - Loss: 0.1439 | Val Acc: 88.88%




Epoch [95/200] - Loss: 0.1440 | Val Acc: 89.06%




Epoch [96/200] - Loss: 0.1402 | Val Acc: 89.40%
Saved new best model at epoch 96 with Val Acc: 89.40%




Epoch [97/200] - Loss: 0.1396 | Val Acc: 88.06%




Epoch [98/200] - Loss: 0.1361 | Val Acc: 88.08%




Epoch [99/200] - Loss: 0.1291 | Val Acc: 88.18%




Epoch [100/200] - Loss: 0.1328 | Val Acc: 90.38%
Saved new best model at epoch 100 with Val Acc: 90.38%




Epoch [101/200] - Loss: 0.1262 | Val Acc: 86.12%




Epoch [102/200] - Loss: 0.1303 | Val Acc: 89.20%




Epoch [103/200] - Loss: 0.1264 | Val Acc: 89.86%




Epoch [104/200] - Loss: 0.1206 | Val Acc: 89.86%




Epoch [105/200] - Loss: 0.1168 | Val Acc: 89.62%




Epoch [106/200] - Loss: 0.1194 | Val Acc: 88.30%




Epoch [107/200] - Loss: 0.1152 | Val Acc: 89.98%




Epoch [108/200] - Loss: 0.1108 | Val Acc: 89.32%




Epoch [109/200] - Loss: 0.1072 | Val Acc: 89.62%




Epoch [110/200] - Loss: 0.1136 | Val Acc: 87.28%




Epoch [111/200] - Loss: 0.1094 | Val Acc: 88.64%




Epoch [112/200] - Loss: 0.1020 | Val Acc: 89.38%




Epoch [113/200] - Loss: 0.0950 | Val Acc: 87.22%




Epoch [114/200] - Loss: 0.0984 | Val Acc: 89.98%




Epoch [115/200] - Loss: 0.0951 | Val Acc: 90.98%
Saved new best model at epoch 115 with Val Acc: 90.98%




Epoch [116/200] - Loss: 0.0946 | Val Acc: 88.42%




Epoch [117/200] - Loss: 0.0928 | Val Acc: 90.02%




Epoch [118/200] - Loss: 0.0940 | Val Acc: 90.66%




Epoch [119/200] - Loss: 0.0824 | Val Acc: 90.52%




Epoch [120/200] - Loss: 0.0833 | Val Acc: 90.96%




Epoch [121/200] - Loss: 0.0841 | Val Acc: 89.78%




Epoch [122/200] - Loss: 0.0892 | Val Acc: 91.26%
Saved new best model at epoch 122 with Val Acc: 91.26%




Epoch [123/200] - Loss: 0.0765 | Val Acc: 90.80%




Epoch [124/200] - Loss: 0.0751 | Val Acc: 91.58%
Saved new best model at epoch 124 with Val Acc: 91.58%




Epoch [125/200] - Loss: 0.0677 | Val Acc: 91.40%




Epoch [126/200] - Loss: 0.0713 | Val Acc: 90.72%




Epoch [127/200] - Loss: 0.0726 | Val Acc: 91.80%
Saved new best model at epoch 127 with Val Acc: 91.80%




Epoch [128/200] - Loss: 0.0617 | Val Acc: 90.98%




Epoch [129/200] - Loss: 0.0622 | Val Acc: 91.40%




Epoch [130/200] - Loss: 0.0690 | Val Acc: 91.54%




Epoch [131/200] - Loss: 0.0593 | Val Acc: 91.26%




Epoch [132/200] - Loss: 0.0577 | Val Acc: 91.92%
Saved new best model at epoch 132 with Val Acc: 91.92%




Epoch [133/200] - Loss: 0.0532 | Val Acc: 91.84%




Epoch [134/200] - Loss: 0.0521 | Val Acc: 91.60%




Epoch [135/200] - Loss: 0.0485 | Val Acc: 90.72%




Epoch [136/200] - Loss: 0.0485 | Val Acc: 90.98%




Epoch [137/200] - Loss: 0.0471 | Val Acc: 91.50%




Epoch [138/200] - Loss: 0.0446 | Val Acc: 92.20%
Saved new best model at epoch 138 with Val Acc: 92.20%




Epoch [139/200] - Loss: 0.0400 | Val Acc: 92.14%




Epoch [140/200] - Loss: 0.0337 | Val Acc: 91.96%




Epoch [141/200] - Loss: 0.0352 | Val Acc: 91.80%




Epoch [142/200] - Loss: 0.0350 | Val Acc: 92.34%
Saved new best model at epoch 142 with Val Acc: 92.34%




Epoch [143/200] - Loss: 0.0324 | Val Acc: 92.34%




Epoch [144/200] - Loss: 0.0313 | Val Acc: 91.82%




Epoch [145/200] - Loss: 0.0288 | Val Acc: 92.70%
Saved new best model at epoch 145 with Val Acc: 92.70%




Epoch [146/200] - Loss: 0.0287 | Val Acc: 93.38%
Saved new best model at epoch 146 with Val Acc: 93.38%




Epoch [147/200] - Loss: 0.0235 | Val Acc: 92.58%




Epoch [148/200] - Loss: 0.0256 | Val Acc: 92.98%




Epoch [149/200] - Loss: 0.0210 | Val Acc: 92.28%




Epoch [150/200] - Loss: 0.0181 | Val Acc: 92.72%




Epoch [151/200] - Loss: 0.0211 | Val Acc: 91.84%




Epoch [152/200] - Loss: 0.0164 | Val Acc: 93.16%




Epoch [153/200] - Loss: 0.0153 | Val Acc: 93.00%




Epoch [154/200] - Loss: 0.0132 | Val Acc: 93.72%
Saved new best model at epoch 154 with Val Acc: 93.72%




Epoch [155/200] - Loss: 0.0096 | Val Acc: 93.56%




Epoch [156/200] - Loss: 0.0084 | Val Acc: 93.62%




Epoch [157/200] - Loss: 0.0088 | Val Acc: 93.74%
Saved new best model at epoch 157 with Val Acc: 93.74%




Epoch [158/200] - Loss: 0.0075 | Val Acc: 93.54%




Epoch [159/200] - Loss: 0.0062 | Val Acc: 93.80%
Saved new best model at epoch 159 with Val Acc: 93.80%




Epoch [160/200] - Loss: 0.0067 | Val Acc: 93.90%
Saved new best model at epoch 160 with Val Acc: 93.90%




Epoch [161/200] - Loss: 0.0058 | Val Acc: 93.96%
Saved new best model at epoch 161 with Val Acc: 93.96%




Epoch [162/200] - Loss: 0.0047 | Val Acc: 94.20%
Saved new best model at epoch 162 with Val Acc: 94.20%




Epoch [163/200] - Loss: 0.0041 | Val Acc: 94.00%




Epoch [164/200] - Loss: 0.0038 | Val Acc: 94.08%




Epoch [165/200] - Loss: 0.0037 | Val Acc: 94.26%
Saved new best model at epoch 165 with Val Acc: 94.26%




Epoch [166/200] - Loss: 0.0030 | Val Acc: 94.18%




Epoch [167/200] - Loss: 0.0027 | Val Acc: 94.46%
Saved new best model at epoch 167 with Val Acc: 94.46%




Epoch [168/200] - Loss: 0.0031 | Val Acc: 94.46%




Epoch [169/200] - Loss: 0.0030 | Val Acc: 94.22%




Epoch [170/200] - Loss: 0.0027 | Val Acc: 94.36%




Epoch [171/200] - Loss: 0.0024 | Val Acc: 94.24%




Epoch [172/200] - Loss: 0.0027 | Val Acc: 94.28%




Epoch [173/200] - Loss: 0.0025 | Val Acc: 94.30%




Epoch [174/200] - Loss: 0.0023 | Val Acc: 94.42%




Epoch [175/200] - Loss: 0.0021 | Val Acc: 94.28%




Epoch [176/200] - Loss: 0.0021 | Val Acc: 94.38%




Epoch [177/200] - Loss: 0.0019 | Val Acc: 94.32%




Epoch [178/200] - Loss: 0.0019 | Val Acc: 94.62%
Saved new best model at epoch 178 with Val Acc: 94.62%




Epoch [179/200] - Loss: 0.0018 | Val Acc: 94.62%




Epoch [180/200] - Loss: 0.0019 | Val Acc: 94.36%




Epoch [181/200] - Loss: 0.0021 | Val Acc: 94.54%




Epoch [182/200] - Loss: 0.0020 | Val Acc: 94.38%




Epoch [183/200] - Loss: 0.0019 | Val Acc: 94.54%




Epoch [184/200] - Loss: 0.0019 | Val Acc: 94.38%




Epoch [185/200] - Loss: 0.0019 | Val Acc: 94.56%




Epoch [186/200] - Loss: 0.0019 | Val Acc: 94.42%




Epoch [187/200] - Loss: 0.0018 | Val Acc: 94.54%




Epoch [188/200] - Loss: 0.0019 | Val Acc: 94.28%




Epoch [189/200] - Loss: 0.0020 | Val Acc: 94.40%




Epoch [190/200] - Loss: 0.0019 | Val Acc: 94.32%




Epoch [191/200] - Loss: 0.0020 | Val Acc: 94.38%




Epoch [192/200] - Loss: 0.0019 | Val Acc: 94.46%




Epoch [193/200] - Loss: 0.0019 | Val Acc: 94.58%




Epoch [194/200] - Loss: 0.0018 | Val Acc: 94.46%




Epoch [195/200] - Loss: 0.0019 | Val Acc: 94.34%




Epoch [196/200] - Loss: 0.0018 | Val Acc: 94.50%




Epoch [197/200] - Loss: 0.0018 | Val Acc: 94.46%




Epoch [198/200] - Loss: 0.0017 | Val Acc: 94.40%




Epoch [199/200] - Loss: 0.0018 | Val Acc: 94.40%




Epoch [200/200] - Loss: 0.0018 | Val Acc: 94.48%
Training complete.
Best Validation Accuracy: 94.62%


In [None]:
# inference.py
import os
import pickle
import numpy as np
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms


def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict
# test_file = "/Users/rahilsinghi/Desktop/DL_PROJECT_1/data/cifar_test_nolabel.pkl"

# test_dict = unpickle(test_file)
# print("Keys in test_dict:", test_dict.keys())

class CIFARTestDataset(Dataset):
    def __init__(self, data, ids, transform=None):
        """
        data: numpy array of shape (N, 3072)
        ids: list or numpy array of shape (N,)
        transform: torchvision transform to apply
        """
        self.data = data
        self.ids = ids
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Convert 3072 vector -> (3, 32, 32) -> (32, 32, 3)
        img = self.data[idx].reshape(3, 32, 32).transpose(1, 2, 0).astype("uint8")
        img_id = self.ids[idx]

        if self.transform:
            img = self.transform(img)

        return img, img_id

def main():
    # 1) Load best model
    if torch.backends.mps.is_available():
        device = "mps"
    else:
        device = "cpu"
    print(f"Using device: {device}")

    # Instantiate your model and move it to the device
    model = ResNet18CIFAR(num_classes=10, initial_planes=32).to(device)
    model.load_state_dict(torch.load("best_model.pth", map_location=device))
    model.eval()


    # 2) Load the custom test set
    #   Path is relative to the script location (cifar_resnet/).
    #   Adjust if your file is somewhere else.
    test_file = "/content/cifar_test_nolabel.pkl"
    test_dict = unpickle(test_file)
    # Typically, test_dict might have keys like b'images' and b'ids'
    test_images = test_dict[b'data']  # shape (N, 3072)
    test_ids = test_dict[b'ids']        # shape (N,)

    # 3) Define transforms
    transform_test = transforms.Compose([
        transforms.ToPILImage(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2470, 0.2435, 0.2616))
    ])

    # 4) Create Dataset and DataLoader
    test_dataset = CIFARTestDataset(test_images, test_ids, transform=transform_test)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

    # 5) Run inference and collect predictions
    all_preds = []
    all_ids = []
    with torch.no_grad():
        for images, ids in test_loader:
            images = images.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().numpy().tolist())
            all_ids.extend(ids.numpy().tolist())  # or just .tolist()

    # 6) Create submission.csv
    submission_df = pd.DataFrame({
        "ID": all_ids,
        "Labels": all_preds
    })
    submission_df.to_csv("submission.csv", index=False)
    print("submission.csv generated! You can submit this file to Kaggle.")

if __name__ == "__main__":
    main()
