In [1]:
import torch
import torch.nn as nn
from torchvision import models, datasets, transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

In [2]:
model = models.resnet18(weights='DEFAULT')

for param in model.parameters():
    param.requires_grad = False

num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)

model = model.to(device)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to C:\Users\Rohan/.cache\torch\hub\checkpoints\resnet18-f37072fd.pth


100.0%


In [3]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)
epochs = 3

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    correct, total = 0, 0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}, Accuracy: {100.*correct/total:.2f}%')

print("Transfer learning phase 1 finished!")

Epoch 1, Loss: 0.8009, Accuracy: 73.49%
Epoch 2, Loss: 0.6390, Accuracy: 77.90%
Epoch 3, Loss: 0.6189, Accuracy: 78.64%
Transfer learning phase 1 finished!


In [4]:
# Unfreeze everything
for param in model.parameters():
    param.requires_grad = True

# Used a MUCH smaller learning rate so we don't 'break' the pre-trained weights
optimizer = optim.Adam(model.parameters(), lr=0.00001)

# One or two more 'polishing' epochs
for epoch in range(2):
    # (Same training loop code as above...)
    model.train()
    running_loss = 0.0
    correct, total = 0, 0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}, Accuracy: {100.*correct/total:.2f}%')
print("Polished the weights...")

Epoch 1, Loss: 0.3689, Accuracy: 87.31%
Epoch 2, Loss: 0.1884, Accuracy: 93.51%
Polished the weights...


In [5]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

model.eval()
with torch.no_grad():
    for data in test_loader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = model(images)
        _, predictions = torch.max(outputs, 1)
        for label, prediction in zip(labels, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1

# Print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')

Accuracy for class: plane is 95.2 %
Accuracy for class: car   is 96.1 %
Accuracy for class: bird  is 90.5 %
Accuracy for class: cat   is 83.5 %
Accuracy for class: deer  is 93.4 %
Accuracy for class: dog   is 87.5 %
Accuracy for class: frog  is 96.1 %
Accuracy for class: horse is 93.6 %
Accuracy for class: ship  is 94.4 %
Accuracy for class: truck is 95.2 %
