# Validation SSL model

In this notebook our goal is to test how good our SSL pretrained weights are. 
- We will query images from different classes and compare embeddings. This will give us better insights for the intraclass/interclass variability.
    - Intraclass variance: variance within one class (The intraclass variance measures the differences between the individual embeddings within each class.)
    - Interclass variance: variance between different classes (The interclass variance measures the differences between the means of each class)
- Note: you need to run this notebook with a kernel in your venv to use vissl libs: https://janakiev.com/blog/jupyter-virtual-envs/#add-virtual-environment-to-jupyter-notebook

## Imports
- matplotlib for visualisation
- torch

In [1]:
%matplotlib inline

In [2]:
import torch
import torchvision
import pandas
import matplotlib.pyplot as plt
from pathlib import Path
from IPython.display import display

## Reading in pretrained weights

### Option 1: Imagenet pretrained
- Load the best imgnet pretrained weights, docs: https://pytorch.org/vision/stable/models.html
- This is currently ResNet50_Weights.IMAGENET1K_V2 with an accuracy of 80.858%
- weights are saved in /home/olivier/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


In [3]:
#imgnet weights
model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
#torch.save(model.state_dict(),"resnet50_imgnet.pth")
weights = torch.load("resnet50_imgnet.pth")
#print(weights.keys())
#print(model)

AttributeError: module 'torchvision.models' has no attribute 'ResNet50_Weights'

### Option 2: SSL pretrained
Load weights from checkpoint according to vissl tutorial:
https://github.com/facebookresearch/vissl/blob/v0.1.6/tutorials/Using_a_pretrained_model_for_inference_V0_1_6.ipynb


In [3]:
from omegaconf import OmegaConf
from vissl.utils.hydra_config import AttrDict
from vissl.utils.hydra_config import compose_hydra_configuration, convert_to_attrdict

# Checkpoint config is located at vissl/configs/config/validation.
# weights are located at 
# All other options override the train_config.yaml config.

cfg = [
  'config=validation/rotnet_full/train_config.yaml',
  'config.MODEL.WEIGHTS_INIT.PARAMS_FILE=/home/olivier/Documents/master/mp/checkpoints/sku110k/rotnet_full/model_final_checkpoint_phase104.torch', # Specify path for the model weights.
  'config.MODEL.FEATURE_EVAL_SETTINGS.EVAL_MODE_ON=True', # Turn on model evaluation mode.
  'config.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_ONLY=True', # Freeze trunk. 
  'config.MODEL.FEATURE_EVAL_SETTINGS.EXTRACT_TRUNK_FEATURES_ONLY=True', # Extract the trunk features, as opposed to the HEAD.
  'config.MODEL.FEATURE_EVAL_SETTINGS.SHOULD_FLATTEN_FEATS=False', # Do not flatten features.
  'config.MODEL.FEATURE_EVAL_SETTINGS.LINEAR_EVAL_FEAT_POOL_OPS_MAP=[["res5avg", ["Identity", []]]]' # Extract only the res5avg features.
]

# Compose the hydra configuration.
cfg = compose_hydra_configuration(cfg)
# Convert to AttrDict. This method will also infer certain config options
# and validate the config is valid.
_, cfg = convert_to_attrdict(cfg)

** Please migrate to the version in iopath repo. **
https://github.com/facebookresearch/iopath 



Now let's build the model with the exact training configs:

In [4]:
from vissl.models import build_model

model = build_model(cfg.MODEL, cfg.OPTIMIZER)

#### Loading the pretrained weights

In [5]:
from classy_vision.generic.util import load_checkpoint
from vissl.utils.checkpoint import init_model_from_consolidated_weights

# Load the checkpoint weights.
weights = load_checkpoint(checkpoint_path=cfg.MODEL.WEIGHTS_INIT.PARAMS_FILE)


# Initializei the model with the simclr model weights.
init_model_from_consolidated_weights(
    config=cfg,
    model=model,
    state_dict=weights,
    state_dict_key_name="classy_state_dict",
    skip_layers=[],  # Use this if you do not want to load all layers
)

print("Weights have loaded")

Weights have loaded


#### Extra info
- VISSL uses the ResNeXT50 class, which is their custom wrapper class
    - ResNeXT50 wrapper class is defined at https://github.com/facebookresearch/vissl/blob/04788de934b39278326331f7a4396e03e85f6e55/vissl/models/trunks/resnext.py
    - ResNet base class https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py for interface of the __init__ method.
    - the model of this wrapper class is a torchvision.models.ResNet() which we will reconstruct here based on the YAML config parameters.
- checkpoints from pretraining are stored on /home/olivier/Documents/master/mp/checkpoints/sku110k/
    - checkpoints have phase numbers: in VISSL, if the workflow involves training and testing both, the number of phases = train phases + test epochs. So if we alternate train and test, the phase number is: 0 (train), 1 (test), 2 (train), 3 (test)... and train_phase_idx is always: 0 (corresponds to phase0), 1 (correponds to phase 2)
    - The weights are stored 

In [38]:
print("Loading vissl checkpoint")
path_checkpoint = Path("/home/olivier/Documents/master/mp/checkpoints/sku110k/rotnet_full/model_final_checkpoint_phase104.torch")
ssl_checkpoint = torch.load(path_checkpoint)
print("Checkpoint contains:")
dataframe_dict = dict()
dataframe_dict["phase_idx"] = ssl_checkpoint["phase_idx"]
dataframe_dict["iteration_num"] = ssl_checkpoint["iteration_num"]
dataframe_dict["train_phase_idx"] = ssl_checkpoint["train_phase_idx"]
dataframe_dict["iteration"] = ssl_checkpoint["iteration"]
dataframe_dict["type"] = ssl_checkpoint["type"]
df = pandas.DataFrame(data=dataframe_dict.values(), index=dataframe_dict.keys(),columns=["Value"])
display(df)
if("loss", "classy_state_dict" in ssl_checkpoint.keys()):
    print("Checkpoint also contains elements loss and classy_state_dict")

#the weights of the trunk resnet network are stored in a nested dict:    
print(ssl_checkpoint["classy_state_dict"]["base_model"]["model"]["trunk"].keys())

Loading vissl checkpoint
Checkpoint contains:


NVIDIA GeForce RTX 4070 Laptop GPU with CUDA capability sm_89 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_37 sm_50 sm_60 sm_70.
If you want to use the NVIDIA GeForce RTX 4070 Laptop GPU GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/



Unnamed: 0,Value
phase_idx,125
iteration_num,4007807
train_phase_idx,104
iteration,3948315
type,consolidated


Checkpoint also contains elements loss and classy_state_dict
odict_keys(['_feature_blocks.conv1.weight', '_feature_blocks.bn1.weight', '_feature_blocks.bn1.bias', '_feature_blocks.bn1.running_mean', '_feature_blocks.bn1.running_var', '_feature_blocks.bn1.num_batches_tracked', '_feature_blocks.layer1.0.conv1.weight', '_feature_blocks.layer1.0.bn1.weight', '_feature_blocks.layer1.0.bn1.bias', '_feature_blocks.layer1.0.bn1.running_mean', '_feature_blocks.layer1.0.bn1.running_var', '_feature_blocks.layer1.0.bn1.num_batches_tracked', '_feature_blocks.layer1.0.conv2.weight', '_feature_blocks.layer1.0.bn2.weight', '_feature_blocks.layer1.0.bn2.bias', '_feature_blocks.layer1.0.bn2.running_mean', '_feature_blocks.layer1.0.bn2.running_var', '_feature_blocks.layer1.0.bn2.num_batches_tracked', '_feature_blocks.layer1.0.conv3.weight', '_feature_blocks.layer1.0.bn3.weight', '_feature_blocks.layer1.0.bn3.bias', '_feature_blocks.layer1.0.bn3.running_mean', '_feature_blocks.layer1.0.bn3.running_var', '

## Extracting features

In [26]:
from PIL import Image
import torchvision.transforms as transforms
import os
import pdb

def extract_features(path):
    image = Image.open(path)
    # Convert images to RGB. This is important
    # as the model was trained on RGB images.
    image = image.convert("RGB")

    # Image transformation pipeline.
    pipeline = transforms.Compose([
      transforms.CenterCrop(224),
      transforms.ToTensor(),
    ])
    x = pipeline(image)

    #unsqueeze adds a dim for batch size (with 1 element the entire input tensor of the image)
    features = model(x.unsqueeze(0))
    #pdb.set_trace()
    features_shape = features[0].shape

    #print(f"Features extracted have the shape: { features_shape }")
    return features[0]

savefile = open("fts.txt","w")
path_to_CornerShop_crops = Path("/home/olivier/Documents/master/mp/CornerShop/CornerShop/crops")

# for crop_dir in os.listdir(path_to_CornerShop_crops):
#     #look for all the folders that are classes of this dataset
#     cdir = path_to_CornerShop_crops / crop_dir
#     if( not(os.path.isdir(cdir)) ):
#         continue #skip non-dirs, these are not classes
#     for img in os.listdir(cdir):
#         #look in all the cropfolders for images to get feature vectors from
#         if( not(img.endswith(".jpg")) ):
#             continue #skip non image files   
#         #print("found image {}".format(img))
#         fts = extract_features(cdir / img) #get feature vector
#         print(fts, file=savefile)

In [27]:
img_paths = list(path_to_CornerShop_crops.glob("*/*.jpg"))#**/*.jpg op alle dieptes van subdirs kijken en itereren

labels = [p.parent.stem for p,_ in zip(img_paths,range(20)) ] #stem attr, naam zonder exentie, name attr met extentie
fts_stack = torch.stack([extract_features(p).squeeze() for p,_ in zip(img_paths,range(20)) ])
print(fts_stack.shape)
print(labels)

torch.Size([20, 2048])
['CawstonDry', 'CawstonDry', 'CawstonDry', 'MinuteMaidAppelPerzik', 'CarrefourSmoothieAardbeiBlauweBessen', 'CarrefourSmoothieAardbeiBlauweBessen', 'CarrefourSmoothieAardbeiBlauweBessen', 'CarrefourSmoothieAardbeiBlauweBessen', 'GiniZeroFles1,5L', 'GiniZeroFles1,5L', 'GiniZeroFles1,5L', 'TropicanaSanguinello', 'TropicanaSanguinello', 'TropicanaSanguinello', 'TropicanaSanguinello', 'TropicanaSanguinello', '7upLemon', '7upLemon', '7upLemon', '7upLemon']


In [29]:
cosin_sim = fts_stack.matmul(fts_stack.T)
print(cosin_sim.shape)
print(cosin_sim)

torch.Size([20, 20])
tensor([[ 87.3110,  83.9168,  87.3473, 109.3133, 135.7349,  77.9418,  95.8140,
         120.6644,  75.7141, 139.5700,  75.3643, 109.6743, 179.2607,  81.3136,
         126.0925,  98.3647, 115.1595,  96.6702, 104.6836,  94.0268],
        [ 83.9168,  88.7915,  87.2473, 110.2363, 137.6209,  79.6288,  97.1384,
         122.2782,  77.9874, 142.5118,  77.8924, 110.9518, 181.8811,  82.6621,
         127.6149, 100.1078, 115.6713,  97.8970, 105.6707,  95.2109],
        [ 87.3473,  87.2473,  93.0548, 112.7572, 141.2955,  81.3216,  99.6091,
         125.5396,  79.4377, 145.5551,  79.1773, 114.1908, 187.3758,  84.7680,
         131.0218, 102.5855, 119.3287, 100.4800, 108.6117,  97.5758],
        [109.3133, 110.2363, 112.7572, 148.6957, 179.3886, 102.5920, 125.8214,
         159.0141,  99.9860, 185.4647,  99.1786, 144.2331, 237.8656, 106.6190,
         166.2170, 129.2002, 150.8825, 126.7883, 137.6863, 123.7811],
        [135.7349, 137.6209, 141.2955, 179.3886, 232.1633, 128.4045

In [36]:
fts_stack_norm = fts_stack / fts_stack.norm(dim=1).unsqueeze(1)
cosim = fts_stack_norm.matmul(fts_stack_norm.T)
cosim.min()

tensor(0.9315)