# Initializing google colab enviroment

In [None]:
from google.colab import drive
drive.mount('/content/drive')

! git clone https://github.com/JozifekSVK/dp_vitpose_vicreg.git
! cd dp_vitpose_vicreg/mmcv && python setup.py install
! pip install -v -e dp_vitpose_vicreg/ViTPose/.
! pip install timm einops
! pip install yapf==0.40.1
! pip install umap-learn
! pip install pycocotools


# Importing dependencies

In [None]:
import cv2
import requests
import umap
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from PIL import Image
from pycocotools.coco import COCO
from matplotlib.colors import ListedColormap

import collections
from dp_vitpose_vicreg.ijepa.src.models.vision_transformer import vit_small ### Defining model

IMAGES_IN_CLASS = 50

## Downloading dataset from google drive

In [None]:
! unzip "/content/drive/MyDrive/DP_pose_estimation/Dataset/COCO_dataset/Dataset_dp.zip" -d "/content"

In [None]:
def make_transforms(
    crop_size=224,
    crop_scale=(0.3, 1.0),
    normalization=((0.485, 0.456, 0.406),
                   (0.229, 0.224, 0.225))
):

    transform_list = []
    transform_list += [transforms.RandomResizedCrop(crop_size, scale=crop_scale)]
    transform_list += [transforms.ToTensor()]
    transform_list += [transforms.Normalize(normalization[0], normalization[1])]

    transform = transforms.Compose(transform_list)
    return transform

def get_links_from_category(category_name, coco):
  catIds = coco.getCatIds(catNms=[category_name])
  imgIds = coco.getImgIds(catIds=catIds)
  images = coco.loadImgs(imgIds)

  return images

def download_five_images(links):
  result = []
  for i in range(IMAGES_IN_CLASS):

    session = requests.Session()
    retry = Retry(connect=3, backoff_factor=0.5)
    adapter = HTTPAdapter(max_retries=retry)
    session.mount('http://', adapter)
    session.mount('https://', adapter)

    img_data = session.get(links[i]['coco_url']).content

    result.append(img_data)

  return result

def print_images(images):
  converted_images = []
  for image in images:
    nparr = np.frombuffer(image, np.uint8)
    img = cv2.imdecode(nparr,cv2.IMREAD_UNCHANGED)

    converted_images.append(img)
  
  return converted_images

#### Initializing MS COCO dataset

In [None]:
coco = COCO('/content/Dataset_dp/annotations/instances_train2017.json')

### Downloading images from every definet class

In [None]:
categories = [
  "person", "bus", "boat", "dog", "tennis racket", "banana", "pizza", "cow", "stop sign", "snowboard"

]

images = {}
for categ in categories:
  links = get_links_from_category(categ, coco)
  downloaded_images = download_five_images(links)
  images[categ] = print_images(downloaded_images)


# I-JEPA

In [None]:
ijepa_model = "/content/drive/MyDrive/DP_pose_estimation/pretrained_encoders/vit_small_ijepa.pth"
ijepa = vit_small()
loaded_model = torch.load(ijepa_model,map_location=torch.device('cpu'))

new_state_dict = collections.OrderedDict()
for k, v in loaded_model.items():
    name = k.replace("module.", '')

    if k == 'norm.weight':
      name = 'fc_norm.weight'
    elif k == 'norm.bias':
      name = 'fc_norm.bias'

    new_state_dict[name] = v

ijepa.load_state_dict(new_state_dict)

# VICReg

In [None]:
vicreg_model = "/content/drive/MyDrive/DP_pose_estimation/pretrained_encoders/vit_small_vicreg.pth"
vicreg = vit_small()
loaded_model = torch.load(vicreg_model,map_location=torch.device('cpu'))

new_state_dict = collections.OrderedDict()
for k, v in loaded_model.items():
    name = k.replace("module.", '')

    if k == 'module.norm.weight':
      name = 'fc_norm.weight'
    elif k == 'module.norm.bias':
      name = 'fc_norm.bias'

    new_state_dict[name] = v

# loaded_model.fc_norm
vicreg.load_state_dict(new_state_dict)

# MAE

In [None]:
mae_model = "/content/drive/MyDrive/DP_pose_estimation/pretrained_encoders/mae_backbone_trained.pth"
mae = vit_small()
loaded_mae_model = torch.load(mae_model,map_location=torch.device('cpu'))

new_state_dict = collections.OrderedDict()
for k, v in loaded_mae_model.items():

  if 'decoder' in k:
    continue

  if 'mask_token' in k:
    continue

  name = k.replace("module.", '')

  if k == 'module.norm.weight':
    name = 'fc_norm.weight'
  elif k == 'module.norm.bias':
    name = 'fc_norm.bias'

  new_state_dict[name] = v


mae.load_state_dict(new_state_dict)

# I-JEPA-VICREG

In [None]:
vicreg_ijepa_model = "/content/drive/MyDrive/DP_pose_estimation/pretrained_encoders/vit_small_ijepa_vicreg.pth"
vicreg_ijepa = vit_small()
loaded_vicreg_ijepa_model = torch.load(vicreg_ijepa_model,map_location=torch.device('cpu'))

new_state_dict = collections.OrderedDict()
for k, v in loaded_vicreg_ijepa_model.items():

  if 'decoder' in k:
    continue

  if 'mask_token' in k:
    continue

  name = k.replace("module.", '') # remove `module.`

  if k == 'norm.weight':
    name = 'fc_norm.weight'
  elif k == 'norm.bias':
    name = 'fc_norm.bias'

  new_state_dict[name] = v


vicreg_ijepa.load_state_dict(new_state_dict)

# UMAP

Defining fuctions for UMAP projection. Creating UMAP vizialization for representations.

In [None]:
def make_predictions_and_umap(selected_model):
    trans = make_transforms()
    result_umap = pd.DataFrame()
    calculated_embeddings = torch.empty(0, 196, 384)
    labels = []
    flatten_data = torch.empty(0, 75264)
    for categ in categories:

      images_in_category = []
      for i in range(IMAGES_IN_CLASS):
        PIL_image = Image.fromarray(images[categ][i])

        if 3 not in images[categ][i].shape: ### do not process gray scale images
          continue

        img = trans(PIL_image)
        images_in_category.append(img)
        labels.append(categ)


      input = torch.stack(images_in_category)
      output = selected_model(input).detach()

      output_flatten = torch.flatten( output,start_dim=1)
      flatten_data = torch.cat((flatten_data,output_flatten),axis=0)

      if categ == 'person':
        output_patches_mean = output.mean(axis=1)
        output_patches_std = output.std(axis=1)

        print('Percentage of varience per sample')
        print( (output_patches_std[0,:] / output_patches_mean[0,:].abs()).mean() )

      calculated_embeddings = torch.cat((calculated_embeddings,output),axis=0)

    reducer = umap.UMAP()

    result_umap = reducer.fit_transform(flatten_data)
    result_umap = pd.DataFrame(result_umap)
    result_umap['category'] = labels

    return result_umap, calculated_embeddings

def create_map_plot(result_umap, model_name):  
  colors = ListedColormap(['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'])

  scatter = plt.scatter(
      result_umap[0],
      result_umap[1],
      c=[[x] for x in result_umap.category.map(
          {
              "person":0,
              "bus":1,
              "boat":2,
              "dog":3,
              "tennis racket":4,
              "banana":5,
              "pizza":6,
              "cow":7,
              "stop sign":8,
              "snowboard":9
          })
      ], cmap=colors
      )
  plt.gca().set_aspect('equal', 'datalim')
  plt.legend(*scatter.legend_elements())

  plt.legend(handles=scatter.legend_elements()[0], labels=categories)
  plt.title(f'UMAP projection of the embedded COCO dataset {model_name}', fontsize=24)
  plt.show()

result_umap_vicreg, embeds_vicreg = make_predictions_and_umap(vicreg)
result_umap_ijepa, embeds_ijepa = make_predictions_and_umap(ijepa)
result_umap_vicreg_ijepa, embeds_vicreg_ijepa = make_predictions_and_umap(vicreg_ijepa)
result_umap_mae, embeds_mae = make_predictions_and_umap(mae)

create_map_plot(result_umap_vicreg, "VICReg")
create_map_plot(result_umap_ijepa, "IJEPA")
create_map_plot(result_umap_vicreg_ijepa, "IJEPA VICREG")
create_map_plot(result_umap_mae, "MAE")

# Distance matrix

Defining function for calculating distance matrices and vizualizing them.

In [None]:
def print_dist_matrix(dist_matrix, model_name, min_value=None, max_value=None):
  dist_matrix = np.log(dist_matrix + 1)
  arr_median = dist_matrix.median().item()

  print(f"Median of distance matrix - {arr_median}")
  if min_value is not None:
    map = plt.pcolor(dist_matrix.detach(), cmap='autumn',vmin=min_value, vmax=max_value)
  else:
    map = plt.pcolor(dist_matrix.detach(), cmap='autumn')

  plt.colorbar(map, orientation='vertical')
  plt.title(f"Heatmap for embeddings from {model_name}")
  plt.show()


def calc_dist_matrix( embeddings):
  embeddings_copy = embeddings.clone()
  embeddings_copy = torch.flatten( embeddings_copy,start_dim=1)
  dist_matrix = torch.cdist(embeddings_copy, embeddings_copy, p=2)

  return dist_matrix

### Calculating distance matrices
dist_matrix_vicreg = calc_dist_matrix( embeds_vicreg)
dist_matrix_ijepa = calc_dist_matrix( embeds_ijepa)
dist_matrix_vicreg_ijepa = calc_dist_matrix( embeds_vicreg_ijepa)
dist_matrix_mae = calc_dist_matrix(embeds_mae)

### Calculating min-max values for scaling
concat_distance_matrices = torch.cat((dist_matrix_vicreg,dist_matrix_ijepa, dist_matrix_vicreg_ijepa, dist_matrix_mae),axis=0)
min_value = np.log(concat_distance_matrices.min() + 1)
max_value = np.log(concat_distance_matrices.max() + 1)

print_dist_matrix(dist_matrix_vicreg, "vicreg", min_value, max_value)
print_dist_matrix(dist_matrix_ijepa, "ijepa", min_value, max_value)
print_dist_matrix(dist_matrix_vicreg_ijepa, "vicreg_ijepa", min_value, max_value)
print_dist_matrix(dist_matrix_mae, "mae", min_value, max_value)

# PCA projection

Calculating singular values from PCA projection and visualing them into line plot.

In [None]:
### VICReg
data_to_pca = embeds_vicreg[:,91,:]
res = torch.pca_lowrank(data_to_pca, q=384)
vicreg_pca = res[1]

### I-JEPA
data_to_pca = embeds_ijepa[:,91,:]
res = torch.pca_lowrank(data_to_pca, q=384)
ijepa_pca = res[1]

### I-JEPA-VICReg
data_to_pca = embeds_vicreg_ijepa[:,91,:]
res = torch.pca_lowrank(data_to_pca, q=384)
vicreg_ijepa_pca = res[1]

### MAE
data_to_pca = embeds_mae[:,91,:]
res = torch.pca_lowrank(data_to_pca, q=384)
mae_pca = res[1]

### plot lines
plt.plot(vicreg_pca,label='vicreg')
plt.plot(ijepa_pca,label='ijepa')
plt.plot(vicreg_ijepa_pca,label='vicreg_ijepa')
plt.plot(mae_pca,label='mae')
plt.legend()
plt.show()