In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import matplotlib.pyplot as plt
import numpy as np
import pickle
import random

from collections import defaultdict
from torchvision.models import resnet18
from torchvision.models.resnet import ResNet18_Weights
from torch.optim.lr_scheduler import CosineAnnealingLR
from PIL import Image
from sklearn.model_selection import train_test_split

from src.utils import load_active_learning_embeddings
from src.training.engine import create_resnet18_model, calculate_metrics, train_model, test_model

In [2]:
torch.manual_seed(2004)
random.seed(2004)
np.random.seed(2004)

In [3]:
# BATCH_SIZE = 4
# LR = 0.025 # Same as paper
# MOMENTUM = 0.9 # Same as paper
# USE_NESTEROV = True # Same as paper
# EPOCHS = 100
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# PRINT_INTERVAL = 1000

TRAINING_SETTINGS = {
                    "batch_size": 4,
                    "lr": 0.025,
                    "momentum": 0.9,
                    "use_nesterov": True,
                    "n_epochs": 100,
                    "print_interval": 1000
                    }


MODEL_SAVE_DIR = "models"
CLASSES = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck")

In [4]:
from src.training.engine import TrainingPipeline
from src.training.transforms import get_transform
pipeline = TrainingPipeline(
                            training_settings=TRAINING_SETTINGS,
                            classes=CLASSES,
                            transform=get_transform(),
                            device=DEVICE
                            )

In [5]:
# Load active learning embeddings and build dataset
version = "typiclust"
setting = "top"

variations = [
    {"version": "typiclust", "setting": "top"},
    {"version": "typiclust", "setting": "bottom"},
    {"version": "random", "setting": "top"},
    {"version": "random", "setting": "bottom"}
    ]

directories = os.listdir(f"embeddings/{version}/{setting}")

all_embeddings_dirs = [f"embeddings/{version}/{setting}/{directory}" for directory in directories]

# all_images, all_labels, _ = load_active_learning_embeddings(embeddings_dir)

In [6]:
for variation in variations:
    version = variation["version"]
    setting = variation["setting"]

    directories = os.listdir(f"embeddings/{version}/{setting}")
    all_embeddings_dirs = [f"embeddings/{version}/{setting}/{directory}" for directory in directories]

    for embeddings_dir in all_embeddings_dirs:
        print(f"Executing for version: {version}, setting: {setting} and embeddings_dir: {embeddings_dir}")
        pipeline.execute(version=version, setting=setting, embeddings_dir=embeddings_dir)

Executing for version: typiclust, setting: top and embeddings_dir: embeddings/typiclust/top/1_iterations_B10
Number of training images: 8
Number of validation images: 2
Epoch: 1 |  Batch: 2/2 | Loss: 1.1105365753173828 | LR: 0.025
Epoch: 1
Training metrics
Average train loss: 1.1105
Class: plane | Precision: 0.00 | Recall: 0.00 | F1 Score: 0.00 | Accuracy: 0.00
Class: car | Precision: 0.00 | Recall: 0.00 | F1 Score: 0.00 | Accuracy: 0.00
Class: bird | Precision: 0.00 | Recall: 0.00 | F1 Score: 0.00 | Accuracy: 0.00
Class: cat | Precision: 0.00 | Recall: 0.00 | F1 Score: 0.00 | Accuracy: 0.00
Class: deer | Precision: 0.00 | Recall: 0.00 | F1 Score: 0.00 | Accuracy: 0.00
Class: dog | Precision: 0.50 | Recall: 0.50 | F1 Score: 0.50 | Accuracy: 50.00
Class: frog | Precision: 0.00 | Recall: 0.00 | F1 Score: 0.00 | Accuracy: 0.00
Class: horse | Precision: 0.00 | Recall: 0.00 | F1 Score: 0.00 | Accuracy: 0.00
Class: ship | Precision: 0.00 | Recall: 0.00 | F1 Score: 0.00 | Accuracy: 0.00
Class

In [None]:
raise NotImplementedError("This is not finished yet")

In [None]:
# Separate into train and validation sets:
train_images, val_images, train_labels, val_labels = train_test_split(all_images, all_labels, test_size=0.2, random_state=2004)

print(f"Number of training images: {len(train_images)}")
print(f"Number of validation images: {len(val_images)}")

In [None]:
def count_instances_per_class(labels):
    class_counts = defaultdict(int)
    for label in labels:
        class_counts[CLASSES[label.item()]] += 1
    for c_class in CLASSES:
        if c_class not in class_counts:
            class_counts[c_class] = 0
    return dict(class_counts)

def print_class_counts(class_counts, labels):
    for c_class, count in class_counts.items():
        print(f"Class: {c_class} | Count: {count} | Percentage: {count / len(labels) * 100:.5f}%")

In [None]:
train_class_counts = count_instances_per_class(train_labels)
val_class_counts = count_instances_per_class(val_labels)

print("Training class counts:")
print_class_counts(train_class_counts, train_labels)
print("\n")

print("Validation class counts:")
print_class_counts(val_class_counts, val_labels)

In [None]:
transform = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=0), # Padding not mentioned in paper
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
)

In [None]:
# train_set = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
train_set = CustomDataset(images=train_images, labels=train_labels, transform=transform)
train_dl = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

val_set = CustomDataset(images=val_images, labels=val_labels, transform=transform)
val_dl = torch.utils.data.DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

test_set = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
test_dl = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

In [None]:
print(len(train_dl))
print(len(test_dl))

In [None]:
data = train_dl.__iter__().__next__()
images, labels = data
print(images.shape, labels.shape)

In [None]:
def imshow(img):
    img = img / 2 + 0.5 # Unnormalize
    np_img = img.numpy()
    plt.imshow(np.transpose(np_img, (1, 2, 0)))
    plt.show()

data_iter = iter(train_dl)
images, labels = next(data_iter)

imshow(torchvision.utils.make_grid(images))
print(" ".join(f"{CLASSES[labels[j]]:5s}" for j in range(BATCH_SIZE)))

In [None]:
model = create_resnet18_model()
model = model.to(DEVICE)

In [None]:
criterion = nn.CrossEntropyLoss()
optimiser = optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM, nesterov=USE_NESTEROV)
scheduler = CosineAnnealingLR(optimiser, T_max=EPOCHS) # T_max is the number of epochs

In [None]:
print(len(train_dl), len(test_dl))

In [None]:
num_train_batches = len(train_dl)

In [None]:
model = train_model(
                    model=model,
                    criterion=criterion,
                    optimiser=optimiser,
                    scheduler=scheduler,
                    train_dl=train_dl,
                    val_dl=val_dl,
                    num_epochs=EPOCHS,
                    classes=CLASSES, 
                    device=DEVICE,
                    print_interval=PRINT_INTERVAL
                    )

In [None]:
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
model_path = f"{MODEL_SAVE_DIR}/fully_supervised_model.pth"
torch.save(model.state_dict(), model_path)

In [None]:
data_iter = iter(test_dl)
images, labels = next(data_iter)
imshow(torchvision.utils.make_grid(images))
print(f"Ground truth: {', '.join(f'{CLASSES[labels[j]]:5s}' for j in range(BATCH_SIZE))}")

In [None]:
saved_model = create_resnet18_model()
saved_model.load_state_dict(torch.load(model_path))
saved_model.to(DEVICE)

In [None]:
output = saved_model(images.to(DEVICE))
_, predicted = torch.max(output, 1)

In [None]:
print(f"Predictions: {', '.join(f'{CLASSES[predicted[j]]:5s}' for j in range(BATCH_SIZE))}")

In [None]:
true_positive, false_positive, false_negative, total = test_model(
                                                                model=saved_model,
                                                                test_dl=test_dl,
                                                                classes=CLASSES,
                                                                device=DEVICE
                                                                )

In [None]:
precision, recall, f1_score, accuracy = calculate_metrics(
                                                        true_positive=true_positive, 
                                                        false_positive=false_positive, 
                                                        false_negative=false_negative, 
                                                        total=total,
                                                        classes=CLASSES
                                                        )
total_accuracy = sum(true_positive.values()) / sum(total.values())

In [None]:
for class_name in total.keys():
    print(f"Class name: {class_name}")
    print(f"Accuracy: {accuracy[class_name]:.4f}")
    print(f"Precision: {precision[class_name]:.4f}")
    print(f"Recall: {recall[class_name]:.4f}")
    print(f"F1 Score: {f1_score[class_name]:.4f}")
    print("\n")
print(f"Total accuracy: {total_accuracy:.4f}")