# Validation SSL model

In this notebook our goal is to test how good our SSL pretrained weights are. 
- We will query images from different classes and compare embeddings. This will give us better insights for the intraclass/interclass variability.
    - Intraclass variance: variance within one class (The intraclass variance measures the differences between the individual embeddings within each class.)
    - Interclass variance: variance between different classes (The interclass variance measures the differences between the means of each class)

## Imports
- matplotlib for visualisation
- torch

In [1]:
%matplotlib inline

In [2]:
import torch
import torchvision
import pandas
import matplotlib.pyplot as plt
from pathlib import Path
from IPython.display import display

## Reading in pretrained weights

### Option 1: Imagenet pretrained
- Load the best imgnet pretrained weights, docs: https://pytorch.org/vision/stable/models.html
- This is currently ResNet50_Weights.IMAGENET1K_V2 with an accuracy of 80.858%
- weights are saved in /home/olivier/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


In [3]:
#imgnet weights
model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
#torch.save(model.state_dict(),"resnet50_imgnet.pth")
weights = torch.load("resnet50_imgnet.pth")
#print(weights.keys())
#print(model)

### Option 2: SSL pretrained
- VISSL uses the ResNeXT50 class, which is their custom wrapper class
    - ResNeXT50 wrapper class is defined at https://github.com/facebookresearch/vissl/blob/04788de934b39278326331f7a4396e03e85f6e55/vissl/models/trunks/resnext.py
    - ResNet base class https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py for interface of the __init__ method.
    - the model of this wrapper class is a torchvision.models.ResNet() which we will reconstruct here based on the YAML config parameters.


In [4]:
#vissl uses a wrapper class ResNeXT for more flexibility, 
#we will rebuild the torch model here to be able to load the weights from SSL pretraining
resnet_depth = 50 #in YAML config file MODEL.TRUNK.RESNETS.DEPTH: 50

#vissl uses these block configs:
#based on the depth in the YAML the right config is chosen
BLOCK_CONFIG = {
    18: {"layers": (2, 2, 2, 2), "block": torchvision.models.resnet.BasicBlock},
    34: {"layers": (3, 4, 6, 3), "block": torchvision.models.resnet.BasicBlock},
    50: {"layers": (3, 4, 6, 3), "block": torchvision.models.resnet.Bottleneck},
    101: {"layers": (3, 4, 23, 3), "block": torchvision.models.resnet.Bottleneck},
    152: {"layers": (3, 8, 36, 3), "block": torchvision.models.resnet.Bottleneck},
    200: {"layers": (3, 24, 36, 3), "block": torchvision.models.resnet.Bottleneck}
}


#gathering the correct parameters for resnet
(n1, n2, n3, n4) = BLOCK_CONFIG[resnet_depth]["layers"]
block_constructor = BLOCK_CONFIG[resnet_depth]["block"]

#vissl builds their torchvision resnet like this:
pretrained_ssl_model = torchvision.models.ResNet(
    block = block_constructor,
    layers = [n1, n2, n3, n4],
    zero_init_residual = False, #in YAML config file MODEL.TRUNK.RESNETS.ZERO_INIT_RESIDUAL: false
    groups = 1, #in YAML config file MODEL.TRUNK.RESNETS.GROUPS: 1
    width_per_group = 64, #in YAML config file MODEL.TRUNK.RESNETS.WIDTH_PER_GROUP: 64
    norm_layer = torch.nn.BatchNorm2d #in YAML config file MODEL.TRUNK.RESNETS.NORM: BatchNorm
)
#interface of the __init__ method:
#block: Type[Union[BasicBlock, Bottleneck]]
#layers: List[int]
#zero_init_residual: bool = False
#groups: int = 1
#width_per_group: int = 64

#check the model
#print(model)

- checkpoints from pretraining are stored on /home/olivier/Documents/master/mp/checkpoints/sku110k/
    - checkpoints have phase numbers: in VISSL, if the workflow involves training and testing both, the number of phases = train phases + test epochs. So if we alternate train and test, the phase number is: 0 (train), 1 (test), 2 (train), 3 (test)... and train_phase_idx is always: 0 (corresponds to phase0), 1 (correponds to phase 2)
    - The weights are stored 

In [5]:
print("Loading vissl checkpoint")
path_checkpoint = Path("/home/olivier/Documents/master/mp/checkpoints/sku110k/rotnet_full/model_final_checkpoint_phase104.torch")
ssl_checkpoint = torch.load(path_checkpoint)
print("Checkpoint contains:")
dataframe_dict = dict()
dataframe_dict["phase_idx"] = ssl_checkpoint["phase_idx"]
dataframe_dict["iteration_num"] = ssl_checkpoint["iteration_num"]
dataframe_dict["train_phase_idx"] = ssl_checkpoint["train_phase_idx"]
dataframe_dict["iteration"] = ssl_checkpoint["iteration"]
dataframe_dict["type"] = ssl_checkpoint["type"]
df = pandas.DataFrame(data=dataframe_dict.values(), index=dataframe_dict.keys(),columns=["Value"])
display(df)
if("loss", "classy_state_dict" in ssl_checkpoint.keys()):
    print("Checkpoint also contains elements loss and classy_state_dict")
    
print(ssl_checkpoint["classy_state_dict"].keys())
    
#pretrained_ssl_model.load_state_dict(ssl_checkpoint["classy_state_dict"])

Loading vissl checkpoint


  return torch._C._cuda_getDeviceCount() > 0


RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.