In [1]:
import torch
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
import lightning as L
import numpy as np
from sklearn.model_selection import train_test_split
from data import Cub2011
from model import CubClassifier

Seed set to 42


In [2]:
torch.manual_seed(42)
WORKPATH = '/home/longkailin/sms2/nndl/midterm'
# Define transformation for the dataset
transform_train = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomCrop(224, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

transform_test = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# Load the CUB-2011 dataset
train_datasest = Cub2011(root=WORKPATH, download=False, transform=transform_train)
test_dataset = Cub2011(root=WORKPATH, download=False, train=False, transform=transform_test)

# Stratified Sampling for train and val
train_idx, validation_idx = train_test_split(np.arange(len(train_datasest)),
                                             test_size=0.1,
                                             random_state=999,
                                             shuffle=True,
                                             stratify=train_datasest.data['target'].to_numpy())

# Subset dataset for train and val
validation_dataset = torch.utils.data.Subset(train_datasest, validation_idx)
train_dataset = torch.utils.data.Subset(train_datasest, train_idx)



# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(validation_dataset, batch_size=128, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=True)

In [5]:
# Load the desired model and train: SGD, lr=0.1
model_path = WORKPATH + '/model_optim_SGD_lr0.1/lightning_logs/version_0/checkpoints/epoch=41-step=1806.ckpt'
test_model = CubClassifier.load_from_checkpoint(model_path)

In [6]:
test_model.eval()

# Calculate the mean accuracy on the test set
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = test_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))     
