In [None]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from iris.data import LandMarkDataset
from iris.models.baseline import BaseLine

import pandas as pd


img_dir="../dataset/train/"
img_metadata = pd.read_csv("../img_metadata_train_dev.csv")
train_img_metadata = img_metadata[img_metadata.iloc[:, 1] == 0][:100]
test_img_metadata = img_metadata[img_metadata.iloc[:, 1] == 0][:100]
img_metadata=(train_img_metadata, test_img_metadata)

trans = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

test_trans = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

batch_size = 12

# Creates dataset and dataloaders
train_ds = LandMarkDataset(
    img_dir, img_metadata[0], trans
)
test_ds = LandMarkDataset(
    img_dir, img_metadata[1], test_trans
)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

model = models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
model.fc = torch.nn.Linear(model.fc.in_features, 21)


NUM_EPOCHS = 20
BEST_MODEL_PATH = 'best_model.pth'
best_accuracy = 0.0
device = "cpu"
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

for epoch in range(NUM_EPOCHS):
    batch = 0
    model.train()
    for images, labels in iter(train_dl):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()
        
        loss, current = loss.item(), batch * len(images)
        print(f"loss: {loss:>7f}  [{current:>5d}/{len(train_dl):>5f}]")
        batch += 1
    
    model.eval()
    test_error_count, correct = 0.0, 0.0
    size = len(test_dl.dataset)
    with torch.no_grad():
        for images, labels in iter(test_dl):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            correct += (outputs.argmax(1) == labels).type(torch.float).sum().item()
            test_error_count += float(torch.sum(torch.abs(labels - outputs.argmax(1))))

    correct /= size
    print(
        f"Test Error: \n Accuracy: {(100*correct):>0.1f}%"
    )

    test_accuracy = 1.0 - float(test_error_count) / float(len(test_ds))
    print('%d: %f' % (epoch, test_accuracy))
    if test_accuracy > best_accuracy:
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        best_accuracy = test_accuracy