In [None]:
import argparse
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import torch.optim as optim
import torch.nn as nn
import time
from tqdm import tqdm

In [None]:
def load_data(batch_size):
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

    return trainloader, testloader

def modify_resnet50():
    resnet50 = models.resnet50(pretrained=True)
    for idx, p in enumerate(resnet50.parameters()):
      if idx == len(list(resnet50.parameters())) - 2:
          p.requires_grad = True
      else:
          p.requires_grad = False
    num_ftrs = resnet50.fc.in_features
    resnet50.fc = torch.nn.Linear(num_ftrs, 10)
    for param in resnet50.fc.parameters():
        param.requires_grad = True
    return resnet50

def train_model(model, trainloader, criterion, optimizer, num_epochs):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    start_time = time.time()
    for epoch in range(num_epochs):
      model.train()
      running_loss = 0.0

      # Create a progress bar
      pbar = tqdm(enumerate(trainloader), total=len(trainloader))

      for batch_idx, (inputs, labels) in pbar:
          inputs, labels = inputs.to(device), labels.to(device)
          optimizer.zero_grad()
          outputs = model(inputs)
          loss = criterion(outputs, labels)
          loss.backward()
          optimizer.step()
          running_loss += loss.item()

          # Update the progress bar
          pbar.set_description(f'Epoch [{epoch + 1}/{num_epochs}] - Loss: {running_loss / (batch_idx + 1):.4f}')
      end_time = time.time()
      print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(trainloader)}, Time: {end_time - start_time} seconds")



def evaluate_model(model, testloader):
    correct = 0
    total = 0
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.eval()
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(correct /total)

    print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')

def main():
    parser = argparse.ArgumentParser(description='Fine-tune ResNet-50 on CIFAR-10')
    parser.add_argument('--batch_size', type=int, default=32, help='input batch size for training (default: 32)')
    parser.add_argument('--epochs', type=int, default=5, help='number of epochs to train (default: 5)')
    parser.add_argument('--lr', type=float, default=0.001, help='learning rate (default: 0.001)')
    args = parser.parse_args()

    trainloader, testloader = load_data(args.batch_size)
    resnet50 = modify_resnet50()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(resnet50.parameters(), lr=args.lr, momentum=0.9)

    train_model(resnet50, trainloader, criterion, optimizer, args.epochs)
    evaluate_model(resnet50, testloader)



In [None]:
# Define argparse-like arguments for Jupyter Notebook cell
batch_size = 32
epochs = 10
lr = 0.001

trainloader, testloader = load_data(batch_size)
resnet50 = modify_resnet50()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet50.parameters(), lr=lr, momentum=0.9)

train_model(resnet50, trainloader, criterion, optimizer, epochs)
evaluate_model(resnet50, testloader)

Files already downloaded and verified
Files already downloaded and verified


Epoch [1/10] - Loss: 0.8637: 100%|██████████| 1563/1563 [02:27<00:00, 10.60it/s]


Epoch 1/10, Loss: 0.8637331230512278, Time: 147.40082120895386 seconds


Epoch [2/10] - Loss: 0.6473: 100%|██████████| 1563/1563 [02:26<00:00, 10.68it/s]


Epoch 2/10, Loss: 0.6473155628010316, Time: 293.78012251853943 seconds


Epoch [3/10] - Loss: 0.6099: 100%|██████████| 1563/1563 [02:27<00:00, 10.57it/s]


Epoch 3/10, Loss: 0.6099471946843373, Time: 441.7081437110901 seconds


Epoch [4/10] - Loss: 0.5907: 100%|██████████| 1563/1563 [02:26<00:00, 10.67it/s]


Epoch 4/10, Loss: 0.5906538623537074, Time: 588.1396687030792 seconds


Epoch [5/10] - Loss: 0.5800: 100%|██████████| 1563/1563 [02:27<00:00, 10.63it/s]


Epoch 5/10, Loss: 0.5800197400290922, Time: 735.1776878833771 seconds


Epoch [6/10] - Loss: 0.5643: 100%|██████████| 1563/1563 [02:28<00:00, 10.55it/s]


Epoch 6/10, Loss: 0.5642732564295551, Time: 883.3346328735352 seconds


Epoch [7/10] - Loss: 0.5599: 100%|██████████| 1563/1563 [02:26<00:00, 10.68it/s]


Epoch 7/10, Loss: 0.5599104827478462, Time: 1029.668349981308 seconds


Epoch [8/10] - Loss: 0.5552: 100%|██████████| 1563/1563 [02:26<00:00, 10.66it/s]


Epoch 8/10, Loss: 0.5551910062230556, Time: 1176.2514276504517 seconds


Epoch [9/10] - Loss: 0.5500: 100%|██████████| 1563/1563 [02:26<00:00, 10.68it/s]


Epoch 9/10, Loss: 0.5500141528559586, Time: 1322.5582876205444 seconds


Epoch [10/10] - Loss: 0.5403: 100%|██████████| 1563/1563 [02:26<00:00, 10.67it/s]


Epoch 10/10, Loss: 0.5402739241538106, Time: 1468.9861254692078 seconds
0.8126
Accuracy of the network on the 10000 test images: 81 %
