# SIMclr Representation Generation

In [None]:
import torch
from torch import optim, nn
import torchvision
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader

from resnet_wider import resnet50x1
from oc_svm_k9de import K9_OCSVM

from tqdm.autonotebook import tqdm
import matplotlib.pyplot as plt
import gc
import numpy as np
import time as tm
import os
working_dir = os.getcwd()

## Load pre-trained net weights and data

In [None]:
# Create transformations and dataset
normalize = torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
transform = torchvision.transforms.Compose([torchvision.transforms.Resize(256),
                        torchvision.transforms.CenterCrop(224),
                        torchvision.transforms.ToTensor()])

cifar10_train = torchvision.datasets.CIFAR10(".", train=True, download=True, transform=transform)
cifar10_test = torchvision.datasets.CIFAR10(".", train=False, download=True, transform=transform)


In [None]:
weights_file = os.path.join(working_dir,'resnet50-1x.pth')
weights = torch.load(weights_file) 

model = resnet50x1()
model.load_state_dict(weights["state_dict"])

model.eval()
print(model)

# Initialise feature extractor model, set to eval mode and send to device
device = "cuda:0"
feature_extractor_model = nn.Sequential(*list(model.children())[:-2], nn.AdaptiveAvgPool2d((1,1)))
feature_extractor_model.eval()
feature_extractor_model = feature_extractor_model.to(device)

# Create DataLoaders
batch_sz = 50
train_loader_cifar = DataLoader(cifar10_train, batch_size=batch_sz)
test_loader_cifar = DataLoader(cifar10_test, batch_size=batch_sz, shuffle=True)

### Representation Extraction

In [None]:
def extract_representations_cifar(extractor_model, dataloader, num_samples, vec_size=2048):
  repr_vectors = np.zeros((num_samples, vec_size))
  repr_labels = np.zeros((num_samples,1))

  for i, data in enumerate(tqdm(dataloader,0)):
    images, labels = data
    images = normalize(images)
    images = images.cuda()
    temp_out = extractor_model(images)
    
    temp_out = temp_out[:,:,0,0].detach().cpu().numpy()
    repr_vectors[i*batch_sz:i*batch_sz + batch_sz,:] = temp_out 

    temp_out_lbl = np.expand_dims(labels.numpy(),1)
    repr_labels[i*batch_sz:i*batch_sz + batch_sz,:] = temp_out_lbl

  return repr_vectors, repr_labels

In [None]:
# Extract
cifar10_repr_train, cifar10_repr_train_lbl = extract_representations_cifar(feature_extractor_model, train_loader_cifar, len(cifar10_train))
cifar10_repr_test, cifar10_repr_test_labels = extract_representations_cifar(feature_extractor_model, test_loader_cifar, len(cifar10_test))

# And save
simclr_cifar_train_path = working_dir + "/cifar10_simcplr_repr_train.npy"
simclr_cifar_train_labels_path = working_dir + "/cifar10_simcplr_repr_train_lbl.npy"
simclr_cifar_tst_path = working_dir + "/cifar10_simcplr_repr_test.npy"
simclr_cifar_tst_labels_path = working_dir + "/cifar10_simcplr_repr_test_lbl.npy"

np.save(simclr_cifar_train_path, cifar10_repr_train)
np.save(simclr_cifar_train_labels_path, cifar10_repr_train_lbl)
np.save(simclr_cifar_tst_path, cifar10_repr_test)
np.save(simclr_cifar_tst_labels_path, cifar10_repr_test_labels)

In [None]:
# Run OC-SVM
simclr_aucs = K9_OCSVM(cifar10_repr_train, np.squeeze(cifar10_repr_train_lbl), cifar10_repr_test, np.squeeze(cifar10_repr_test_labels), kernel_type="rbf")