# SIMclr Representation Generation

In [None]:
import torch
from torch import optim, nn
import torchvision
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader

from resnet_wider import resnet50x1
from oc_svm_k9de import K9_OCSVM

from tqdm.autonotebook import tqdm
import matplotlib.pyplot as plt
import gc
import numpy as np
import time as tm
import os
working_dir = os.getcwd()

### Load weights and create DataLoaders

In [None]:
normalize = torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

# Create transformations and dataset
transform = torchvision.transforms.Compose([torchvision.transforms.Resize(256),
                        torchvision.transforms.CenterCrop(224),
                        torchvision.transforms.ToTensor()])

f_mnist_data = FashionMNIST(".", train=True, download=True, transform=transform)
f_mnist_data_test = FashionMNIST(".", train=False, download=True, transform=transform)

In [None]:
weights_file = os.path.join(working_dir,"resnet50-1x.pth")
weights = torch.load(weights_file) # , map_location="cpu" 

model = resnet50x1()
model.load_state_dict(weights["state_dict"])

model.eval()
print(model)

# Initialise feature extractor model, set to eval mode and send to device
device = "cuda:0"
feature_extractor_model = nn.Sequential(*list(model.children())[:-2], nn.AdaptiveAvgPool2d((1,1)))
feature_extractor_model.eval()
feature_extractor_model = feature_extractor_model.to(device)

# Create DataLoaders
batch_sz = 128
train_loader = DataLoader(f_mnist_data, batch_size=batch_sz)
test_loader = DataLoader(f_mnist_data_test, batch_size=batch_sz, shuffle=True)

### SimCLR representation exctractor function

In [None]:
def extract_representations_simclr(extractor_model, data_sz,  dataloader, batch_size, interleave3d=False, vec_size=2048):
    """
      Extract representations based on the process described on LRDOCC paper.
    """

    SINCE = tm.time()

    feature_tensor = np.zeros((data_sz, vec_size))
    label_tensor = np.zeros((data_sz,1))
    
    for i, data in enumerate(tqdm(dataloader,0)):
        images, labels = data
        if interleave3d==True:
            images = images.repeat_interleave(3, dim=1)
        images = normalize(images)
        images = images.cuda()
        
        temp_out_tensor = extractor_model(images)

        temp_out_tensor = temp_out_tensor[:,:,0,0].detach().cpu().numpy()
        feature_tensor[i*batch_size:i*batch_size + batch_size,:] = temp_out_tensor

        temp_out_label = np.expand_dims(labels.numpy(),1)
        label_tensor[i*batch_size:i*batch_size + batch_size,:] = temp_out_label

        gc.collect()
    ELAPSED = tm.time() - SINCE
    print('Feature extraction complete in {:.0f}m {:.0f}s'.format(ELAPSED // 60, ELAPSED % 60))

    return feature_tensor, label_tensor

In [None]:
simclr_fmnist_train, simclr_fmnist_train_lbl = extract_representations_simclr(feature_extractor_model, len(f_mnist_data), train_loader, batch_sz, True)
simclr_fmnist_test, simclr_fmnist_test_lbl = extract_representations_simclr(feature_extractor_model, len(f_mnist_data_test), test_loader, batch_sz, True)

In [None]:
simclr_dataset_path_train = working_dir + "/simcplr_repr_fmnist.npy"
simclr_labels_path_train_lbl = working_dir + "/simclr_labels_fmnist.npy"
simclr_dataset_path_test = working_dir + "/simcplr_repr_fmnist_test.npy"
simclr_labels_path_test_labels = working_dir + "/simclr_labels_fmnist_test_lbl.npy"


np.save(simclr_dataset_path_train, simclr_fmnist_train)
np.save(simclr_labels_path_train_lbl, simclr_fmnist_train_lbl)
np.save(simclr_dataset_path_test, simclr_fmnist_test)
np.save(simclr_labels_path_test_labels, simclr_fmnist_test_lbl)

### Run OC-SVM

In [None]:
# Run OC-SVM
simclr_aucs = K9_OCSVM(simclr_fmnist_train, np.squeeze(simclr_fmnist_train_lbl), simclr_fmnist_test, np.squeeze(simclr_fmnist_test_lbl), kernel_type="rbf")