In [None]:
import torchvision
from torch import nn
from task2 import Trainer, create_plots, compute_loss_and_accuracy
from dataloaders import load_cifar10


In [None]:

class ResNet18TransferModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.model = torchvision.models.resnet18(pretrained=True)
        self.model.fc = nn.Linear(512, 10) # No need to apply softmax, as this is done in nn.CrossEntropyLoss
        
        for param in self.model.parameters(): # Freeze all parameters
            param.requires_grad = False
        for param in self.model.fc.parameters(): # Unfreeze the last fully-connected layer
            param.requires_grad = True
        for param in self.model.layer4.parameters(): # Unfreeze the last 5 convolutional layers
            param.requires_grad = True 

    def forward(self, x):
        x = self.model(x)
        return x


In [None]:
epochs = 10
batch_size = 32
learning_rate = 5e-4
early_stop_count = 4
dataloaders = load_cifar10(
    batch_size, 
    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], 
    additional_transforms=[torchvision.transforms.Resize(224)]
)
model = ResNet18TransferModel()
trainer = Trainer(
    batch_size,
    learning_rate,
    early_stop_count,
    epochs,
    model,
    dataloaders,
    use_adam_optimizer=True
)
trainer.train()
create_plots(trainer, "task4a")

dataloader_train, dataloader_val, dataloader_test = dataloaders
print("Train Accuracy:", compute_loss_and_accuracy(dataloader_train, model, nn.CrossEntropyLoss())[1])
print("Validation Accuracy:", compute_loss_and_accuracy(dataloader_val, model, nn.CrossEntropyLoss())[1])
print("Test Accuracy:", compute_loss_and_accuracy(dataloader_test, model, nn.CrossEntropyLoss())[1])
