In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
from torch.utils.data import Dataset, DataLoader, Subset

import numpy as np
import os
import gc
import copy
import time as tm

from tqdm.autonotebook import tqdm
from itertools import chain
import matplotlib.pyplot as plt

working_dir = os.getcwd()
print(os.listdir(working_dir))
print(os.getcwd())

from rot_pred_helper import LRDOCCRotNet18, K9_OCSVM

In [None]:
# Create transformations and dataset
normalize_v2 = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

transform = torchvision.transforms.Compose([torchvision.transforms.Resize(256),
                        torchvision.transforms.CenterCrop(224),
                        torchvision.transforms.ToTensor()])

f_mnist_data = torchvision.datasets.FashionMNIST(".", train=True, download=True, transform=transform)
f_mnist_data_test = torchvision.datasets.FashionMNIST(".", train=False, download=True, transform=transform)

In [None]:
DEVICE="cuda"
HEAD_DIMS = 512
# Load base network and modify to change outputs
resnet18_rot_pred_model = resnet18()
resnet18_rot_pred_model.avgpool = nn.AdaptiveAvgPool2d((1,1))
resnet18_rot_pred_model.fc = nn.Linear(HEAD_DIMS, HEAD_DIMS)

In [None]:
weights_file = os.path.join(working_dir,'rot_net_200e.pth')
weights = torch.load(weights_file)

model = LRDOCCRotNet18(resnet18_rot_pred_model)
model.load_state_dict(weights)
model.eval()

# Initialise feature extractor model, set to eval mode and send to device
device = "cuda:0"
model.projection_head = model.projection_head[:5]
feature_extractor_model = model.to(device)

In [None]:
# Create DataLoaders
batch_sz = 128
train_loader = DataLoader(f_mnist_data, batch_size=batch_sz)
test_loader = DataLoader(f_mnist_data_test, batch_size=batch_sz, shuffle=True)

## Generate Representations

In [None]:
def extract_representations(extractor_model, data_sz,  dataloader, batch_size, interleave3d=False, vec_size=2048):
    """
      Extract representations based on the process described on LRDOCC paper.
    """

    SINCE = tm.time()

    feature_tensor = np.zeros((data_sz, vec_size))
    label_tensor = np.zeros((data_sz,1))
    
    for i, data in enumerate(tqdm(dataloader,0)):
        images, labels = data
        if interleave3d==True:
            images = images.repeat_interleave(3, dim=1)
        images = normalize_v2(images)
        images = images.cuda()
        
        temp_out_tensor = extractor_model(images)

        temp_out_tensor = temp_out_tensor.detach().cpu().numpy()
        feature_tensor[i*batch_size:i*batch_size + batch_size,:] = temp_out_tensor

        temp_out_label = np.expand_dims(labels.numpy(),1)
        label_tensor[i*batch_size:i*batch_size + batch_size,:] = temp_out_label

        gc.collect()
    ELAPSED = tm.time() - SINCE
    print('Feature extraction complete in {:.0f}m {:.0f}s'.format(ELAPSED // 60, ELAPSED % 60))

    return feature_tensor, label_tensor

In [None]:
# Feature extraction for train_set
rot_pred_fmnist_train, rot_pred_fmnist_train_lbl = extract_representations(feature_extractor_model, len(f_mnist_data),train_loader, batch_sz, True, 512)

# Feature extraction for test_set
rot_pred_fmnist_test, rot_pred_fmnist_test_lbl = extract_representations(feature_extractor_model, len(f_mnist_data_test),test_loader, batch_sz, True, 512)

### Save representations

In [None]:
rotclr_dataset_path_train = working_dir + "/rot_net_repr_fmnist_train.npy"
rotclr_labels_path_train_lbl = working_dir + "/rot_net_labels_fmnist_train_lbl.npy"
rotclr_dataset_path_test = working_dir + "/rot_net_repr_fmnist_test.npy"
rotclr_labels_path_test_lbl = working_dir + "/rot_net_labels_fmnist_test_lbl.npy"

np.save(rotclr_dataset_path_train, rot_pred_fmnist_train)
np.save(rotclr_labels_path_train_lbl, rot_pred_fmnist_train_lbl)

np.save(rotclr_dataset_path_test, rot_pred_fmnist_test)
np.save(rotclr_labels_path_test_lbl, rot_pred_fmnist_test_lbl)

### Run OC-SVM on the generated representations

In [None]:
# Run OC-SVM - rot_pred_fmnist_train, rot_pred_fmnist_train_lbl
# rot_pred_fmnist_test, rot_pred_fmnist_test_lbl 
simclr_aucs = K9_OCSVM(rot_pred_fmnist_train, np.squeeze(rot_pred_fmnist_train_lbl), rot_pred_fmnist_test, np.squeeze(rot_pred_fmnist_test_lbl), kernel_type="rbf")