In [1]:
import torch
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

import pickle
import os
from tqdm import tqdm


In [2]:
resnet18 = models.resnet18(weights=None)
num_classes = 10
resnet18.fc = torch.nn.Linear(resnet18.fc.in_features, num_classes)
print(resnet18)
num_params=sum(p.numel() for p in resnet18.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {num_params:,}")



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [6]:
# transform_train = transforms.Compose([
#     transforms.RandomHorizontalFlip(),
#     transforms.RandomCrop(32, padding=4),
#     transforms.ToTensor(),
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalizing the images
# ])

# transform_test = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
# ])



def calculate_mean_std(dataset):
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    for images, _ in dataloader:
        for i in range(3):
            mean[i] += images[:,i,:,:].mean()
            std[i] += images[:,i,:,:].std()
    mean.div_(len(dataset))
    std.div_(len(dataset))
    return mean, std

# Path for the dataset and the mean/std file
data_path = '../data'
mean_std_file = os.path.join(data_path, 'cifar10_mean_std.pkl')


# Check if mean/std file exists, calculate if not
if not os.path.exists(mean_std_file):
    # Load CIFAR-10 without normalization
    trainset_raw = torchvision.datasets.CIFAR10(root=data_path, train=True, download=True, transform=transforms.ToTensor())
    mean, std = calculate_mean_std(trainset_raw)
    with open(mean_std_file, 'wb') as f:
        pickle.dump((mean, std), f)
    print("Mean and Std Dev calculated and saved.")
else:
    with open(mean_std_file, 'rb') as f:
        mean, std = pickle.load(f)
    print("Mean and Std Dev loaded from file.")

print(f"Mean: {mean}, Std: {std}")

# Use the calculated/loaded mean and std for normalization
transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.ConvertImageDtype(torch.float32),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0),
    transforms.Normalize(mean, std)
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.ConvertImageDtype(torch.float32),
    transforms.Normalize(mean, std)
])


Mean and Std Dev loaded from file.
Mean: tensor([0.0077, 0.0075, 0.0070]), Std: tensor([0.0039, 0.0038, 0.0041])


In [7]:
device=torch.device('mps')
print(f"Using device: {device}")
resnet18 = resnet18.to(device)


# Load the CIFAR-10 dataset with transforms
trainset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

# Define loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(resnet18.parameters(), lr=0.003,weight_decay=3e-2)

Using device: mps
Files already downloaded and verified
Files already downloaded and verified


In [8]:


num_epochs = 20  # Set the number of epochs

for epoch in range(num_epochs):
    resnet18.train()  # Set the model to training mode
    running_loss = 0.0

    # Training loop with tqdm progress bar
    train_bar = tqdm(trainloader, desc=f"Training Epoch {epoch + 1}")
    for i, data in enumerate(train_bar):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()  # Zero the parameter gradients
        outputs = resnet18(inputs)  # Forward pass
        loss = criterion(outputs, labels)  # Compute loss
        loss.backward()  # Backward pass
        optimizer.step()  # Update weights

        running_loss += loss.item()
        train_bar.set_postfix(loss=(running_loss / (i + 1)))

    # Validation loop with tqdm progress bar
    resnet18.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0
    val_bar = tqdm(testloader, desc=f"Validation Epoch {epoch + 1}")
    with torch.no_grad():  # No gradients needed for validation
        for data in val_bar:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = resnet18(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy of the network on the test images: {accuracy:.2f}%')

    # Optional: Save model checkpoint if this is the best epoch so far

print('Finished Training')


Training Epoch 1: 100%|██████████| 391/391 [00:24<00:00, 15.86it/s, loss=0.88] 
Validation Epoch 1: 100%|██████████| 100/100 [00:03<00:00, 25.68it/s]


Accuracy of the network on the test images: 71.11%


Training Epoch 2: 100%|██████████| 391/391 [00:24<00:00, 15.82it/s, loss=0.826]
Validation Epoch 2: 100%|██████████| 100/100 [00:04<00:00, 22.50it/s]


Accuracy of the network on the test images: 68.65%


Training Epoch 3: 100%|██████████| 391/391 [00:24<00:00, 15.89it/s, loss=0.792]
Validation Epoch 3: 100%|██████████| 100/100 [00:04<00:00, 24.72it/s]


Accuracy of the network on the test images: 74.58%


Training Epoch 4: 100%|██████████| 391/391 [00:24<00:00, 15.83it/s, loss=0.759]
Validation Epoch 4: 100%|██████████| 100/100 [00:04<00:00, 23.16it/s]


Accuracy of the network on the test images: 76.06%


Training Epoch 5: 100%|██████████| 391/391 [00:25<00:00, 15.53it/s, loss=0.748]
Validation Epoch 5: 100%|██████████| 100/100 [00:04<00:00, 23.81it/s]


Accuracy of the network on the test images: 76.50%


Training Epoch 6: 100%|██████████| 391/391 [00:24<00:00, 16.05it/s, loss=0.729]
Validation Epoch 6: 100%|██████████| 100/100 [00:04<00:00, 24.76it/s]


Accuracy of the network on the test images: 78.20%


Training Epoch 7: 100%|██████████| 391/391 [00:23<00:00, 16.43it/s, loss=0.712]
Validation Epoch 7: 100%|██████████| 100/100 [00:03<00:00, 25.61it/s]


Accuracy of the network on the test images: 75.57%


Training Epoch 8: 100%|██████████| 391/391 [00:23<00:00, 16.32it/s, loss=0.691]
Validation Epoch 8: 100%|██████████| 100/100 [00:03<00:00, 26.87it/s]


Accuracy of the network on the test images: 70.82%


Training Epoch 9: 100%|██████████| 391/391 [00:23<00:00, 16.42it/s, loss=0.684]
Validation Epoch 9: 100%|██████████| 100/100 [00:03<00:00, 25.57it/s]


Accuracy of the network on the test images: 78.25%


Training Epoch 10: 100%|██████████| 391/391 [00:23<00:00, 16.30it/s, loss=0.683]
Validation Epoch 10: 100%|██████████| 100/100 [00:04<00:00, 20.28it/s]


Accuracy of the network on the test images: 73.45%


Training Epoch 11: 100%|██████████| 391/391 [00:24<00:00, 16.07it/s, loss=0.667]
Validation Epoch 11: 100%|██████████| 100/100 [00:03<00:00, 27.22it/s]


Accuracy of the network on the test images: 76.23%


Training Epoch 12: 100%|██████████| 391/391 [00:24<00:00, 16.18it/s, loss=0.656]
Validation Epoch 12: 100%|██████████| 100/100 [00:03<00:00, 27.20it/s]


Accuracy of the network on the test images: 79.19%


Training Epoch 13: 100%|██████████| 391/391 [00:23<00:00, 16.35it/s, loss=0.648]
Validation Epoch 13: 100%|██████████| 100/100 [00:03<00:00, 26.17it/s]


Accuracy of the network on the test images: 77.03%


Training Epoch 14: 100%|██████████| 391/391 [00:23<00:00, 16.39it/s, loss=0.639]
Validation Epoch 14: 100%|██████████| 100/100 [00:03<00:00, 25.85it/s]


Accuracy of the network on the test images: 78.36%


Training Epoch 15: 100%|██████████| 391/391 [00:24<00:00, 16.23it/s, loss=0.627]
Validation Epoch 15: 100%|██████████| 100/100 [00:03<00:00, 26.70it/s]


Accuracy of the network on the test images: 79.00%


Training Epoch 16: 100%|██████████| 391/391 [00:24<00:00, 16.24it/s, loss=0.619]
Validation Epoch 16: 100%|██████████| 100/100 [00:03<00:00, 27.00it/s]


Accuracy of the network on the test images: 75.18%


Training Epoch 17: 100%|██████████| 391/391 [00:24<00:00, 16.22it/s, loss=0.62] 
Validation Epoch 17: 100%|██████████| 100/100 [00:03<00:00, 26.82it/s]


Accuracy of the network on the test images: 79.50%


Training Epoch 18: 100%|██████████| 391/391 [00:23<00:00, 16.54it/s, loss=0.618]
Validation Epoch 18: 100%|██████████| 100/100 [00:04<00:00, 24.92it/s]


Accuracy of the network on the test images: 76.65%


Training Epoch 19: 100%|██████████| 391/391 [00:23<00:00, 16.38it/s, loss=0.605]
Validation Epoch 19: 100%|██████████| 100/100 [00:03<00:00, 28.47it/s]


Accuracy of the network on the test images: 77.35%


Training Epoch 20: 100%|██████████| 391/391 [00:23<00:00, 16.52it/s, loss=0.594]
Validation Epoch 20: 100%|██████████| 100/100 [00:03<00:00, 27.41it/s]

Accuracy of the network on the test images: 81.79%
Finished Training



