In [None]:
!pip install torch
!pip install torchvision
!pip install sklearn
!pip install tqdm

In [None]:
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split

import torchvision
import torchvision.transforms as transforms

from sklearn.metrics import confusion_matrix

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
root_path = "/content/drive/MyDrive/NUS/CS4243/CS4243_mini_project"
data_path = os.path.join(root_path, "image_data_cleaned_split")
model_root_path = os.path.join(root_path, "models")

In [None]:
# helper function to display images
def imshow(img):
    npimg = img.cpu().numpy()
    plt.figure(figsize=(20, 20))
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

In [None]:
input_size = (299, 299)
batch_size = 32
is_split = True
# train, validation, test
data_split = [0.8, 0.1, 0.1]

# Image transformations
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(20),
    transforms.Resize(input_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

if is_split:
    # Load data that has already been split
    print("Reading from split data...")
    train_path = os.path.join(data_path, "train")
    validation_path = os.path.join(data_path, "validation")
    test_path = os.path.join(data_path, "test")
    datasets = (torchvision.datasets.ImageFolder(x, transform) for x in [train_path, validation_path, test_path])
else:
    # Load and split data
    print("Reading and splitting data...")
    dataset = torchvision.datasets.ImageFolder(data_path, transform)
    n_data = len(dataset)
    n_train = int(n_data * data_split[0])
    n_validation = int(n_data * data_split[1])
    n_test = n_data - n_train - n_validation
    datasets = random_split(dataset, (n_train, n_validation, n_test))
    train_dataloader, validation_dataloader, test_dataloader = (DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=2) for ds in datasets)

train_dataloader, validation_dataloader, test_dataloader = (DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=2) for ds in datasets)
num_train = len(train_dataloader) * batch_size
num_validation = len(validation_dataloader) * batch_size
num_test = len(test_dataloader) * batch_size
print(num_train, "training")
print(num_validation, "validation")
print(num_test, "testing")
print("Total:", num_train + num_validation + num_test)

Reading from split data...
4512 training
416 validation
608 testing
Total: 5536


In [None]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'inception_v3', pretrained=True)
num_classes = len(next(os.walk(data_path))[1])
print(f"Found {num_classes} classes")
model.AuxLogits.fc = nn.Linear(768, num_classes)
model.fc = nn.Linear(2048, num_classes)

Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to /root/.cache/torch/hub/v0.10.0.zip
  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "
Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to /root/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth


  0%|          | 0.00/104M [00:00<?, ?B/s]

Found 3 classes


In [None]:
model_load_path = os.path.join(model_root_path, "inception_ensemble_image_classifier_lr3_e20_elr7")
model.load_state_dict(torch.load(model_load_path))

<All keys matched successfully>

In [None]:
learning_rate = 1e-3 * 0.7 ** 10
lr_decay = 0.7
num_epochs = 10

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=lr_decay)

In [None]:
def get_labels(logit, size):
    return torch.max(logit, dim=1)[1].view(size)

def get_accuracy(logit, target, batch_size):
    corrects = (get_labels(logit, target.size()).data == target.data).sum()
    accuracy = 100.0 * corrects/batch_size
    return accuracy.item()

In [None]:
model_save_path = os.path.join(model_root_path, "inception_ensemble_image_classifier_lr3_e20_elr7")

for epoch in range(num_epochs):

    train_running_loss = 0.0
    train_acc = 0.0

    model = model.train()

    # training steps
    pbar = tqdm(total=len(train_dataloader))
    for i, (images, labels) in enumerate(train_dataloader):
        images = images.to(device)
        labels = labels.to(device)

        # forward + backprop + loss
        logit, aux_logit = model(images)
        loss = criterion(logit, labels)
        aux_loss = criterion(aux_logit, labels)
        loss = loss + 0.4 * aux_loss
        optimizer.zero_grad()
        loss.backward()

        # update model params
        optimizer.step()

        # calc training metrics
        train_running_loss += loss.detach().item()
        train_acc += get_accuracy(logit, labels, batch_size)

        # increment progress bar
        pbar.update(1)

    pbar.close()
    scheduler.step()

    val_running_loss = 0.0
    val_acc = 0.0
    model.eval()

    # validation step
    for j, (images, labels) in enumerate(validation_dataloader):
        images = images.to(device)
        labels = labels.to(device)

        # forward step and loss, no bacckprop
        logit = model(images)
        loss = criterion(logit, labels)

        # calc validation metrics
        val_running_loss += loss.detach().item()
        val_acc += get_accuracy(logit, labels, batch_size)


    print('Epoch: %d | Train Loss: %.4f | Train Accuracy: %.2f | Validation Loss: %.4f | Validation Accuracy: %.2f' \
          %(epoch, train_running_loss/i, train_acc/i, val_running_loss/j, val_acc/j))        

torch.save(model.state_dict(), model_save_path)

100%|██████████| 141/141 [02:25<00:00,  1.03s/it]


Epoch: 0 | Train Loss: 0.1906 | Train Accuracy: 95.74 | Validation Loss: 0.3701 | Validation Accuracy: 90.36


100%|██████████| 141/141 [02:25<00:00,  1.03s/it]


Epoch: 1 | Train Loss: 0.1482 | Train Accuracy: 97.08 | Validation Loss: 0.3433 | Validation Accuracy: 91.15


100%|██████████| 141/141 [02:25<00:00,  1.03s/it]


Epoch: 2 | Train Loss: 0.1435 | Train Accuracy: 97.23 | Validation Loss: 0.2890 | Validation Accuracy: 92.71


100%|██████████| 141/141 [02:24<00:00,  1.02s/it]


Epoch: 3 | Train Loss: 0.1361 | Train Accuracy: 97.50 | Validation Loss: 0.3124 | Validation Accuracy: 92.19


100%|██████████| 141/141 [02:25<00:00,  1.03s/it]


Epoch: 4 | Train Loss: 0.1169 | Train Accuracy: 97.95 | Validation Loss: 0.3573 | Validation Accuracy: 89.84


100%|██████████| 141/141 [02:25<00:00,  1.03s/it]


Epoch: 5 | Train Loss: 0.1111 | Train Accuracy: 98.08 | Validation Loss: 0.5342 | Validation Accuracy: 91.93


100%|██████████| 141/141 [02:24<00:00,  1.02s/it]


Epoch: 6 | Train Loss: 0.1218 | Train Accuracy: 97.97 | Validation Loss: 0.3436 | Validation Accuracy: 90.89


100%|██████████| 141/141 [02:21<00:00,  1.00s/it]


Epoch: 7 | Train Loss: 0.1095 | Train Accuracy: 98.21 | Validation Loss: 0.3629 | Validation Accuracy: 90.10


100%|██████████| 141/141 [02:22<00:00,  1.01s/it]


Epoch: 8 | Train Loss: 0.1126 | Train Accuracy: 98.06 | Validation Loss: 0.4091 | Validation Accuracy: 90.89


100%|██████████| 141/141 [02:24<00:00,  1.03s/it]


Epoch: 9 | Train Loss: 0.1035 | Train Accuracy: 98.42 | Validation Loss: 0.4063 | Validation Accuracy: 92.19


In [None]:
model.eval()

test_acc = 0
total_conf_table = np.zeros((3, 3))

for i, (images, labels) in enumerate(test_dataloader):
    images = images.to(device)
    labels = labels.to(device)

    # forward step
    logit = model(images)
    pred = get_labels(logit, labels.size())
    # calc validation metrics
    test_acc += get_accuracy(logit, labels, batch_size)
    pred_np = pred.cpu().detach().numpy()
    label_np = labels.cpu().detach().numpy()
    conf_table = confusion_matrix(label_np, pred_np, labels=[0, 1, 2])
    total_conf_table += conf_table

print("Test Accuracy: %.2f" %(test_acc/i)) 
print("Confusion Table:")
print(total_conf_table)

Test Accuracy: 93.75
Confusion Table:
[[170.   5.   4.]
 [ 19. 228.  14.]
 [  0.   0. 142.]]
