In [1]:
import torch
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import models
import time

In [2]:
transform = transforms.Compose(
    [transforms.Resize((224,224)),
     transforms.ToTensor(),
     transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]
)

# Load the train and test sets
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)

test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
import torch

# Load the checkpoint
checkpoint = torch.load('vgg16_ht2.pth')

# Extract the model structure and state dict
model = checkpoint["model"]  # This contains the model architecture
model.load_state_dict(checkpoint["state_dict"])  # Load the saved weights

# Move the model to the appropriate device
model = model.to(device)

# Set model to training mode
model.train()


  checkpoint = torch.load('vgg16_ht2.pth')


VGG(
  (features): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 27, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): Conv2d(27, 64, kernel_size=(1, 1), stride=(1, 1))
    )
    (1): ReLU(inplace=True)
    (2): Sequential(
      (0): Conv2d(64, 13, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): Conv2d(13, 19, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)
      (2): Conv2d(19, 19, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)
      (3): Conv2d(19, 64, kernel_size=(1, 1), stride=(1, 1))
    )
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Sequential(
      (0): Conv2d(64, 18, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): Conv2d(18, 31, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)
      (2): Conv2d(31, 34, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)
      (3): Conv2d(34, 128, kernel_size=(1,

In [5]:
import torch.optim as optim
import torch.nn as nn

optimizer = optim.SGD(model.classifier.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()

In [6]:
def validate(model, test_dataloader, criterion):
    model.eval()
    top1_correct = 0
    top5_correct = 0
    total_samples = 0
    val_running_loss = 0.0

    with torch.no_grad():
        for data, target in test_dataloader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            val_running_loss += loss.item() * target.size(0)  # accumulate the total loss

            # Calculate top-1 and top-5 accuracy
            _, top5_preds = output.topk(5, dim=1, largest=True, sorted=True)
            top5_preds = top5_preds.t()  

            top1_correct += top5_preds[:1].eq(target.view(1, -1)).sum().item()
            top5_correct += top5_preds.eq(target.view(1, -1)).sum().item()
            total_samples += target.size(0)

    top1_accuracy = 100. * top1_correct / total_samples
    top5_accuracy = 100. * top5_correct / total_samples
    val_loss = val_running_loss / total_samples  # calculate average validation loss

    print(f'Validation Loss: {val_loss:.4f}, Top-1 Acc: {top1_accuracy:.2f}, Top-5 Acc: {top5_accuracy:.2f}')
    return val_loss, top1_accuracy, top5_accuracy

In [7]:
def fit(model, train_dataloader, criterion, optimizer):
    model.train()
    train_running_loss = 0.0
    top1_correct = 0
    top5_correct = 0
    total_samples = 0

    for data, target in train_dataloader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        
        # Accumulate the total loss
        train_running_loss += loss.item() * target.size(0)

        # Calculate top-1 and top-5 accuracy
        _, top5_preds = output.topk(5, dim=1, largest=True, sorted=True)
        top5_preds = top5_preds.t()
        
        top1_correct += top5_preds[:1].eq(target.view(1, -1)).sum().item()
        top5_correct += top5_preds.eq(target.view(1, -1)).sum().item()
        total_samples += target.size(0)

        loss.backward()
        optimizer.step()

    train_loss = train_running_loss / total_samples  # average train loss
    top1_accuracy = 100. * top1_correct / total_samples
    top5_accuracy = 100. * top5_correct / total_samples

    print(f'Train Loss: {train_loss:.4f}, Top-1 Acc: {top1_accuracy:.2f}, Top-5 Acc: {top5_accuracy:.2f}')
    return train_loss, top1_accuracy, top5_accuracy

In [8]:
train_loss, train_top1_accuracy, train_top5_accuracy = [], [], []
val_loss, val_top1_accuracy, val_top5_accuracy = [], [], []

start = time.time()
for epoch in range(10):
    train_epoch_loss, train_top1_acc, train_top5_acc = fit(model, train_loader, criterion, optimizer)
    val_epoch_loss, val_top1_acc, val_top5_acc = validate(model, test_loader, criterion)
    
    train_loss.append(train_epoch_loss)
    train_top1_accuracy.append(train_top1_acc)
    train_top5_accuracy.append(train_top5_acc)
    val_loss.append(val_epoch_loss)
    val_top1_accuracy.append(val_top1_acc)
    val_top5_accuracy.append(val_top5_acc)

end = time.time()

torch.save(model.state_dict(), 'vgg16_finetuned.pth')
print(f"Training took {(end - start) / 60:.2f} minutes")

# Print final statistics
print(f"Final Train Loss: {train_loss[-1]:.4f}, Top-1 Acc: {train_top1_accuracy[-1]:.2f}, Top-5 Acc: {train_top5_accuracy[-1]:.2f}")
print(f"Final Val Loss: {val_loss[-1]:.4f}, Top-1 Acc: {val_top1_accuracy[-1]:.2f}, Top-5 Acc: {val_top5_accuracy[-1]:.2f}")

Train Loss: 0.9314, Top-1 Acc: 67.75, Top-5 Acc: 97.15
Validation Loss: 0.7332, Top-1 Acc: 74.13, Top-5 Acc: 98.51
Train Loss: 0.7013, Top-1 Acc: 75.45, Top-5 Acc: 98.60
Validation Loss: 0.7034, Top-1 Acc: 75.57, Top-5 Acc: 98.73
Train Loss: 0.5994, Top-1 Acc: 79.00, Top-5 Acc: 98.99
Validation Loss: 0.6192, Top-1 Acc: 78.47, Top-5 Acc: 98.83
Train Loss: 0.5262, Top-1 Acc: 81.71, Top-5 Acc: 99.25
Validation Loss: 0.6033, Top-1 Acc: 78.90, Top-5 Acc: 98.95
Train Loss: 0.4584, Top-1 Acc: 84.00, Top-5 Acc: 99.48
Validation Loss: 0.5938, Top-1 Acc: 79.79, Top-5 Acc: 99.03
Train Loss: 0.3980, Top-1 Acc: 86.11, Top-5 Acc: 99.64
Validation Loss: 0.6254, Top-1 Acc: 78.79, Top-5 Acc: 98.96
Train Loss: 0.3475, Top-1 Acc: 87.79, Top-5 Acc: 99.74
Validation Loss: 0.5907, Top-1 Acc: 80.66, Top-5 Acc: 99.07
Train Loss: 0.2930, Top-1 Acc: 89.79, Top-5 Acc: 99.82
Validation Loss: 0.6329, Top-1 Acc: 79.93, Top-5 Acc: 98.96
Train Loss: 0.2478, Top-1 Acc: 91.27, Top-5 Acc: 99.87
Validation Loss: 0.6286, 

In [None]:
FLOPS