<a href="https://colab.research.google.com/github/alexgabriel28/thesis/blob/main/VICReg_CNN2_All.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Preliminary installs



In [1]:
from google.colab import drive
drive.mount("/content/drive/")

Mounted at /content/drive/


In [None]:
cd /content/drive/MyDrive/Master_Thesis/package/

In [None]:
!pip install wandb
!pip install lightning-flash

In [None]:
!apt-get install libmagic-dev
!pip install -e .
!pip install -r requirements.txt

In [None]:
cd /content/drive/MyDrive/Master_Thesis/package/

# Imports

In [1]:
import wandb

In [2]:
import torch
import numpy as np
from thesis.helper.dataset import collate_CNN_2

# Dataset and Dataloader

## Dataloader for continual learning

In [3]:
import tensorflow
import torch
from torch.utils.data import DataLoader, Dataset

from thesis.helper import dataset, utils, tensor_img_transforms
from thesis.models.VICRegModel import VICRegCNN_2
from thesis.models import VICRegModel
#from thesis.models import training
from thesis.models import eval
from thesis.helper.utils import save_train_specs
from sklearn.model_selection import train_test_split
#from thesis.helper.dataset import train_test_split
#from thesis.helper.utils import cal_linclf_acc

import os
import numpy as np
import glob

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF

from PIL import Image

from sklearn.model_selection import train_test_split

import augly.image as imaugs

from tqdm.auto import tqdm

from thesis.helper import tensor_img_transforms

class ContImageDataset(Dataset):
  def __init__(self, root_dir: str, classes: list):
    self.root_dir = root_dir
    self.classes = classes
    self.label_list = []
    self.image_list = []
    self.class_map = {}
    cls_list = []
    count = 0
    for class_path in glob.glob(root_dir + "*"):
        cls = class_path.split("/")[-1]
        cls_list.append(cls)
        if count in self.classes:
            for img_path in glob.glob(class_path + "/*.png"):
              img = Image.open(str(img_path)).convert("RGB")
              tensor_image = TF.pil_to_tensor(img)
              self.image_list.append(tensor_image)
              self.label_list.append(cls)
        count += 1
    
    for i, cls in enumerate(cls_list):
        self.class_map[cls] = i

  def __len__(self):
    return len(self.image_list)

  def __getitem__(self, idx):
    class_name = self.label_list[idx]
    if torch.is_tensor(idx):
          idx = idx.tolist()

    self.class_id = self.class_map[class_name]
    self.class_id = torch.tensor([self.class_id])
    return self.image_list[idx], self.class_id

def create_dataloader_cont(
    num_classes: str = 5,
    cls_per_run: int = 2,
    batch_size: int = 128, 
    root_dir: str = "/content/drive/MyDrive/MT Gabriel/data_ext/"
    ) -> torch.utils.data.DataLoader:

    dataloader_list = []
    dataloader_test_list = []
    cls_list = []
    cls_list_per_run = []

    for i in range(num_classes - 1):
        cls_list_per_run = []
        for j in range(cls_per_run):
            ####################################################################
            ### Change 0 to i to receive sequential cls_list, i.e. [0,1,2,3,...]
            cls_list_per_run.extend([i + j])
        cls_list.append(cls_list_per_run)

    for i, classes in enumerate(cls_list):
        image_dir = root_dir
        image_dataset = ContImageDataset(image_dir, classes = classes)

        labels = [label.numpy() for tensor, label in iter(image_dataset)]
        train_indices, test_indices = train_test_split(list(range(len(labels))), test_size=0.2, stratify=labels)
        train_dataset = torch.utils.data.Subset(image_dataset, train_indices)
        test_dataset = torch.utils.data.Subset(image_dataset, test_indices)

        dataloader = DataLoader(
            train_dataset, 
            batch_size = batch_size, 
            shuffle = True, 
            pin_memory = False,
            collate_fn = collate_CNN_2
            )
        dataloader_list.append(dataloader)

        dataloader_test = DataLoader(
            test_dataset, 
            batch_size = batch_size,
            shuffle = True,
            pin_memory = False,
            collate_fn = collate_CNN_2
            )
        dataloader_test_list.append(dataloader_test)
    return dataloader_list, dataloader_test_list

Using device: cuda


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

In [4]:
dataloader_list, dataloader_test_list = create_dataloader_cont()

## Regular DataLoader

In [None]:
import tensorflow
import torch
from torch.utils.data import DataLoader, Dataset

from thesis.helper import dataset, utils, tensor_img_transforms
from thesis.models.VICRegModel import VICRegCNN_2
from thesis.models import VICRegModel
#from thesis.models import training
from thesis.models import eval
from thesis.helper.utils import save_train_specs
from sklearn.model_selection import train_test_split
#from thesis.helper.dataset import train_test_split
#from thesis.helper.utils import cal_linclf_acc

import os
import numpy as np
import glob

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import networkx as nx

from skimage.segmentation import slic, mark_boundaries
from skimage.util import img_as_float, img_as_int
from skimage.future import graph
from skimage.color import gray2rgb
from skimage import measure
from PIL import Image

from sklearn.model_selection import train_test_split

import augly.image as imaugs

from tqdm.auto import tqdm

from thesis.helper import tensor_img_transforms

class ImageDataset(Dataset):
  def __init__(self, root_dir):
    self.root_dir = root_dir
    self.label_list = []
    self.image_list = []
    self.class_map = {}
    cls_list = []

    for class_path in glob.glob(root_dir + "*"):
      cls = class_path.split("/")[-1]
      cls_list.append(cls)
      for img_path in glob.glob(class_path + "/*.png"):
        img = Image.open(str(img_path)).convert("RGB")
        tensor_image = TF.pil_to_tensor(img)
        self.image_list.append(tensor_image)
        self.label_list.append(cls)
    
    for i, cls in enumerate(cls_list):
        self.class_map[cls] = i

  def __len__(self):
    return len(self.image_list)

  def __getitem__(self, idx):
    class_name = self.label_list[idx]
    if torch.is_tensor(idx):
          idx = idx.tolist()

    self.class_id = self.class_map[class_name]
    self.class_id = torch.tensor([self.class_id])
    return self.image_list[idx], self.class_id

image_dir = "/content/drive/MyDrive/MT Gabriel/data_ext/"
image_dataset = ImageDataset(image_dir)

labels = [label.numpy() for tensor, label in iter(image_dataset)]
train_indices, test_indices = train_test_split(list(range(len(labels))), test_size=0.2, stratify=labels)
train_dataset = torch.utils.data.Subset(image_dataset, train_indices)
test_dataset = torch.utils.data.Subset(image_dataset, test_indices)

dataloader = DataLoader(
    train_dataset, 
    batch_size = 128, 
    shuffle = True, 
    pin_memory = False,
    collate_fn = collate_CNN_2
    )

dataloader_test = DataLoader(
    test_dataset, 
    batch_size = 128, 
    shuffle = True, 
    pin_memory = False,
    collate_fn = collate_CNN_2
    )

# Cosine sim with protos

In [None]:
def cosine_sim(
                reps: torch.Tensor,
                labels: torch.Tensor,
                t: float,
                alpha: float,
                protos: torch.Tensor,
                epoch: int,
                num_classes: int,
                ) -> torch.Tensor:
                
    device = "cuda" if torch.cuda.is_available() else "cpu"
    assert (alpha >= 0) & (alpha <=1), "Alpha must be in [0,1]"
    
    # For numerical stability of exp function
    eps = torch.Tensor([1e-08]).to(device)
    Eps = torch.Tensor([1.797693134862315e+308]).to(device)
    
    means = torch.Tensor([]).to(device)
    labels_np = labels.detach().cpu().numpy()

    if epoch < 10:
        for label in range(num_classes): 
            mean = torch.mean((reps[label == labels_np.squeeze()]), 0, keepdim = True)
            means = torch.cat((means, mean), 0)
            labels = torch.cat((labels, torch.Tensor([[label]]).to(device)), 0)
        reps = torch.cat((reps, means), 0).to(device)
    else:
        for label in range(num_classes): 
            mean = torch.mean((reps[label == labels_np.squeeze()]), 0, keepdim = True)
            means = torch.cat((means, mean), 0)
            labels = torch.cat((labels, torch.Tensor([[label]]).to(device)), 0)
        protos += 0.01*means
        reps = torch.cat((reps, protos), 0).to(device)

    # Calculate l2-norms of all vector combinations 
    v_1, v_2 = torch.sum(torch.square(reps), dim = 1).view(reps.size(0), 1), \
        torch.sum(torch.square(reps), dim = 1).view(1, reps.size(0))
    norm_matrix = torch.matmul(torch.sqrt(v_1), torch.sqrt(v_2))

    labels_np = labels.detach().cpu().numpy()

    diff_p = torch.Tensor([0]).to(device)
    diff_n = torch.Tensor([0]).to(device)

    for label in range(num_classes):
        diff_pos = torch.sqrt((reps[label == labels_np.squeeze()] - means[int(label), :]).pow(2).sum(1)).sum()
        diff_neg = torch.sqrt((reps[label != labels_np.squeeze()] -  means[int(label), :]).pow(2).sum(1)).sum()
        diff_p += diff_pos
        diff_n += diff_neg

    # Calculate vector (cosine) similarities and normalize by l2 norms of vectors
    sim = torch.matmul(reps, reps.T)/(torch.max(eps, norm_matrix)*t)
    # Delete "self-loops" from similarity matrix by subtracting diagonal values
    sim = sim - torch.diag(torch.diagonal(sim))
    # Add zero for stability and clamp to float32 values
    sim = torch.clamp(torch.exp(sim + 0), min = eps, max = Eps)


    # Finds which instances are of the same class
    # If cls1 == cls2 -> label_1 - label_2 == 0
    # If cls 1 != cls2 -> abs(label_1 - label_2) >= 0
    pos_mask = (~torch.abs(labels.T - labels).bool()).float()
    neg_mask = (torch.abs(labels.T - labels).bool()).float()

    # Average positive and negative similarities for a batch and weight by alpha
    pos_loss = torch.mean(alpha*(pos_mask*sim))
    neg_loss = torch.mean((1 - alpha)*(neg_mask*sim))
    sim_loss =  (neg_loss - pos_loss).to(reps.device)
    proto_loss = (diff_pos - diff_neg).to(reps.device)/reps.size()[1]

    # Return overall loss
    return sim_loss, proto_loss, means

# Cosine sim loss with fraction and negative selection

In [None]:
def cosine_sim(
                reps: torch.Tensor = None,
                labels: torch.Tensor = None,
                t: float = 0.3,
                alpha: float = 0.3,
                fraction: float = 1.,
                num_classes: int = 5,
                neg_agg_choice: str = "proto",
                neg_selection: bool = True,
                ) -> torch.Tensor:
                
    device = "cuda" if torch.cuda.is_available() else "cpu"
    assert (alpha >= 0) & (alpha <=1), "Alpha must be in [0,1]"
    
    # For numerical stability of exp function
    eps = torch.Tensor([1e-08]).to(device)
    Eps = torch.Tensor([1.797693134862315e+308]).to(device)

    # Calculate l2-norms of all vector combinations 
    v_1, v_2 = torch.sum(torch.square(reps), dim = 1).view(reps.size(0), 1), \
        torch.sum(torch.square(reps), dim = 1).view(1, reps.size(0))

    norm_matrix = norm_matrix_pn = torch.matmul(
                    reps.norm(dim = -1, keepdim = True).view(reps.size(0), 1), 
                    reps.norm(dim = -1, keepdim = True).view(1, reps.size(0))
                    )
    # Use only the fraction of data, specified in fraction -> semi-supervised
    if fraction < 1.:
        erase_v = (torch.rand(size=(reps.size()[0], 1)) < fraction).to(device).float()
        reps = erase_v * reps

    # Calculate vector (cosine) similarities and normalize by l2 norms of vectors
    sim = torch.matmul(reps, reps.T)/(torch.max(eps, norm_matrix)*t)
    # Delete "self-loops" from similarity matrix by subtracting diagonal values
    sim = sim - torch.diag(torch.diagonal(sim))

    # Add zero for stability and clamp to float32 values
    sim = torch.clamp(torch.exp(sim + 0), min = eps, max = Eps)

    # Finds which instances are of the same class
    # If cls1 == cls2 -> label_1 - label_2 == 0
    # If cls 1 != cls2 -> abs(label_1 - label_2) >= 0
    pos_mask = (~torch.abs(labels.T - labels).bool()).float()

    # Average positive and negative similarities for a batch and weight by alpha
    pos_loss = torch.mean(alpha*(pos_mask*sim))
    
    # Use this to calculate the loss, in case only one randomly sampled class
    # acts as neg representations
    if neg_selection == True:
        classes = np.unique(labels.detach().cpu().numpy())
        neg_loss = torch.Tensor([0]).to(device)

        # Iterate over all classes in classes and specify as pos_class
        for i, pos_cls in enumerate(classes):
            neg_class = np.random.choice(classes[classes != pos_cls])
            labels_neg = labels[labels == neg_class]
            neg_reps = reps[labels.squeeze() == neg_class]
            #print(f"neg_reps: {neg_reps}, nr.size: {neg_reps.size()}")
            pos_reps = reps[labels.squeeze() == pos_cls]

            # Calculate neg part of loss from proto from the one sampled class
            if neg_agg_choice == "proto":
                neg_proto = neg_reps.mean(0, keepdim = True)

                proto_norm = torch.sqrt(torch.sum(torch.square(neg_proto), dim = 1).view(neg_proto.size(0), 1))
                reps_norm = torch.sqrt(torch.sum(torch.square(pos_reps), dim = 1).view(pos_reps.size(0), 1))
                sim_neg = pos_reps @ neg_proto.T/(torch.max(eps, proto_norm*reps_norm))
                sim_neg = torch.clamp(torch.exp(sim_neg + 0), min = eps, max = Eps)
                neg_loss += torch.mean((1-alpha)*sim_neg)

            # Calculate neg part of loss with all instances of the randomly
            # selected negative class for each class in classes
            elif neg_agg_choice == "single":
                # Calc norms of vectors
                norm_matrix_pn = torch.matmul(
                    pos_reps.norm(dim = -1, keepdim = True).view(pos_reps.size(0), 1), 
                    neg_reps.norm(dim = -1, keepdim = True).view(1, neg_reps.size(0))
                    )
                # Calc similarities 
                sim_pn = torch.matmul(pos_reps, neg_reps.T)/(torch.max(eps, norm_matrix_pn)*t)
                #print(f"Cosine sim: {sim_pn}")
                #print(sim_pn.size())
                sim_pn = torch.clamp(torch.exp(sim_pn + 0), min = eps, max = Eps)
                #print(f"After exp: {sim_pn}")
                neg_loss += (1-alpha)*torch.mean(sim_pn)
        
        # Use to return 0-dim tensor, fast work-around
        neg_loss = torch.sum(neg_loss)

    else:
        neg_mask = (torch.abs(labels.T - labels).bool()).float()
        neg_loss = torch.mean((1 - alpha)*(neg_mask*sim))

    # Calc final sim loss
    loss =  (neg_loss - pos_loss)
    
    # Return overall loss
    return loss.to(reps.device)

# Cosine loss neg selection, continual and prototypes

In [None]:
def cosine_sim(
                reps: torch.Tensor = None,
                labels: torch.Tensor = None,
                t: float = 0.3,
                alpha: float = 0.3,
                fraction: float = 1.,
                num_classes: int = 5,
                neg_agg_choice: str = "proto",
                neg_selection: bool = True,
                return_proto: bool = True,
                protos: torch.Tensor = None,
                cls_counter: int = 0,
                ) -> torch.Tensor:
                
    device = "cuda" if torch.cuda.is_available() else "cpu"
    assert (alpha >= 0) & (alpha <=1), "Alpha must be in [0,1]"
    
    # For numerical stability of exp function
    eps = torch.Tensor([1e-08]).to(device)
    Eps = torch.Tensor([1.797693134862315e+308]).to(device)

    # Calculate l2-norms of all vector combinations 
    v_1, v_2 = torch.sum(torch.square(reps), dim = 1).view(reps.size(0), 1), \
        torch.sum(torch.square(reps), dim = 1).view(1, reps.size(0))

    norm_matrix = norm_matrix_pn = torch.matmul(
                    reps.norm(dim = -1, keepdim = True).view(reps.size(0), 1), 
                    reps.norm(dim = -1, keepdim = True).view(1, reps.size(0))
                    )
    # Use only the fraction of data, specified in fraction -> semi-supervised
    if fraction < 1.:
        erase_v = (torch.rand(size=(reps.size()[0], 1)) < fraction).to(device).float()
        reps = erase_v * reps

    # Calculate vector (cosine) similarities and normalize by l2 norms of vectors
    sim = torch.matmul(reps, reps.T)/(torch.max(eps, norm_matrix)*t)
    # Delete "self-loops" from similarity matrix by subtracting diagonal values
    sim = sim - torch.diag(torch.diagonal(sim))

    # Add zero for stability and clamp to float32 values
    sim = torch.clamp(torch.exp(sim + 0), min = eps, max = Eps)

    # Finds which instances are of the same class
    # If cls1 == cls2 -> label_1 - label_2 == 0
    # If cls 1 != cls2 -> abs(label_1 - label_2) >= 0
    pos_mask = (~torch.abs(labels.T - labels).bool()).float()

    # Average positive and negative similarities for a batch and weight by alpha
    pos_loss = torch.mean(alpha*(pos_mask*sim))

    ############################################################################
    #pos_loss = torch.mean(alpha*(pos_mask*2*(1-sim)))
    ############################################################################

    # Use this to calculate the loss, in case only one randomly sampled class
    # acts as neg representations
    if neg_selection == True:
        classes = np.unique(labels.detach().cpu().numpy())
        neg_loss = torch.Tensor([0]).to(device)
        neg_proto_loss = torch.Tensor([0]).to(device)

        # Iterate over all classes in classes and specify as pos_class
        for i, pos_cls in enumerate(classes):
            neg_class = np.random.choice(classes[classes != pos_cls])
            labels_neg = labels[labels == neg_class]
            neg_reps = reps[labels.squeeze() == neg_class]
            #print(f"neg_reps: {neg_reps}, nr.size: {neg_reps.size()}")
            pos_reps = reps[labels.squeeze() == pos_cls]

            # Calculate neg part of loss from proto from the one sampled class
            if neg_agg_choice == "proto":
                neg_proto = neg_reps.mean(0, keepdim = True)

                reps.norm(dim = -1, keepdim = True).view(reps.size(0), 1),

                proto_norm = neg_proto.norm(dim=-1, keepdim = True).view(1, neg_proto.size(0))
                reps_norm = pos_reps.norm(dim=-1, keepdim = True).view(pos_reps.size(0), 1)
                norm_matrix = torch.matmul(reps_norm, proto_norm)
                sim_neg = (pos_reps @ neg_proto.T)/(torch.max(eps, norm_matrix*t))
                sim_neg = torch.clamp(torch.exp(sim_neg + 0), min = eps, max = Eps)
                neg_loss += torch.mean((1-alpha)*sim_neg)

            # Calculate neg part of loss with all instances of the randomly
            # selected negative class for each class in classes
            elif neg_agg_choice == "single":
                # Calc norms of vectors
                norm_matrix_pn = torch.matmul(
                    pos_reps.norm(dim = -1, keepdim = True).view(pos_reps.size(0), 1), 
                    neg_reps.norm(dim = -1, keepdim = True).view(1, neg_reps.size(0))
                    )
                # Calc similarities 
                sim_pn = torch.maximum(torch.Tensor([0]).to(device), torch.matmul(pos_reps, neg_reps.T)/(torch.max(eps, norm_matrix_pn)*t) - 0.5)
                #print(f"Cosine sim: {sim_pn}")
                #print(sim_pn.size())
                sim_pn = torch.clamp(torch.exp(sim_pn + 0), min = eps, max = Eps)
                #print(f"After exp: {sim_pn}")

                ################################################################
                #sim_pn = -2*(1-sim_pn)
                ################################################################

                neg_loss += (1-alpha)*torch.mean(sim_pn)

                if cls_counter > 0:
                    neg_proto = protos
                    proto_norm = neg_proto.norm(dim=-1, keepdim = True).view(1, neg_proto.size(0))
                    reps_norm = pos_reps.norm(dim=-1, keepdim = True).view(pos_reps.size(0), 1)
                    norm_matrix = torch.matmul(reps_norm, proto_norm)
                    sim_neg = torch.maximum(torch.Tensor([0]).to(device), (pos_reps @ neg_proto.T)/(torch.max(eps, norm_matrix*t)) - 0.5)
                    sim_pn = torch.clamp(torch.exp(sim_pn + 0), min = eps, max = Eps)
                    ############################################################
                    #sim_neg = -2*(1-sim_neg)
                    ############################################################

                    neg_proto_loss += torch.mean((1-alpha)*sim_neg)
                else:
                    neg_proto_loss = torch.Tensor([0])
        
        # Use to return 0-dim tensor, fast work-around
        neg_loss = torch.sum(neg_loss)
        proto_loss = torch.sum(neg_proto_loss)

    else:
        neg_mask = (torch.abs(labels.T - labels).bool()).float()
        neg_loss = torch.mean((1 - alpha)*(neg_mask*sim))

    # Running batch mean
    batch_mean_cls = reps[labels.squeeze() == cls_counter].mean(0, keepdim = True).to(device)

    # Calc final sim loss
    loss =  (neg_loss - pos_loss)
    ############################################################################
    #loss = (pos_loss + neg_loss)
    ############################################################################
    if return_proto == True:
        protos = torch.cat((protos, batch_mean_cls)).to(device)
        return loss.to(reps.device), proto_loss, protos
    else:
        # Return overall loss
        return loss.to(reps.device), proto_loss

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
reps = torch.randn(size = (20, 15)).to(device)
labels = torch.randint(low = 0, high = 5, size = (20, 1)).to(device)
protos = torch.randn(size = (1, 15)).to(device)
loss, proto_loss = cosine_sim(reps, labels, t = 1, neg_agg_choice = "single", neg_selection = True, protos = protos, return_proto = False)

# Training function for regular cosine_sim function



In [None]:
import os
import torch
import numpy as np
from torch import nn as nn
import torch.optim as optim
from tqdm.auto import tqdm

from IPython.display import clear_output

from thesis.loss import vicreg_loss_fn as vlf
from thesis.loss.similarity_loss import cosine_sim
from thesis.loss.similarity_loss import NCELoss

def train_vicreg_cnn_2(
  model: torch.nn.Module, 
  dataloader: torch.utils.data.DataLoader, 
  epochs: int,
  weight_vicreg: float = 1,
  sim_vicreg: float = 25,
  var_vicreg: float = 25,
  cov_vicreg: float = 1,
  decay_rate_vicreg: float = 0.01,
  decay_steps_vicreg: float = 100,
  weight_sim: float = 0.01,
  decay_rate_sim: float = 0.01,
  decay_steps_sim: float = -100,
  lr: float = 0.001,
  t: float = 1.,
  alpha: float = 0.5,
  alpha_prot: float = 0.3,
  epsilon: float = 0.05,
  instance_weight: float = 1,
  proto_weight: float = 5,
  cel_weight: float = 1,
  dist_weight: float = 500,
  num_classes: float = 3,
  sim_loss_fn: str = "cosine",
  lr_scheduler: str = "exp",
  gamma: float = 0.9,
  ssv_prob: float = 1,
  root_dir = None, **kwargs) -> torch.Tensor:

    """Training step for Self-Supervised Training model with VICReg and Supervised
    Training with Sim loss.

    Args:
      batch(Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where
    [X] is a list of size num_crops containing batches of images.
 
      model: torch.nn.Module: 
      dataloader: torch.utils.data.DataLoader: 
      epochs: int: 
      weight_vicreg: float:  (Default value = 1)
      sim_vicreg: float:  (Default value = 25)
      var_vicreg: float:  (Default value = 25)
      cov_vicreg: float:  (Default value = 1)
      decay_rate_vicreg: float:  (Default value = 0.01)
      decay_steps_vicreg: float:  (Default value = 100)
      weight_sim: float:  (Default value = 0.01)
      decay_rate_sim: float:  (Default value = 0.01)
      decay_steps_sim: float:  (Default value = -100)
      lr: float:  (Default value = 0.001)
      t: float:  (Default value = 1.)
      alpha: float:  (Default value = 0.5)
      alpha_prot: float:  (Default value = 0.3)
      epsilon: float:  (Default value = 0.05)
      instance_weight: float:  (Default value = 1)
      proto_weight: float:  (Default value = 5)
      cel_weight: float:  (Default value = 1)
      dist_weight: float:  (Default value = 500)
      num_classes: float:  (Default value = 3)
      sim_loss_fn: str:  (Default value = "cosine")
      lr_scheduler: str:  (Default value = "exp")
      gamma: float:  (Default value = 0.9)
      ssv_prob: float:  (Default value = 1)

    Returns:
      torch.Tensor: total loss composed of VICReg loss and classification loss.
      Gratefully adapted from: https://github.com/vturrisi/solo-learn

    """

    try:
      device = "cuda" if torch.cuda.is_available() else "cpu"
      model.train()
      model.to(device)

      # Initiate return variables
      loss_list, vicreg_loss_list, sim_loss_list, prototypes_list = [], [], [], []
      prototypes = None
      
      # Define optimizer and scheduler
      optimizer = torch.optim.Adam(model.parameters(), lr = lr)
      if lr_scheduler == "exp":
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = gamma)

      weight_vicreg_init = weight_vicreg
      weight_sim_init = weight_sim

      # Training loop
      for epoch in tqdm(range(epochs)):
          batch_count = batch_loss = vicreg_batch_loss = sim_batch_loss = epoch_loss = 0
          
          # VICReg Loss Decay
          if (bool(decay_rate_vicreg)) & (epoch <= abs(decay_steps_vicreg)):
              weight_vicreg = weight_vicreg_init*decay_rate_vicreg**(epoch/decay_steps_vicreg)
          
          # Sim Loss Decay
          if (bool(decay_rate_sim)) & (epoch <= abs(decay_steps_sim)):
              weight_sim = weight_sim_init*decay_rate_sim**(epoch/decay_steps_sim)

          # Batch loop
          for image_1, image_2, labels in tqdm(dataloader, leave = False):
              
              # Zero grads -> forward pass -> compute loss -> backprop
              optimizer.zero_grad()
              out = model(image_1.float(), image_2.float()).float().squeeze()
              feature_size = out.size()[1]

              labels = labels.view(labels.size(dim = 0), 1).repeat(2, 1)
              
              # Calculate VICReg Loss function
              vicreg_loss = weight_vicreg*vlf.vicreg_loss_func(
                  out[:,:int(feature_size*0.5)],
                  out[:, int(feature_size*0.5):], sim_loss_weight = sim_vicreg,
                  var_loss_weight = var_vicreg, cov_loss_weight = cov_vicreg,
                  ).float()
              
              # Assign features for handling in following loss functions from model
              # outputs
              sim_features = torch.cat(
                              (
                                  out[:,:int(feature_size*0.5)],
                                  out[:, int(feature_size*0.5):]
                                ), dim = 0)
              
              # Implementation of cosine similarity loss
              if sim_loss_fn == "cosine":
                  sim_loss = weight_sim*cosine_sim(
                      sim_features, labels, t, alpha
                      ).float()
              
              #Implementation of NCELoss
              elif sim_loss_fn == "NCELoss":
                  sim_loss = weight_sim*NCELoss(sim_features, labels, t).float()

              #Implementation of prototypical loss
              elif sim_loss_fn == "proto":        
                  sim_loss, instance_loss, proto_loss, \
                  ce_loss, dist_loss, prototypes_updated = proto_sim(
                                reps = sim_features, labels = labels, 
                                prototypes = prototypes, 
                                t = t, alpha = alpha, alpha_prot = alpha_prot, 
                                instance_weight = instance_weight, 
                                proto_weight = proto_weight, dist_weight = dist_weight,
                                cel_weight = cel_weight, num_classes = num_classes,
                                epsilon = epsilon, epoch = epoch
                                )
                  
                  # Reassign prototypes
                  prototypes = prototypes_updated.detach()
                  sim_loss = weight_sim*sim_loss.float().detach()
              
              # Determine the probability with which supervised labels will be used
              semi_sup_ = int(np.random.choice(2, 1, p = [1- ssv_prob, ssv_prob]))
              loss = vicreg_loss + semi_sup_*sim_loss

              loss.backward()
              optimizer.step()

              # Output batch losses
              batch_count += 1
              batch_loss += loss.detach().cpu().numpy()
              vicreg_batch_loss += vicreg_loss.detach().cpu().numpy()
              sim_batch_loss += sim_loss.detach().cpu().numpy()
              print(f"Epoch: {epoch} | Batch_Loss: {loss.detach().cpu().numpy()}")
              
          clear_output()

          # Calculate and log epoch losses
          epoch_loss = batch_loss/batch_count
          vicreg_loss = vicreg_batch_loss/batch_count
          sim_loss = sim_batch_loss/batch_count

          if sim_loss_fn == "proto":
              wandb.log({
                          "loss":epoch_loss, 
                          "vicreg_loss": vicreg_loss, 
                          "sim_loss": sim_loss,
                          "sim_loss_norm": sim_loss/weight_sim,
                          "vicreg_loss_norm": vicreg_loss/weight_vicreg,
                          "weight_vicreg":weight_vicreg, 
                          "weight_sim":weight_sim,
                          "inst_loss":instance_loss,
                          "proto_loss":proto_loss,
                          "ce_loss":ce_loss,
                          "dist_loss":dist_loss,
                          "prototypes":prototypes_updated,
                        })
              prototypes_list.append(prototypes_updated.detach().cpu().numpy())

              
          elif sim_loss_fn == "cosine":
              wandb.log({
                          "loss":epoch_loss, 
                          "vicreg_loss": vicreg_loss, 
                          "sim_loss": sim_loss,
                          "sim_loss_norm": sim_loss/weight_sim,
                          "vicreg_loss_norm": vicreg_loss/weight_vicreg,
                          "weight_vicreg":weight_vicreg, 
                          "weight_sim":weight_sim,
                        })              
          loss_list.append(epoch_loss)
          vicreg_loss_list.append(vicreg_loss)
          sim_loss_list.append(sim_loss)

          print(f"Epoch: {epoch} | Epoch loss: {epoch_loss:.2f}")
            
          # Save model, in case a root_dir is given
          if (epoch > 5) & (root_dir is not None) & (~np.isnan(loss.detach().cpu().numpy())):
              if (loss < loss_list[-2]):
                  PATH = os.path.join(root_dir, f"{run_name}.pt")
                  torch.save({
                      'epoch': epoch,
                      'model_state_dict': model.state_dict(),
                      'optimizer_state_dict': optimizer.state_dict(),
                      'loss': loss,
                    }, PATH) 
        
            
      # Return loss logs and prototypes, in case it's given
      if sim_loss_fn == "proto":
          return loss_list, vicreg_loss_list, sim_loss_list, prototypes_list
      else:
          return loss_list, vicreg_loss_list, sim_loss_list

    except KeyboardInterrupt:
        print("Execution interrupted by user")
        if sim_loss_fn == "proto":
            return loss_list, vicreg_loss_list, sim_loss_list, prototypes_list
        else:
            return loss_list, vicreg_loss_list, sim_loss_list

# Training function for cosine sim with prototypes

In [None]:
import os
import torch
import numpy as np
from torch import nn as nn
import torch.optim as optim
from tqdm.auto import tqdm

from IPython.display import clear_output

from thesis.loss import vicreg_loss_fn as vlf
#from thesis.loss.similarity_loss import cosine_sim
from thesis.loss.similarity_loss import NCELoss

def train_vicreg_cnn_2(
  model: torch.nn.Module, 
  dataloader: torch.utils.data.DataLoader, 
  epochs: int,
  weight_vicreg: float = 1,
  sim_vicreg: float = 25,
  var_vicreg: float = 25,
  cov_vicreg: float = 1,
  decay_rate_vicreg: float = 0.01,
  decay_steps_vicreg: float = 100,
  weight_sim: float = 0.01,
  decay_rate_sim: float = 0.01,
  decay_steps_sim: float = -100,
  lr: float = 0.001,
  t: float = 1.,
  alpha: float = 0.5,
  alpha_prot: float = 0.3,
  epsilon: float = 0.05,
  instance_weight: float = 1,
  proto_weight: float = 5,
  cel_weight: float = 1,
  dist_weight: float = 500,
  num_classes: float = 3,
  sim_loss_fn: str = "cosine",
  lr_scheduler: str = "exp",
  gamma: float = 0.9,
  ssv_prob: float = 1,
  root_dir = None, **kwargs) -> torch.Tensor:

  """Training step for Self-Supervised Training model with VICReg and Sim loss
  Args:
      batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where
          [X] is a list of size num_crops containing batches of images.
      batch_idx (int): index of the batch.
  Returns:
      torch.Tensor: total loss composed of VICReg loss and classification loss.
  Gratefully adapted from: https://github.com/vturrisi/solo-learn
  """

  try:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.train()
    model.to(device)

    # Initiate return variables
    loss_list, vicreg_loss_list, sim_loss_list, proto_loss_list, prototypes_list = [], [], [], [], []
    prototypes = None

    # Define function to access intermediate layer outputs
    activation = {}
    def get_activation(name):
        def hook(model, input, output):
            activation[name] = output.detach()
        return hook
    
    # Define optimizer and scheduler
    optimizer = torch.optim.Adam(model.parameters(), lr = lr)
    if lr_scheduler == "exp":
      scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = gamma)

    weight_vicreg_init = weight_vicreg
    weight_sim_init = weight_sim

    # Training loop
    for epoch in tqdm(range(epochs)):
        batch_count = batch_loss = vicreg_batch_loss = 0
        sim_batch_loss = proto_batch_loss = epoch_loss = 0
        
        # VICReg Loss Decay
        if (bool(decay_rate_vicreg)) & (epoch <= abs(decay_steps_vicreg)):
            weight_vicreg = weight_vicreg_init*decay_rate_vicreg**(epoch/decay_steps_vicreg)
        
        # Sim Loss Decay
        if (bool(decay_rate_sim)) & (epoch <= abs(decay_steps_sim)):
            weight_sim = weight_sim_init*decay_rate_sim**(epoch/decay_steps_sim)

        # Batch loop
        for image_1, image_2, labels in tqdm(dataloader, leave = False):
            model.backbone_1.fc.register_forward_hook(get_activation("fc_1"))
            model.backbone_2.fc.register_forward_hook(get_activation("fc_2"))

            # Zero grads -> forward pass -> compute loss -> backprop
            optimizer.zero_grad()
            out = model(image_1.float(), image_2.float()).float().squeeze()
            feature_size = out.size()[1]

            labels = labels.view(labels.size(dim = 0), 1).repeat(2, 1)
            
            # Calculate VICReg Loss function
            vicreg_loss = weight_vicreg*vlf.vicreg_loss_func(
                out[:,:int(feature_size*0.5)],
                out[:, int(feature_size*0.5):], sim_loss_weight = sim_vicreg,
                var_loss_weight = var_vicreg, cov_loss_weight = cov_vicreg,
                ).float()

            # Assign features for handling in following loss functions from model
            # outputs
            a = activation["fc_1"]
            b = activation["fc_2"]

            sim_features = torch.cat(
                            (
                                a,
                                b
                              ), dim = 0)
            if epoch == 0:
                protos = torch.randn(size = (num_classes, sim_features.size()[1])).to(device)

            # Implementation of cosine similarity loss
            if sim_loss_fn == "cosine":

                sim_loss, proto_loss, protos = cosine_sim(
                    sim_features, labels, t, alpha, protos, 
                    epoch = epoch, num_classes = num_classes
                    )

            #Implementation of NCELoss
            elif sim_loss_fn == "NCELoss":
                sim_loss = weight_sim*NCELoss(sim_features, labels, t).float()

            #Implementation of prototypical loss
            elif sim_loss_fn == "proto":        
                sim_loss, instance_loss, proto_loss, \
                ce_loss, dist_loss, prototypes_updated = proto_sim(
                              reps = sim_features, labels = labels, 
                              prototypes = prototypes, 
                              t = t, alpha = alpha, alpha_prot = alpha_prot, 
                              instance_weight = instance_weight, 
                              proto_weight = proto_weight, dist_weight = dist_weight,
                              cel_weight = cel_weight, num_classes = num_classes,
                              epsilon = epsilon, epoch = epoch
                              )
                # Reassign prototypes
                prototypes = prototypes_updated.detach()
                sim_loss = weight_sim*sim_loss.float().detach()
            
            # Determine the probability with which supervised labels will be used
            #semi_sup_ = int(np.random.choice(2, 1, p = [1- ssv_prob, ssv_prob]))
            sim_loss = weight_sim*sim_loss.float().detach()
            proto_loss = proto_weight*proto_loss.float().detach()
            loss = vicreg_loss + sim_loss + proto_loss

            loss.backward()
            optimizer.step()

            # Output batch losses
            if ~np.isnan(loss.detach().cpu().numpy()):
                batch_count += 1
                batch_loss += loss.detach().cpu().numpy()
                sim_batch_loss += sim_loss.detach().cpu().numpy()
                proto_batch_loss += proto_loss.detach().cpu().numpy()

            vicreg_batch_loss += vicreg_loss.detach().cpu().numpy()

            print(f"Epoch: {epoch} | Batch_Loss: {loss.detach().cpu().numpy()}")
            
        clear_output()

        # Calculate and log epoch losses
        epoch_loss = batch_loss/batch_count
        vicreg_loss = vicreg_batch_loss/batch_count
        sim_loss = sim_batch_loss/batch_count
        proto_loss = proto_batch_loss/batch_count

        if sim_loss_fn == "proto":
            wandb.log({
                        "loss":epoch_loss, 
                        "vicreg_loss": vicreg_loss,
                        "sim_loss": sim_loss,
                        "sim_loss_norm": sim_loss/weight_sim,
                        "vicreg_loss_norm": vicreg_loss/weight_vicreg,
                        "weight_vicreg":weight_vicreg, 
                        "weight_sim":weight_sim,
                        "inst_loss":instance_loss,
                        "proto_loss":proto_loss,
                        "ce_loss":ce_loss,
                        "dist_loss":dist_loss,
                        "prototypes":prototypes_updated,
                      })
            prototypes_list.append(prototypes_updated.detach().cpu().numpy())

            
        elif sim_loss_fn == "cosine":
            wandb.log({
                        "loss":epoch_loss, 
                        "vicreg_loss":vicreg_loss, 
                        "sim_loss":sim_loss,
                        "sim_loss_norm":sim_loss/weight_sim,
                        "vicreg_loss_norm":vicreg_loss/weight_vicreg,
                        "weight_vicreg":weight_vicreg, 
                        "weight_sim":weight_sim,
                        "proto_loss":proto_loss/proto_weight,
                      })              
        loss_list.append(epoch_loss)
        vicreg_loss_list.append(vicreg_loss)
        sim_loss_list.append(sim_loss)
        proto_loss_list.append(proto_loss)

        print(f"Epoch: {epoch} | Epoch loss: {epoch_loss}")

        # Save model, in case a root_dir is given
        if (epoch > 5) & (root_dir is not None) & (~np.isnan(loss.detach().cpu().numpy())):
            if (loss < loss_list[-2]):
                PATH = os.path.join(root_dir, f"{run_name}.pt")
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss,
                  }, PATH) 
          
    # Return loss logs and prototypes, in case it's given
    if sim_loss_fn == "proto":
        return loss_list, vicreg_loss_list, sim_loss_list, prototypes_list
    else:
        return loss_list, vicreg_loss_list, sim_loss_list, proto_loss_list, protos

  except KeyboardInterrupt:
      print("Execution interrupted by user")
      if sim_loss_fn == "proto":
          return loss_list, vicreg_loss_list, sim_loss_list, prototypes_list
      else:
          return loss_list, vicreg_loss_list, sim_loss_list, proto_loss_list, protos

# Training function for cosine sim loss with negative selection and fraction

In [5]:
import os
import torch
import numpy as np
from torch import nn as nn
import torch.optim as optim
from tqdm.auto import tqdm

from IPython.display import clear_output

from thesis.loss import vicreg_loss_fn as vlf
#from thesis.loss.similarity_loss import cosine_sim
#from thesis.loss.similarity_loss import NCELoss

def train_vicreg_cnn_2(
  model: torch.nn.Module, 
  dataloader: torch.utils.data.DataLoader, 
  epochs: int,
  weight_vicreg: float = 1,
  sim_vicreg: float = 25,
  var_vicreg: float = 25,
  cov_vicreg: float = 1,
  decay_rate_vicreg: float = 0.01,
  decay_steps_vicreg: float = 100,
  weight_sim: float = 0.01,
  decay_rate_sim: float = 0.01,
  decay_steps_sim: float = -100,
  lr: float = 0.001,
  t: float = 1.,
  alpha: float = 0.5,
  alpha_prot: float = 0.3,
  epsilon: float = 0.05,
  instance_weight: float = 1,
  proto_weight: float = 5,
  cel_weight: float = 1,
  dist_weight: float = 500,
  num_classes: float = 3,
  sim_loss_fn: str = "cosine",
  lr_scheduler: str = "exp",
  gamma: float = 0.9,
  neg_agg_choice: str = "proto",
  neg_selection: bool = True,
  projected: bool = False,
  fraction: float = 1.,
  root_dir = None, **kwargs) -> torch.Tensor:

    """Training step for Self-Supervised Training model with VICReg and Supervised
    Training with Sim loss.

    Args:
      batch(Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where
    [X] is a list of size num_crops containing batches of images.
 
      model: torch.nn.Module: 
      dataloader: torch.utils.data.DataLoader: 
      epochs: int: 
      weight_vicreg: float:  (Default value = 1)
      sim_vicreg: float:  (Default value = 25)
      var_vicreg: float:  (Default value = 25)
      cov_vicreg: float:  (Default value = 1)
      decay_rate_vicreg: float:  (Default value = 0.01)
      decay_steps_vicreg: float:  (Default value = 100)
      weight_sim: float:  (Default value = 0.01)
      decay_rate_sim: float:  (Default value = 0.01)
      decay_steps_sim: float:  (Default value = -100)
      lr: float:  (Default value = 0.001)
      t: float:  (Default value = 1.)
      alpha: float:  (Default value = 0.5)
      alpha_prot: float:  (Default value = 0.3)
      epsilon: float:  (Default value = 0.05)
      instance_weight: float:  (Default value = 1)
      proto_weight: float:  (Default value = 5)
      cel_weight: float:  (Default value = 1)
      dist_weight: float:  (Default value = 500)
      num_classes: float:  (Default value = 3)
      sim_loss_fn: str:  (Default value = "cosine")
      lr_scheduler: str:  (Default value = "exp")
      gamma: float:  (Default value = 0.9)

    Returns:
      torch.Tensor: total loss composed of VICReg loss and classification loss.
      Gratefully adapted from: https://github.com/vturrisi/solo-learn

    """

    try:
      device = "cuda" if torch.cuda.is_available() else "cpu"
      model.train()
      model.to(device)

      # Initiate return variables
      loss_list, vicreg_loss_list, sim_loss_list, prototypes_list = [], [], [], []
      prototypes = None
      
      # Define optimizer and scheduler
      optimizer = torch.optim.Adam(model.parameters(), lr = lr)
      if lr_scheduler == "exp":
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = gamma)

      weight_vicreg_init = weight_vicreg
      weight_sim_init = weight_sim

      # Define function to access intermediate layer outputs
      activation = {}
      def get_activation(name):
          def hook(model, input, output):
              activation[name] = output.detach()
          return hook

      # Training loop
      for epoch in tqdm(range(epochs)):
          batch_count = batch_loss = vicreg_batch_loss = sim_batch_loss = epoch_loss = 0

          if projected == False:
              # Register forward hook, in case projected == False
              model.backbone_1.fc.register_forward_hook(get_activation("fc_1"))
              model.backbone_2.fc.register_forward_hook(get_activation("fc_2"))

          # VICReg Loss Decay
          if (bool(decay_rate_vicreg)) & (epoch <= abs(decay_steps_vicreg)):
              weight_vicreg = weight_vicreg_init*decay_rate_vicreg**(epoch/decay_steps_vicreg)
          
          # Sim Loss Decay
          if (bool(decay_rate_sim)) & (epoch <= abs(decay_steps_sim)):
              weight_sim = weight_sim_init*decay_rate_sim**(epoch/decay_steps_sim)

          # Batch loop
          for image_1, image_2, labels in tqdm(dataloader, leave = False):
              
              # Zero grads -> forward pass -> compute loss -> backprop
              optimizer.zero_grad()
              out = model(image_1.float(), image_2.float()).float().squeeze()
              feature_size = out.size()[1]

              labels = labels.view(labels.size(dim = 0), 1).repeat(2, 1)
              
              # Calculate VICReg Loss function
              vicreg_loss = weight_vicreg*vlf.vicreg_loss_func(
                  out[:,:int(feature_size*0.5)],
                  out[:, int(feature_size*0.5):], sim_loss_weight = sim_vicreg,
                  var_loss_weight = var_vicreg, cov_loss_weight = cov_vicreg,
                  ).float()
              
              # Assign features for handling in following loss functions from model
              # outputs
              if projected == False:
                  a = activation["fc_1"]
                  b = activation["fc_2"]

                  sim_features = torch.cat(
                                  (
                                      a,
                                      b
                                    ), dim = 0)
              else:
                  sim_features = torch.cat(
                                  (
                                      out[:,:int(feature_size*0.5)],
                                      out[:, int(feature_size*0.5):]
                                    ), dim = 0)
              
              # Implementation of cosine similarity loss
              if sim_loss_fn == "cosine":
                  sim_loss = weight_sim*cosine_sim(
                      sim_features, labels, t, alpha, neg_agg_choice=neg_agg_choice, 
                      neg_selection = neg_selection, fraction = fraction,
                      ).float()
              
              #Implementation of NCELoss
              elif sim_loss_fn == "NCELoss":
                  sim_loss = weight_sim*NCELoss(sim_features, labels, t).float()

              #Implementation of prototypical loss
              elif sim_loss_fn == "proto":        
                  sim_loss, instance_loss, proto_loss, \
                  ce_loss, dist_loss, prototypes_updated = proto_sim(
                                reps = sim_features, labels = labels, 
                                prototypes = prototypes, 
                                t = t, alpha = alpha, alpha_prot = alpha_prot, 
                                instance_weight = instance_weight, 
                                proto_weight = proto_weight, dist_weight = dist_weight,
                                cel_weight = cel_weight, num_classes = num_classes,
                                epsilon = epsilon, epoch = epoch
                                )
                  
                  # Reassign prototypes
                  prototypes = prototypes_updated.detach()
                  sim_loss = weight_sim*sim_loss.float().detach()
              
              # Determine the probability with which supervised labels will be used
              loss = vicreg_loss + sim_loss

              loss.backward()
              optimizer.step()

              # Output batch losses
              batch_count += 1
              batch_loss += loss.detach().cpu().numpy()
              vicreg_batch_loss += vicreg_loss.detach().cpu().numpy()
              sim_batch_loss += sim_loss.detach().cpu().numpy()
              print(f"Epoch: {epoch} | Batch_Loss: {loss.detach().cpu().numpy()}")
              
          clear_output()

          # Calculate and log epoch losses
          epoch_loss = batch_loss/batch_count
          vicreg_loss = vicreg_batch_loss/batch_count
          sim_loss = sim_batch_loss/batch_count

          if sim_loss_fn == "proto":
              wandb.log({
                          "loss":epoch_loss, 
                          "vicreg_loss": vicreg_loss, 
                          "sim_loss": sim_loss,
                          "sim_loss_norm": sim_loss/weight_sim,
                          "vicreg_loss_norm": vicreg_loss/weight_vicreg,
                          "weight_vicreg":weight_vicreg, 
                          "weight_sim":weight_sim,
                          "inst_loss":instance_loss,
                          "proto_loss":proto_loss,
                          "ce_loss":ce_loss,
                          "dist_loss":dist_loss,
                          "prototypes":prototypes_updated,
                        })
              prototypes_list.append(prototypes_updated.detach().cpu().numpy())

              
          elif sim_loss_fn == "cosine":
              wandb.log({
                          "loss":epoch_loss, 
                          "vicreg_loss": vicreg_loss, 
                          "sim_loss": sim_loss,
                          "sim_loss_norm": sim_loss/weight_sim,
                          "vicreg_loss_norm": vicreg_loss/weight_vicreg,
                          "weight_vicreg":weight_vicreg, 
                          "weight_sim":weight_sim,
                        })              
          loss_list.append(epoch_loss)
          vicreg_loss_list.append(vicreg_loss)
          sim_loss_list.append(sim_loss)

          print(f"Epoch: {epoch} | Epoch loss: {epoch_loss}")
            
          # Save model, in case a root_dir is given
          if (epoch > 5) & (root_dir is not None) & (~np.isnan(loss.detach().cpu().numpy())):
              if (loss < loss_list[-2]):
                  PATH = os.path.join(root_dir, f"{run_name}.pt")
                  torch.save({
                      'epoch': epoch,
                      'model_state_dict': model.state_dict(),
                      'optimizer_state_dict': optimizer.state_dict(),
                      'loss': loss,
                    }, PATH) 
        
            
      # Return loss logs and prototypes, in case it's given
      if sim_loss_fn == "proto":
          return loss_list, vicreg_loss_list, sim_loss_list, prototypes_list
      else:
          return loss_list, vicreg_loss_list, sim_loss_list

    except KeyboardInterrupt:
        print("Execution interrupted by user")
        if sim_loss_fn == "proto":
            return loss_list, vicreg_loss_list, sim_loss_list, prototypes_list
        else:
            return loss_list, vicreg_loss_list, sim_loss_list

#Train routine for continual learning setting

In [6]:
import os
import torch
import copy
import numpy as np
from torch import nn as nn
import torch.optim as optim
from tqdm.auto import tqdm

from IPython.display import clear_output

from thesis.loss import vicreg_loss_fn as vlf
#from thesis.loss.similarity_loss import cosine_sim
#from thesis.loss.similarity_loss import NCELoss

def train_vicreg_cnn_2(
  model: torch.nn.Module, 
  dataloader_list: list, 
  epoch_list: int,
  weight_vicreg: float = 1,
  sim_vicreg: float = 25,
  var_vicreg: float = 25,
  cov_vicreg: float = 1,
  decay_rate_vicreg: float = 0.01,
  decay_steps_vicreg: float = 100,
  weight_sim: float = 0.01,
  decay_rate_sim: float = 0.01,
  decay_steps_sim: float = -100,
  lr: float = 0.001,
  t: float = 1.,
  alpha: float = 0.5,
  alpha_prot: float = 0.3,
  epsilon: float = 0.05,
  instance_weight: float = 1,
  proto_weight: float = 5,
  cel_weight: float = 1,
  dist_weight: float = 500,
  num_classes: float = 3,
  sim_loss_fn: str = "cosine",
  lr_scheduler: str = "exp",
  gamma: float = 0.9,
  neg_agg_choice: str = "proto",
  neg_selection: bool = True,
  projected: bool = False,
  fraction: float = 1.,
  root_dir = None, **kwargs) -> torch.Tensor:

    """Training step for Self-Supervised Training model with VICReg and Supervised
    Training with Sim loss.

    Args:
      batch(Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where
    [X] is a list of size num_crops containing batches of images.
 
      model: torch.nn.Module: 
      dataloader: torch.utils.data.DataLoader: 
      epochs: int: 
      weight_vicreg: float:  (Default value = 1)
      sim_vicreg: float:  (Default value = 25)
      var_vicreg: float:  (Default value = 25)
      cov_vicreg: float:  (Default value = 1)
      decay_rate_vicreg: float:  (Default value = 0.01)
      decay_steps_vicreg: float:  (Default value = 100)
      weight_sim: float:  (Default value = 0.01)
      decay_rate_sim: float:  (Default value = 0.01)
      decay_steps_sim: float:  (Default value = -100)
      lr: float:  (Default value = 0.001)
      t: float:  (Default value = 1.)
      alpha: float:  (Default value = 0.5)
      alpha_prot: float:  (Default value = 0.3)
      epsilon: float:  (Default value = 0.05)
      instance_weight: float:  (Default value = 1)
      proto_weight: float:  (Default value = 5)
      cel_weight: float:  (Default value = 1)
      dist_weight: float:  (Default value = 500)
      num_classes: float:  (Default value = 3)
      sim_loss_fn: str:  (Default value = "cosine")
      lr_scheduler: str:  (Default value = "exp")
      gamma: float:  (Default value = 0.9)

    Returns:
      torch.Tensor: total loss composed of VICReg loss and classification loss.
      Gratefully adapted VICReg loss from: https://github.com/vturrisi/solo-learn
    """

    try:
      assert len(epoch_list) == len(dataloader_list), \
      "epoch_list must be of same length as dataloader_list"

      device = "cuda" if torch.cuda.is_available() else "cpu"
      model.train()
      model.to(device)

      # Initiate return variables
      loss_list, vicreg_loss_list, sim_loss_list, prototypes_list = [], [], [], []
      prototypes = None
      
      # Define optimizer and scheduler
      optimizer = torch.optim.Adam(model.parameters(), amsgrad = True, lr = lr)
      if lr_scheduler == "exp":
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = gamma)

      weight_vicreg_init = weight_vicreg
      weight_sim_init = weight_sim

      # Define function to access intermediate layer outputs
      activation = {}
      def get_activation(name):
          def hook(model, input, output):
              activation[name] = output.detach()
          return hook

      # Initialize dataloader and allocate next epoch_limit
      epochs_limit = 0
      epochs_limit_counter = 0
      epochs_limit = epoch_list[epochs_limit_counter]
      dataloader = dataloader_list[epochs_limit_counter]
      dataloader_test = dataloader_list[epochs_limit_counter]
      epochs = epoch_list[-1]

      #Training loop
      for epoch in tqdm(range(epochs)):
          if epoch == epochs_limit:
              model_cop = copy.deepcopy(model)

              cal_linclf_acc(
                  model_cop, dataloader, dataloader_test, protos = None, projected = False,
                  path = "/content/drive/MyDrive/MT Gabriel/model_runs/" + run_name + "_lin_" + str(epochs_limit_counter) + ".txt",
                  wandb_run = None,
                  )
              
              epochs_limit_counter += 1
              num_classes += 1
              epochs_limit = epoch_list[epochs_limit_counter]
              dataloader = dataloader_list[epochs_limit_counter]
              dataloader_test = dataloader_test_list[epochs_limit_counter]
              print(f"Using dataloader no. {epochs_limit_counter}")

          batch_count = batch_loss = vicreg_batch_loss = sim_batch_loss = epoch_loss = 0
          
          # VICReg Loss Decay
          if (bool(decay_rate_vicreg)) & (epoch <= abs(decay_steps_vicreg)):
              weight_vicreg = weight_vicreg_init*decay_rate_vicreg**(epoch/decay_steps_vicreg)
          
          # Sim Loss Decay
          if (bool(decay_rate_sim)) & (epoch <= abs(decay_steps_sim)):
              weight_sim = weight_sim_init*decay_rate_sim**(epoch/decay_steps_sim)

          # Batch loop
          for image_1, image_2, labels in tqdm(dataloader, leave = False):
              # Zero grads -> forward pass -> compute loss -> backprop
              optimizer.zero_grad()
              model.backbone_1.fc.register_forward_hook(get_activation("fc_1"))
              model.backbone_2.fc.register_forward_hook(get_activation("fc_2"))
              
              out = model(image_1.float(), image_2.float()).float().squeeze()
              feature_size = out.size()[1]

              labels = labels.view(labels.size(dim = 0), 1).repeat(2, 1)
              
              # Calculate VICReg Loss function
              vicreg_loss = weight_vicreg*vlf.vicreg_loss_func(
                  out[:,:int(feature_size*0.5)],
                  out[:, int(feature_size*0.5):], sim_loss_weight = sim_vicreg,
                  var_loss_weight = var_vicreg, cov_loss_weight = cov_vicreg,
                  ).float()
              
              # Assign features for handling in following loss functions from model
              # outputs
              if projected == False:
                  a = activation["fc_1"]
                  b = activation["fc_2"]

                  sim_features = torch.cat(
                                  (
                                      a,
                                      b
                                    ), dim = 0)
              else:
                  sim_features = torch.cat(
                                  (
                                      out[:,:int(feature_size*0.5)],
                                      out[:, int(feature_size*0.5):]
                                    ), dim = 0)
              
              # Implementation of cosine similarity loss
              if sim_loss_fn == "cosine":
                  sim_loss = weight_sim*cosine_sim(
                      sim_features, labels, t, alpha, neg_agg_choice=neg_agg_choice, 
                      neg_selection = neg_selection, fraction = fraction,
                      ).float()
              
              #Implementation of NCELoss
              elif sim_loss_fn == "NCELoss":
                  sim_loss = weight_sim*NCELoss(sim_features, labels, t).float()

              #Implementation of prototypical loss
              elif sim_loss_fn == "proto":        
                  sim_loss, instance_loss, proto_loss, \
                  ce_loss, dist_loss, prototypes_updated = proto_sim(
                                reps = sim_features, labels = labels, 
                                prototypes = prototypes, 
                                t = t, alpha = alpha, alpha_prot = alpha_prot, 
                                instance_weight = instance_weight, 
                                proto_weight = proto_weight, dist_weight = dist_weight,
                                cel_weight = cel_weight, num_classes = num_classes,
                                epsilon = epsilon, epoch = epoch
                                )
                  
                  # Reassign prototypes
                  prototypes = prototypes_updated.detach()
                  sim_loss = weight_sim*sim_loss.float().detach()
              
              # Determine the probability with which supervised labels will be used
              loss = vicreg_loss + sim_loss

              loss.backward()
              optimizer.step()

              # Output batch losses
              batch_count += 1
              batch_loss += loss.detach().cpu().numpy()
              vicreg_batch_loss += vicreg_loss.detach().cpu().numpy()
              sim_batch_loss += sim_loss.detach().cpu().numpy()
              print(f"Epoch: {epoch} | Batch_Loss: {loss.detach().cpu().numpy()}")
              
          clear_output()

          # Calculate and log epoch losses
          epoch_loss = batch_loss/batch_count
          vicreg_loss = vicreg_batch_loss/batch_count
          sim_loss = sim_batch_loss/batch_count

          if sim_loss_fn == "proto":
              wandb.log({
                          "loss":epoch_loss, 
                          "vicreg_loss": vicreg_loss, 
                          "sim_loss": sim_loss,
                          "sim_loss_norm": sim_loss/weight_sim,
                          "vicreg_loss_norm": vicreg_loss/weight_vicreg,
                          "weight_vicreg":weight_vicreg, 
                          "weight_sim":weight_sim,
                          "inst_loss":instance_loss,
                          "proto_loss":proto_loss,
                          "ce_loss":ce_loss,
                          "dist_loss":dist_loss,
                          "prototypes":prototypes_updated,
                        })
              prototypes_list.append(prototypes_updated.detach().cpu().numpy())

              
          elif sim_loss_fn == "cosine":
              wandb.log({
                          "loss":epoch_loss, 
                          "vicreg_loss": vicreg_loss, 
                          "sim_loss": sim_loss,
                          "sim_loss_norm": sim_loss/weight_sim,
                          "vicreg_loss_norm": vicreg_loss/weight_vicreg,
                          "weight_vicreg":weight_vicreg, 
                          "weight_sim":weight_sim,
                        })              
          loss_list.append(epoch_loss)
          vicreg_loss_list.append(vicreg_loss)
          sim_loss_list.append(sim_loss)

          print(f"Epoch: {epoch} | Epoch loss: {epoch_loss}")
            
          # Save model, in case a root_dir is given
          if (epoch > 5) & (root_dir is not None) & (~np.isnan(loss.detach().cpu().numpy())):
              if (loss < loss_list[-2]):
                  PATH = os.path.join(root_dir, f"{run_name}.pt")
                  torch.save({
                      'epoch': epoch,
                      'model_state_dict': model.state_dict(),
                      'optimizer_state_dict': optimizer.state_dict(),
                      'loss': loss,
                    }, PATH) 
        
            
      # Return loss logs and prototypes, in case it's given
      if sim_loss_fn == "proto":
          return loss_list, vicreg_loss_list, sim_loss_list, prototypes_list
      else:
          return loss_list, vicreg_loss_list, sim_loss_list

    except KeyboardInterrupt:
        print("Execution interrupted by user")
        if sim_loss_fn == "proto":
            return loss_list, vicreg_loss_list, sim_loss_list, prototypes_list
        else:
            return loss_list, vicreg_loss_list, sim_loss_list

# Training function for continual w/ prototypes

In [None]:
import os
import torch
import copy
import numpy as np
from torch import nn as nn
import torch.optim as optim
from tqdm.auto import tqdm

from IPython.display import clear_output

from thesis.loss import vicreg_loss_fn as vlf
#from thesis.loss.similarity_loss import cosine_sim
#from thesis.loss.similarity_loss import NCELoss

def train_vicreg_cnn_2(
  model: torch.nn.Module, 
  dataloader_list: list,
  dataloader_test_list: list,
  epochs: int,
  epoch_list: int,
  weight_vicreg: float = 1,
  weight_sim: float = 0.01,
  weight_proto: float = 0.1,
  sim_vicreg: float = 25,
  var_vicreg: float = 25,
  cov_vicreg: float = 1,
  decay_rate_vicreg: float = 0.01,
  decay_steps_vicreg: float = 100,
  decay_rate_sim: float = 0.01,
  decay_steps_sim: float = -100,
  lr: float = 0.001,
  t: float = 1.,
  alpha: float = 0.5,
  num_classes: float = 3,
  sim_loss_fn: str = "cosine",
  lr_scheduler: str = "exp",
  gamma: float = 0.9,
  neg_agg_choice: str = "proto",
  neg_selection: bool = True,
  projected: bool = False,
  fraction: float = 1.,
  root_dir = None, **kwargs) -> torch.Tensor:

    """Training step for Self-Supervised Training model with VICReg and Supervised
    Training with Sim loss.

    Args:
      model: torch.nn.Module: A torch model
      dataloader_list: list(torch.utils.data.DataLoader)
          List of dataloaders, giving classes sequentially
          The dataloader has to return a batch of (img_1, img_2, labels)
      dataloader_test_list: list(torch.utils.data.DataLoader)
          List of test dataloaders, giving classes sequentially
          The dataloader has to return a batch of (img_1, img_2, labels)
      epochs: int: Number of epochs to train
      epoch_list: list(int):
          Epoch limits to train on each given dataloader
      weight_vicreg: float:  (Default value = 1)
          The initial weight for the vicreg loss function -> implicitly importance
      sim_vicreg: float:  (Default value = 25)
          The lambda parameter from the VICReg paper -> importance of invar term
      var_vicreg: float:  (Default value = 25)
          The mu parameter from the VICReg paper -> importance of var term
      cov_vicreg: float:  (Default value = 1)
          The eta parameter from the VICReg paper -> importance of cov term
      decay_rate_vicreg: float:  (Default value = 0.01) 
          If decay_steps_sim is specified, the weight_vicreg is updated by
          weight_vicreg = weight_vicreg_init*decay_rate_vicreg**(epoch/decay_steps_vicreg)
      decay_steps_vicreg: float:  (Default value = 100)
          Specifies the steps to update weight_vicreg, if decay_steps_vicreg is 
          specified -> see (decay_rate_vicreg)
      weight_sim: float:  (Default value = 0.01) analogous to weight_vicreg
      decay_rate_sim: float:  (Default value = 0.01) analogous to decay_rate_sim
      decay_steps_sim: float:  (Default value = -100) 
          analogous to decay_steps_sim; default is negative to increase importance
          during training
      lr: float:  (Default value = 0.001) Learning Rate
      t: float:  (Default value = 1.) 
          Temperature parameter for the cosine sim loss -> exp(sim/t)
      alpha: float:  (Default value = 0.5)
          Specifies the importance of negative and positive samples
      num_classes: float:  (Default value = 3)
          Specifies the number of classes
      sim_loss_fn: str:  (Default value = "cosine")
          Specifies the type of loss function
      lr_scheduler: str:  (Default value = "exp")
          Specifies the learning rate scheduler
      gamma: float:  (Default value = 0.9)
          Specifies the gamma value of the learning rate scheduler
      neg_agg_choice: str: (Default value = "proto")
          Specifies how the negatives should be aggregated: ["proto", "single"]
      neg_selection: bool: (Default value = True)
          Specifies whether a negative class is selected, or training is executed
          on the entire dataset
      projected: bool: (Default value = False)
          Specifies whether the similarity loss should be calculated on the 
          embeddings of the Backbone models, or the projected features
      fraction: float: (Default value = 1.) 
          Ratio of data per batch to be used for the similarity calculation
      root_dir: str: (Default value = None) Directory to write data to

    Returns:
      torch.Tensor: total loss composed of VICReg loss and classification loss.
      Gratefully adapted VICReg loss from: https://github.com/vturrisi/solo-learn
    """
    try:
      assert len(epoch_list) == len(dataloader_list), \
      "epoch_list must be of same length as dataloader_list"

      # Move model to gpu and set to train mode
      device = "cuda" if torch.cuda.is_available() else "cpu"
      model.train()
      model.to(device)

      # Initiate return variables
      loss_list, vicreg_loss_list, sim_loss_list, prototypes_list = [], [], [], []
      prototypes = None
      
      # Define optimizer and scheduler
      optimizer = torch.optim.Adam(model.parameters(), amsgrad = True, lr = lr)
      if lr_scheduler == "exp":
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = gamma)

      # Initialize weights for vicreg and sim_loss of weight decay
      weight_vicreg_init = weight_vicreg
      weight_sim_init = weight_sim

      # Define function to access intermediate layer outputs
      activation = {}
      def get_activation(name):
          def hook(model, input, output):
              activation[name] = output.detach()
          return hook

      # Initialize dataloader and allocate next epoch_limit
      epochs_limit = 0
      epochs_limit_counter = 0

      # Initialize first epoch_limit
      epochs_limit = epoch_list[epochs_limit_counter]
      
      #Initialize first dataloader
      dataloader = dataloader_list[epochs_limit_counter]
      dataloader_test = dataloader_test_list[epochs_limit_counter]
      #epochs = epoch_list[-1]

      # Set return_proto to False -> First proto is returned after reaching first 
      # epoch_limit -> initialize protos to empty tensor
      return_proto = False
      protos = torch.Tensor().to(device)

      #Training loop
      for epoch in tqdm(range(epochs)):
          torch.cuda.empty_cache()

          # Test model with linear classifier after training on each dataloader
          if (epoch == epochs_limit) & (epoch != epoch_list[-1]):
              #model_cop = copy.deepcopy(model)
              # If first dataloader, no protos are yet available
              # if epochs_limit_counter == 0:
              #     protos_clf = None
              # else:
              #protos_clf = copy.deepcopy(protos)
              # cal_linclf_acc(
              #     model_cop, dataloader, dataloader_test, 
              #     protos = protos_clf, projected = False, num_classes = num_classes,
              #     path = "/content/drive/MyDrive/MT Gabriel/model_runs/" + run_name + "_lin_" + str(epochs_limit_counter) + ".txt",
              #     wandb_run = None,
              #     )
              
              # Update the counters, re-allocate dataloader and dataloader_test
              # Specify new epoch_limit
              epochs_limit_counter += 1
              num_classes += 1
              epochs_limit = epoch_list[epochs_limit_counter]
              dataloader = dataloader_list[epochs_limit_counter]
              dataloader_test = dataloader_test_list[epochs_limit_counter]
              print(f"Using dataloader no. {epochs_limit_counter}")

          # Set return_proto to True, to get the prototype of the next class
          if epoch == epochs_limit - 1:
              return_proto = True

          # Return last prototype
          elif epoch == epoch_list[-1]:
              return_proto = True

          # Initialize losses and counts
          batch_count = batch_loss = vicreg_batch_loss = sim_batch_loss = proto_batch_loss = epoch_loss = 0
          
          # VICReg Loss Decay Update
          if (bool(decay_rate_vicreg)) & (epoch <= abs(decay_steps_vicreg)):
              weight_vicreg = weight_vicreg_init*decay_rate_vicreg**(epoch/decay_steps_vicreg)
          
          # Sim Loss Decay Update
          if (bool(decay_rate_sim)) & (epoch <= abs(decay_steps_sim)):
              weight_sim = weight_sim_init*decay_rate_sim**(epoch/decay_steps_sim)

          # Batch loop
          for image_1, image_2, labels in tqdm(dataloader, leave = False):
              # Zero grads -> forward pass -> compute loss -> backprop
              optimizer.zero_grad()

              # Register forward hook, in case projected == False
              model.backbone_1.fc.register_forward_hook(get_activation("fc_1"))
              model.backbone_2.fc.register_forward_hook(get_activation("fc_2"))
              
              # Calculate output features; note that output is of the form
              # B x 2D, i.e. the embedding vector of the two views of the same
              # image are concatenated
              out = model(image_1.float(), image_2.float()).float().squeeze()

              # Retrieve output features in order to concatenate features later on
              feature_size = out.size()[1]
              labels = labels.view(labels.size(dim = 0), 1).repeat(2, 1)
              
              # Calculate VICReg Loss function
              vicreg_loss = weight_vicreg*vlf.vicreg_loss_func(
                  out[:,:int(feature_size*0.5)],
                  out[:, int(feature_size*0.5):], sim_loss_weight = sim_vicreg,
                  var_loss_weight = var_vicreg, cov_loss_weight = cov_vicreg,
                  ).float()
              
              # Assign features for handling in following loss functions from model
              # outputs, depending on projected == {True; False}
              if projected == False:
                  a = activation["fc_1"]
                  b = activation["fc_2"]

                  sim_features = torch.cat(
                                  (
                                      a,
                                      b
                                    ), dim = 0)
              else:
                  sim_features = torch.cat(
                                  (
                                      out[:,:int(feature_size*0.5)],
                                      out[:, int(feature_size*0.5):]
                                    ), dim = 0)

              # Implementation of cosine similarity loss
              if sim_loss_fn == "cosine":
                  # Return sim_loss, proto_loss and the protos, in case 
                  # epoch == epochs_limit - 1
                  if return_proto == True:
                      sim_loss, proto_loss, protos = cosine_sim(
                          sim_features, labels, t, alpha, neg_agg_choice=neg_agg_choice,
                          neg_selection = neg_selection, fraction = fraction, return_proto = return_proto, cls_counter = epochs_limit_counter,
                          protos = protos,
                          )
                      # Set return_proto to False
                      return_proto = False
                      sim_loss, proto_loss = weight_sim*sim_loss.float(), weight_proto*proto_loss.float()

                  else:
                      # Return sim_loss and proto_loss in case return_proto = False
                      sim_loss, proto_loss = cosine_sim(
                          sim_features, labels, t, alpha, neg_agg_choice=neg_agg_choice, return_proto = return_proto,
                          protos = protos, neg_selection = neg_selection, fraction = fraction, cls_counter = epochs_limit_counter,
                          )
                      sim_loss, proto_loss = weight_sim*sim_loss.float(), weight_proto*proto_loss.float()
              
              # Determine the probability with which supervised labels will be used
              loss = vicreg_loss + sim_loss + proto_loss

              loss.backward()
              optimizer.step()

              # Output batch losses
              batch_count += 1
              batch_loss += loss.detach().cpu().numpy()
              vicreg_batch_loss += vicreg_loss.detach().cpu().numpy()
              sim_batch_loss += sim_loss.detach().cpu().numpy()
              proto_batch_loss += proto_loss.detach().cpu().numpy()
              print(f"Epoch: {epoch} | Batch_Loss: {loss.detach().cpu().numpy()}")
              
          clear_output()

          # Calculate and log epoch losses
          epoch_loss = batch_loss/batch_count
          vicreg_loss = vicreg_batch_loss/batch_count
          sim_loss = sim_batch_loss/batch_count
          proto_loss = proto_batch_loss/batch_count
              
          if sim_loss_fn == "cosine":
              wandb.log({
                          "loss":epoch_loss, 
                          "vicreg_loss": vicreg_loss, 
                          "sim_loss": sim_loss,
                          "proto_loss":proto_loss,
                          "sim_loss_norm": sim_loss/weight_sim,
                          "vicreg_loss_norm": vicreg_loss/weight_vicreg,
                          "weight_vicreg":weight_vicreg, 
                          "weight_sim":weight_sim,
                          "weight_proto":weight_proto,
                        })              
          loss_list.append(epoch_loss)
          vicreg_loss_list.append(vicreg_loss)
          sim_loss_list.append(sim_loss)

          print(f"Epoch: {epoch} | Epoch loss: {epoch_loss}")
            
          # Save model, in case a root_dir is given
          if (epoch > 5) & (root_dir is not None) & (~np.isnan(loss.detach().cpu().numpy())):
              if (loss < loss_list[-2]):
                  PATH = os.path.join(root_dir, f"{run_name}.pt")
                  torch.save({
                      'epoch': epoch,
                      'model_state_dict': model.state_dict(),
                      'optimizer_state_dict': optimizer.state_dict(),
                      'loss': loss,
                    }, PATH) 
        
            
      # Return loss logs and prototypes, in case it's given
      if sim_loss_fn == "proto":
          return loss_list, vicreg_loss_list, sim_loss_list, prototypes_list
      else:
          return loss_list, vicreg_loss_list, sim_loss_list, protos

    except KeyboardInterrupt:
        print("Execution interrupted by user")
        if sim_loss_fn == "proto":
            return loss_list, vicreg_loss_list, sim_loss_list, prototypes_list
        else:
            return loss_list, vicreg_loss_list, sim_loss_list

# Linear Classifier Function for Evaluation

In [7]:
from matplotlib.rcsetup import validate_backend
import torch
from torch.utils.data import DataLoader, Dataset
from thesis.helper import utils

from typing import Any
from sklearn.svm import SVC
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier, NearestCentroid
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score, classification_report
from sklearn.preprocessing import OneHotEncoder

def cal_linclf_acc(model: torch.nn.Module = None, 
                  train_dataloader: torch.utils.data.DataLoader = None,
                  test_dataloader: torch.utils.data.DataLoader = None,
                  num_heads: int = 2,
                  head_no: int = 1,
                  clfs: dict = {
                       "KNeighbors": KNeighborsClassifier(),
                       "NearestCentroid": NearestCentroid(),
                       "SVC": SVC(gamma = "auto")
                       }, 
                  train_sz: float = 0.8,
                  num_classes: int = 5,
                  projected: bool = True,
                  protos: torch.Tensor = None,
                  wandb_run: Any = wandb.run,
                  path: str = None,
                  ) -> [Any]:

    utils.set_parameter_requires_grad(model, False)
    if projected == True:
        assert num_heads == 2, \
        "When projection == True, both heads must be used (num_heads == 2)"

    #Calculate embeddings:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.eval()
    model.to(device)

    outs = torch.Tensor().to(device)
    #outs_train = torch.Tensor().to(device)
    outs_test = torch.Tensor().to(device)

    labels = torch.Tensor().to(device)
    #labels_train = torch.Tensor().to(device)
    labels_test = torch.Tensor().to(device)

    if num_heads ==2:
        if projected == True:
            for data, data_2, label in train_dataloader:
                torch.no_grad()
                outs = torch.cat((outs, model(data.float(),data_2.float()).squeeze()), 0)
                labels = torch.cat((labels, label), 0)

            for data, data_2, label in test_dataloader:
                torch.no_grad()
                outs_test = torch.cat((outs_test, model(data.float(),data_2.float()).squeeze()), 0)
                labels_test = torch.cat((labels_test, label), 0)
            
        else:
            activation = {}
            def get_activation(name):
                def hook(model, input, output):
                    activation[name] = output.detach()
                return hook

            for data, data_2, label in train_dataloader:
                torch.no_grad()
                model.backbone_1.fc.register_forward_hook(get_activation("fc_1"))
                model.backbone_2.fc.register_forward_hook(get_activation("fc_2"))
                outs_proj = model(data.float(), data_2.float()).squeeze()

                a = activation["fc_1"]
                b = activation["fc_2"]
                features = torch.cat((a, b), 1)
                outs = torch.cat((outs, features), 0)
                labels = torch.cat((labels, label), 0)
            
            for data, data_2, label in test_dataloader:
                torch.no_grad()
                model.backbone_1.fc.register_forward_hook(get_activation("fc_1"))
                model.backbone_2.fc.register_forward_hook(get_activation("fc_2"))
                outs_proj = model(data.float(),data_2.float()).squeeze()

                a = activation["fc_1"]
                b = activation["fc_2"]
                features = torch.cat((a, b), 1)
                outs_test = torch.cat((outs_test, features), 0)
                labels_test = torch.cat((labels_test, label), 0)

        feature_size = 0.5*outs.size()[1]

        outs_cpu = outs.detach().cpu()
        outs_cpu = torch.cat((outs_cpu[:, :int(feature_size)], outs_cpu[:, int(feature_size):]), 0).detach().cpu().numpy()

        outs_test_cpu = outs_test.detach().cpu()
        outs_test_cpu = torch.cat((outs_test_cpu[:, :int(feature_size)], outs_test_cpu[:, int(feature_size):]), 0).detach().cpu().numpy() 

        labels_cpu = labels.detach().cpu().repeat(2).numpy()
        labels_test_cpu = labels_test.detach().cpu().repeat(2).numpy()


    else:
        if head_no == 1:
            model_1 = model.backbone_1
        else:
            model_1 = model.backbone_2

        for data, data_2, label in train_dataloader:
            torch.no_grad()

            outs = torch.cat((outs, model_1(torch.cat((data.float(), data_2.float()))).squeeze()), 0)
            labels = torch.cat((labels, label.repeat(2)), 0)

        for data, data_2, label in test_dataloader:
            torch.no_grad()

            outs_test = torch.cat((outs_test, model_1(torch.cat((data.float(),data_2.float()))).squeeze()), 0)
            labels_test = torch.cat((labels_test, label.repeat(2)), 0)

        outs_cpu = outs.detach().cpu().numpy()
        outs_test_cpu = outs_test.detach().cpu().numpy()

        labels_cpu = labels.detach().cpu().numpy()
        labels_test_cpu = labels_test.detach().cpu().numpy()

    # One Hot Encoder for ROC-AUC measure
    ohe = OneHotEncoder()
    ohe.fit(labels_cpu.reshape(-1, 1))

    # Define Prototypes, if not available (mean)
    proto = []

    if protos is None:
        for label in np.unique(labels_cpu):
              proto.append(np.mean(outs_cpu[labels_cpu == label], 0))
        proto = np.array(proto)
    else:
        proto = protos

    # Prototypical Proximity Evaluation
    ## Euclid Distance Loss
    count, correct = 0, 0
    min_dists = []

        
    correct_dist_list = [0. for i in range(num_classes)]
    count_dist_list = [0. for i in range(num_classes)]

    for i, instance in enumerate(outs_test_cpu):
        if protos is not None:
            distances = np.sum((proto.detach().cpu().numpy() - instance)**2, 1)
        else:
            distances = np.sum((proto - instance)**2, 1)

        min_dist = np.argmin(distances)
        min_dists.append(min_dist)

        if min_dist == labels_test_cpu[i]:
            correct += 1
            correct_dist_list[int(labels_test_cpu[i])] += 1
            count_dist_list[int(labels_test_cpu[i])] += 1
        else:
            count_dist_list[int(labels_test_cpu[i])] += 1
        count += 1
    
    dist_acc_list = []
    for correct, count in zip(correct_dist_list, count_dist_list):
        if count != 0:
            dist_acc_list.append(correct/count)

    ## Cosine Sim Loss
    count_sim, correct_sim = 0, 0
    max_sims = []
    correct_sim_list = [0. for i in range(num_classes)]
    count_sim_list = [0. for i in range(num_classes)]

    for i, instance in enumerate(outs_test_cpu):
        cos = torch.nn.CosineSimilarity(dim = 1)

        if protos is not None:
            sim = cos(proto.detach().cpu(), torch.Tensor(instance))
        else:
            sim = cos(torch.Tensor(proto), torch.Tensor(instance))

        max_sim = torch.argmax(sim)
        max_sims.append(max_sim.detach().cpu().numpy())

        if max_sim == labels_test_cpu[i]:
            correct_sim += 1
            correct_sim_list[int(labels_test_cpu[i])] += 1
            count_sim_list[int(labels_test_cpu[i])] += 1
        else:
            count_sim_list[int(labels_test_cpu[i])] += 1
        count_sim += 1
    
    sim_acc_list = []
    for correct, count in zip(correct_sim_list, count_sim_list):
        if count != 0:
            sim_acc_list.append(correct/count)

    acc_euclid = np.sum(dist_acc_list) / len(dist_acc_list)
    acc_sim = np.sum(sim_acc_list) / len(sim_acc_list)
    
    print(sim_acc_list)
    print(dist_acc_list)

    prototypical_loss = {"Acc_euclid": acc_euclid, "Acc_sim": acc_sim}
    
    with open(path, "a") as f:
        for k, v in clfs.items():
          pipeline = make_pipeline(StandardScaler(), v)
          pipeline.fit(outs_cpu, labels_cpu)

          predictions = pipeline.predict(outs_test_cpu)
          print(
              "------------------------------------------------------------------------------------", 
              file = f
              )
          print(f"Clf: {k}", file = f)
          print(classification_report(labels_test_cpu, predictions), file = f)
          lb= ohe.fit_transform(labels_test_cpu.reshape(-1, 1))
          pred = ohe.fit_transform(predictions.reshape(-1, 1))
          roc = roc_auc_score(lb.toarray(), pred.toarray())
          acc = accuracy_score(labels_test_cpu, predictions)
          # print(f"Clf: {k} | Acc: {acc} | AUC-ROC: {roc}", file = f)
          print(
              "------------------------------------------------------------------------------------", 
              file = f
              )
          name_acc = "Lin_clf_acc | " + k
          name_roc = "Lin_clf_ROC | " + k
          if wandb_run is not None:
              wandb_run.summary[name_acc] = acc
              wandb_run.summary[name_roc] = roc

        print("Prototypical Losses")
        for k, v in prototypical_loss.items():
            print(f"Metric: {k} | Value: {v}", file = f)
            print(
              "------------------------------------------------------------------------------------", 
              file = f
              )
            if wandb_run is not None:
                name_acc = "Proto_acc | " + k
                wandb_run.summary[name_acc] = v
                print("Prototypical Losses")
        acc_dict = {"Sim_loss": sim_acc_list, "Dist_loss": dist_acc_list}
        for k, v in {"Sim_loss": sim_acc_list, "Dist_loss": dist_acc_list}.items():
            print(f" {k} | ", file = f)
            for i, acc in enumerate(v):
                print(f"Class: {i} | Value: {acc}", file = f)
                if wandb_run is not None:
                    name_acc = "Proto_acc | " + k + " | Class " + str(i)
                    wandb_run.summary[name_acc] = v
            print(
                "------------------------------------------------------------------------------------", 
                file = f
                )

In [None]:
import warnings
warnings.filterwarnings("ignore")
run_name = "test_"
model = VICRegCNN_2()
cal_linclf_acc(model, dataloader_test, dataloader_test, protos = None, projected = True, path = "/content/drive/MyDrive/MT Gabriel/model_runs/" + run_name + "_lin.txt", wandb_run = None)

# Execute Training

In [1]:
# from thesis.models.training import train_vicreg_cnn_2
# from thesis.loss.similarity_loss import cosine_sim
# from thesis.models.eval import cal_linclf_acc

from thesis.models.VICRegModel import VICRegCNN_2
import warnings
warnings.filterwarnings("ignore")


run_name = "20220506_04"

model = VICRegCNN_2()

train_dict = {"epochs": 100,
              "dataloader_list":dataloader_list,
              "epoch_list":[15, 30, 45, 60],
              "weight_vicreg": 0.1,
              "weight_sim": 10,
              "weight_proto":100,
              "decay_rate_vicreg": None,
              "decay_steps": 100,
              "decay_rate_sim": 0.4,
              "decay_steps_sim": -100,
              "lr": 0.001,
              "alpha": 0.3,
              "num_classes": 5,
              "t":.3,
              "sim_loss_fn":"cosine",
              "gamma": 0.9,
              "neg_selection":True,
              "neg_agg_choice":"single",
              "fraction":1.,
              "projected":False,
              }

run = wandb.init(project = "thesis", entity = "agabriel", config = train_dict)
wandb.run.name = run_name

loss_list, vicreg_loss_list, sim_loss_list = train_vicreg_cnn_2(model, root_dir = "/content/drive/MyDrive/MT Gabriel/models/", **train_dict)
callbacks = [loss_list, vicreg_loss_list, sim_loss_list]
save_train_specs(model, train_dict, callbacks, "/content/drive/MyDrive/MT Gabriel/model_runs/", run_name + ".txt")

ModuleNotFoundError: ignored

In [12]:
torch.cuda.empty_cache()

## Check linear classifier

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF

from PIL import Image
from sklearn.model_selection import train_test_split
import augly.image as imaugs
from tqdm.auto import tqdm
from thesis.helper import tensor_img_transforms

class ImageDataset(Dataset):
  def __init__(self, root_dir):
    self.root_dir = root_dir
    self.label_list = []
    self.image_list = []
    self.class_map = {}
    cls_list = []

    for class_path in glob.glob(root_dir + "*"):
      cls = class_path.split("/")[-1]
      cls_list.append(cls)
      for img_path in glob.glob(class_path + "/*.png"):
        img = Image.open(str(img_path)).convert("RGB")
        tensor_image = TF.pil_to_tensor(img)
        self.image_list.append(tensor_image)
        self.label_list.append(cls)
    
    for i, cls in enumerate(cls_list):
        self.class_map[cls] = i

  def __len__(self):
    return len(self.image_list)

  def __getitem__(self, idx):
    class_name = self.label_list[idx]
    if torch.is_tensor(idx):
          idx = idx.tolist()

    self.class_id = self.class_map[class_name]
    self.class_id = torch.tensor([self.class_id])
    return self.image_list[idx], self.class_id

image_dir = "/content/drive/MyDrive/MT Gabriel/data_ext/"
image_dataset = ImageDataset(image_dir)

labels = [label.numpy() for tensor, label in iter(image_dataset)]
train_indices, test_indices = train_test_split(list(range(len(labels))), test_size=0.2, stratify=labels)
train_dataset = torch.utils.data.Subset(image_dataset, train_indices)
test_dataset = torch.utils.data.Subset(image_dataset, test_indices)

dataloader = DataLoader(
    train_dataset, 
    batch_size = 128, 
    shuffle = True, 
    pin_memory = False,
    collate_fn = collate_CNN_2
    )

dataloader_test = DataLoader(
    test_dataset, 
    batch_size = 128, 
    shuffle = True, 
    pin_memory = False,
    collate_fn = collate_CNN_2
    )

In [9]:
cal_linclf_acc(model, dataloader_test, dataloader_test, protos = None, num_classes = 5, projected = False, path = "/content/drive/MyDrive/MT Gabriel/model_runs/" + run_name + "_lin.txt", wandb_run = wandb.run)

[0.7901234567901234, 0.6170212765957447, 0.24166666666666667, 0.5472222222222223, 0.7958333333333333]
[0.7808641975308642, 0.6103723404255319, 0.24444444444444444, 0.5611111111111111, 0.8055555555555556]
Prototypical Losses
Prototypical Losses
Prototypical Losses


In [10]:
torch.cuda.empty_cache()

# Sweep configuration

In [None]:
#from thesis.models.training import train_vicreg_cnn_2
#from thesis.loss.similarity_loss import cosine_sim
#from thesis.models.eval import cal_linclf_acc
import math
from thesis.models.VICRegModel import VICRegCNN_2
import warnings
warnings.filterwarnings("ignore")

sweep_config = {
    "method":"grid",
    "entity":"agabriel",
    "metric":{
        "name":"loss",
        "goal":"minimize"
    },
    "parameters":{
        "epochs":{
            "value":61
        },
        "epoch_list":{
            "value":[15, 30, 45, 60]
        },
        "lr":{
            "value":0.001
        },
        "t":{
            "values":[0.25, 1.5]
        },
        "neg_selection":{
            "value":True
        },
        "neg_agg_choice":{
            "value":"single"
        },
        "fraction":{
            "value": 1.
        },
        "sim_loss_fn":{
            "value":"cosine"
        },
        "num_classes":{
            "value":2
        },
        "alpha":{
            "values":[0.25, 0.75]
        },
        "weight_vicreg":{
            "values":[0., 10.]
        },
        "weight_sim":{
            "values":[1., 100.]
        },
        "weight_proto":{
            "values":[0., 100.]
        },
        "decay_rate_vicreg":{
            "value":None
        },
        "decay_rate_sim":{
            "value":None
        },
        "gamma":{
            "values":[.9]
        },
    }
}

sweep_id = wandb.sweep(sweep_config, project = "thesis")

#run = wandb.init(project = "thesis", entity = "agabriel", config = train_dict)

def train():
    with wandb.init() as run:
        print(type(wandb.run))
        config = wandb.config
        run_name = "20220502_" + wandb.run.name
        model = VICRegCNN_2()

        train_vicreg_cnn_2(model, dataloader_list, dataloader_test_list, root_dir = "/content/drive/MyDrive/MT Gabriel/models/", **config)
        cal_linclf_acc(
          model, dataloader_test, dataloader_test, protos = None, projected = False, num_classes = 5,
          path = "/content/drive/MyDrive/MT Gabriel/model_runs/" + run_name + "_lin.txt", wandb_run = wandb.run
          )

count = 10 # number of runs to execute
wandb.agent(sweep_id, function=train, count=count)

#loss_list, vicreg_loss_list, sim_loss_list = train_vicreg_cnn_2(model, dataloader, root_dir = "/content/drive/MyDrive/MT Gabriel/models/", **config)
#callbacks = [loss_list, vicreg_loss_list, sim_loss_list]
#save_train_specs(model, train_dict, callbacks, "/content/drive/MyDrive/MT Gabriel/model_runs/", run_name + ".txt")

Epoch: 6 | Epoch loss: 436.2771689675071


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
loss,█▂▁▁▁▁▁
proto_loss,▁▁▁▁▁▁▁
sim_loss,▁▆███▇▇
sim_loss_norm,▁▆███▇▇
vicreg_loss,█▂▁▁▁▁▁
vicreg_loss_norm,█▂▁▁▁▁▁
weight_proto,▁▁▁▁▁▁▁
weight_sim,▁▁▁▁▁▁▁
weight_vicreg,▁▁▁▁▁▁▁

0,1
loss,436.27717
proto_loss,0.0
sim_loss,6.44972
sim_loss_norm,6.44972
vicreg_loss,429.82744
vicreg_loss_norm,42.98274
weight_proto,0.0
weight_sim,1.0
weight_vicreg,10.0


Run 6m62h729 errored: NameError("name 'run_name' is not defined")
[34m[1mwandb[0m: [32m[41mERROR[0m Run 6m62h729 errored: NameError("name 'run_name' is not defined")
[34m[1mwandb[0m: Agent Starting Run: 4vwrhgbg with config:
[34m[1mwandb[0m: 	alpha: 0.25
[34m[1mwandb[0m: 	decay_rate_sim: None
[34m[1mwandb[0m: 	decay_rate_vicreg: None
[34m[1mwandb[0m: 	epoch_list: [15, 30, 45, 60]
[34m[1mwandb[0m: 	epochs: 61
[34m[1mwandb[0m: 	fraction: 1
[34m[1mwandb[0m: 	gamma: 0.9
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	neg_agg_choice: single
[34m[1mwandb[0m: 	neg_selection: True
[34m[1mwandb[0m: 	num_classes: 2
[34m[1mwandb[0m: 	sim_loss_fn: cosine
[34m[1mwandb[0m: 	t: 0.25
[34m[1mwandb[0m: 	weight_proto: 0
[34m[1mwandb[0m: 	weight_sim: 100
[34m[1mwandb[0m: 	weight_vicreg: 0


<class 'wandb.sdk.wandb_run.Run'>


  0%|          | 0/61 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

Run 4vwrhgbg errored: RuntimeError('CUDA out of memory. Tried to allocate 262.00 MiB (GPU 0; 15.90 GiB total capacity; 14.05 GiB already allocated; 273.75 MiB free; 14.55 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF')
[34m[1mwandb[0m: [32m[41mERROR[0m Run 4vwrhgbg errored: RuntimeError('CUDA out of memory. Tried to allocate 262.00 MiB (GPU 0; 15.90 GiB total capacity; 14.05 GiB already allocated; 273.75 MiB free; 14.55 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF')
[34m[1mwandb[0m: Agent Starting Run: mbkn66ho with config:
[34m[1mwandb[0m: 	alpha: 0.25
[34m[1mwandb[0m: 	decay_rate_sim: None
[34m[1mwandb[0m: 	decay_rate_vicreg: None
[34m[1mwandb[0m: 	epoch_list: [

<class 'wandb.sdk.wandb_run.Run'>


  0%|          | 0/61 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

Run mbkn66ho errored: RuntimeError('CUDA out of memory. Tried to allocate 784.00 MiB (GPU 0; 15.90 GiB total capacity; 13.85 GiB already allocated; 399.75 MiB free; 14.43 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF')
[34m[1mwandb[0m: [32m[41mERROR[0m Run mbkn66ho errored: RuntimeError('CUDA out of memory. Tried to allocate 784.00 MiB (GPU 0; 15.90 GiB total capacity; 13.85 GiB already allocated; 399.75 MiB free; 14.43 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF')
[34m[1mwandb[0m: Agent Starting Run: in8u8inh with config:
[34m[1mwandb[0m: 	alpha: 0.25
[34m[1mwandb[0m: 	decay_rate_sim: None
[34m[1mwandb[0m: 	decay_rate_vicreg: None
[34m[1mwandb[0m: 	epoch_list: [

<class 'wandb.sdk.wandb_run.Run'>


  0%|          | 0/61 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

Run in8u8inh errored: RuntimeError('CUDA out of memory. Tried to allocate 784.00 MiB (GPU 0; 15.90 GiB total capacity; 14.27 GiB already allocated; 101.75 MiB free; 14.72 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF')
[34m[1mwandb[0m: [32m[41mERROR[0m Run in8u8inh errored: RuntimeError('CUDA out of memory. Tried to allocate 784.00 MiB (GPU 0; 15.90 GiB total capacity; 14.27 GiB already allocated; 101.75 MiB free; 14.72 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF')
[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: 6ri4zsmf with config:
[34m[1mwandb[0m: 	alpha: 0.25
[34m[1mwandb[0m: 	decay_rate_sim: 

<class 'wandb.sdk.wandb_run.Run'>


  0%|          | 0/61 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

Run 6ri4zsmf errored: RuntimeError('CUDA out of memory. Tried to allocate 148.00 MiB (GPU 0; 15.90 GiB total capacity; 14.41 GiB already allocated; 99.75 MiB free; 14.72 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF')
[34m[1mwandb[0m: [32m[41mERROR[0m Run 6ri4zsmf errored: RuntimeError('CUDA out of memory. Tried to allocate 148.00 MiB (GPU 0; 15.90 GiB total capacity; 14.41 GiB already allocated; 99.75 MiB free; 14.72 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF')
Detected 5 failed runs in a row at start, killing sweep.
[34m[1mwandb[0m: [32m[41mERROR[0m Detected 5 failed runs in a row at start, killing sweep.
[34m[1mwandb[0m: To change this value set WANDB_AGENT_MAX_INI

In [None]:
class ImageDataset(Dataset):
  def __init__(self, root_dir):
    self.root_dir = root_dir
    self.label_list = []
    self.image_list = []
    self.class_map = {}
    cls_list = []

    for class_path in glob.glob(root_dir + "*"):
      cls = class_path.split("/")[-1]
      cls_list.append(cls)
      for img_path in glob.glob(class_path + "/*.png"):
        img = Image.open(str(img_path)).convert("RGB")
        tensor_image = TF.pil_to_tensor(img)
        self.image_list.append(tensor_image)
        self.label_list.append(cls)
    
    for i, cls in enumerate(cls_list):
        self.class_map[cls] = i

  def __len__(self):
    return len(self.image_list)

  def __getitem__(self, idx):
    class_name = self.label_list[idx]
    if torch.is_tensor(idx):
          idx = idx.tolist()

    self.class_id = self.class_map[class_name]
    self.class_id = torch.tensor([self.class_id])
    return self.image_list[idx], self.class_id

image_dir = "/content/drive/MyDrive/MT Gabriel/data_ext/"
image_dataset = ImageDataset(image_dir)

labels = [label.numpy() for tensor, label in iter(image_dataset)]
train_indices, test_indices = train_test_split(list(range(len(labels))), test_size=0.2, stratify=labels)
train_dataset = torch.utils.data.Subset(image_dataset, train_indices)
test_dataset = torch.utils.data.Subset(image_dataset, test_indices)

dataloader = DataLoader(
    train_dataset, 
    batch_size = 128, 
    shuffle = True, 
    pin_memory = False,
    collate_fn = collate_CNN_2
    )

dataloader_test = DataLoader(
    test_dataset, 
    batch_size = 128, 
    shuffle = True, 
    pin_memory = False,
    collate_fn = collate_CNN_2
    )

cal_linclf_acc(
    model, dataloader, dataloader_test, protos = None, projected = False,
    path = "/content/drive/MyDrive/MT Gabriel/model_runs/" + run_name + "_lin.txt",
    wandb_run = wandb.run,
    )

run.finish()

In [None]:
model_opt = VICRegCNN_2()
checkpoint = torch.load("/content/drive/MyDrive/MT Gabriel/models/" + run_name + ".pt")
model_opt.load_state_dict(checkpoint["model_state_dict"])

cal_linclf_acc(
    model_opt, dataloader, dataloader_test, protos = None, projected = False,
    path = "/content/drive/MyDrive/MT Gabriel/model_runs/" + run_name + "opt_lin.txt",
    wandb_run = wandb.run,
    )

artifact = wandb.Artifact('Best_Model', type='Model')
artifact.add_file('/content/drive/MyDrive/MT Gabriel/models/' + run_name + ".pt")
wandb.log_artifact(artifact)

run.finish()

In [None]:
cal_linclf_acc(
    model, dataloader, dataloader_test, protos = None, projected = False,
    path = "/content/drive/MyDrive/MT Gabriel/model_runs/" + run_name + "_lin.txt",
    wandb_run = None,
    )

# Train Regular ResNET18

In [None]:
def resnet_eval(resnet18_model, dataloader_test, num_classes, path):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    corrects = 0
    total = 0
    corrects_list = [0. for i in range(num_classes)]
    counts_list = [0. for i in range(num_classes)]
    test_acc_list = []

    resnet18_model.eval()
    utils.set_parameter_requires_grad(resnet18_model, False)
    resnet18_model.to(device)

    for data_1, data_2, label in dataloader_test:
        torch.no_grad()
        data_1 = data_1.to(device)
        data_2 = data_2.to(device)
        labels = label.to(device)
        outs_1 = resnet18_model(data_1)
        outs_2 = resnet18_model(data_2)
        outs = torch.cat((outs_1, outs_2), 0)

        labels = labels.view(labels.size(dim = 0), 1).repeat(2, 1).squeeze()
        _, preds = torch.max(outs, 1)
        
        for i, label in enumerate(range(num_classes)):
            corrects = 0
            total = 0

            mask = label == labels.detach().cpu().numpy()
            preds_masked = preds[mask].detach().cpu().numpy()
            labels_masked = labels[mask].detach().cpu().numpy()

            corrects_list[i] += (preds_masked == labels_masked).sum()
            counts_list[i] += labels_masked.size

    for i, j in zip(corrects_list, counts_list):
        test_acc_list.append(i/j)

    with open(path, "a") as f:
        for i, entry in enumerate(test_acc_list):
            print(f"Class: {i} | Test Acc: {entry}", file = f)
        print(f"Overall Test Acc: {np.sum(test_acc_list)/num_classes:.2}", file = f)
    return test_acc_list

In [None]:
import torchvision
import torch.nn as nn
import copy

def resnet18_training(
    model: torch.nn.Module = None, 
    dataloader_list: list = dataloader_list,
    dataloader_test_list: list = dataloader_test_list,
    num_classes: int = 2,
    epoch_list: list = [25, 50, 75, 100],
    lr_scheduler: str = "exp",
    lr: float = 0.001,
    gamma: float = 0.9,
    path: str = "/content/drive/MyDrive/MT Gabriel/model_runs/",
    run_name: str = None,
    **kwargs,
    ):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    model.train()

    num_features = resnet18_model.fc.in_features
    resnet18_model.fc = nn.Linear(num_features, num_classes)
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr = lr)
    if lr_scheduler == "exp":
      scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = gamma)

    train_acc_list = []
    loss_list = []

    epochs_limit = 0
    epochs_limit_counter = 0
    epochs_limit = epoch_list[epochs_limit_counter]
    dataloader = dataloader_list[epochs_limit_counter]
    dataloader_test = dataloader_test_list[epochs_limit_counter]
    epochs = epoch_list[-1]

    for epoch in tqdm(range(epochs)):
        corrects = torch.Tensor([0]).to(device)
        total = torch.Tensor([0]).to(device)

        for data_1, data_2, label in tqdm(dataloader, leave = False):
            optimizer.zero_grad()
            if epoch == epochs_limit:
                model_cop = copy.deepcopy(model)

                resnet_eval(model_cop, dataloader_test, path + run_name + "_lin_" + str(epochs_limit_counter) + ".txt")
                epochs_limit_counter += 1
                num_classes += 1
                epochs_limit = epoch_list[epochs_limit_counter]
                dataloader = dataloader_list[epochs_limit_counter]
                dataloader_test = dataloader_test_list[epochs_limit_counter]
                print(f"Using dataloader no. {epochs_limit_counter}")
                
                #num_features = resnet18_model.fc.in_features
                #model.fc = torch.nn.Linear(num_features, num_classes)
                #model.to(device)

            data_1 = data_1.to(device)
            data_2 = data_2.to(device)
            labels = label.to(device)
            outs_1 = model(data_1)
            outs_2 = model(data_2)
            outs = torch.cat((outs_1, outs_2), 0)

            labels = labels.view(labels.size(dim = 0), 1).repeat(2, 1).squeeze()

            criterion = torch.nn.CrossEntropyLoss()
            loss = criterion(outs, labels)
            _, preds = torch.max(outs, 1)

            loss.backward()
            optimizer.step()

            corrects += torch.sum(preds == labels)
            total += torch.numel(labels)

        epoch_acc = (corrects/total).detach().cpu().numpy()
        wandb.log({"train_acc":epoch_acc, "loss":loss})

        train_acc_list.append(epoch_acc)
        loss_list.append(loss)
        print(f"Epoch: {epoch} | Epoch_acc: {epoch_acc} | Loss: {loss}")
    return loss_list, train_acc_list

In [None]:
run_name = "20220427_03_resnet18"

resnet18_model = torchvision.models.resnet18(pretrained = True, progress = True, zero_init_residual=True)

train_dict = {"model":resnet18_model,
              "epoch_list": [15, 30, 45, 60],
              "dataloader_list":dataloader_list,
              "dataloader_test_list":dataloader_test_list,
              "lr":0.001,
              "num_classes":5,
              "gamma":0.9,
              "lr":0.001,
              "lr_scheduler":"exp",
              }

run = wandb.init(project = "thesis", entity = "agabriel", config = train_dict)
wandb.run.name = run_name

loss_list, acc_list = resnet18_training(path = "/content/drive/MyDrive/MT Gabriel/model_runs/", run_name = run_name, **train_dict)

run.finish()

In [None]:
import tensorflow
import torch
from torch.utils.data import DataLoader, Dataset

from thesis.helper import dataset, utils, tensor_img_transforms
from thesis.models.VICRegModel import VICRegCNN_2
from thesis.models import VICRegModel
#from thesis.models import training
from thesis.models import eval
from thesis.helper.utils import save_train_specs
from sklearn.model_selection import train_test_split
#from thesis.helper.dataset import train_test_split
#from thesis.helper.utils import cal_linclf_acc

import os
import numpy as np
import glob

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF

from PIL import Image

from sklearn.model_selection import train_test_split

import augly.image as imaugs

from tqdm.auto import tqdm

from thesis.helper import tensor_img_transforms

class ImageDataset(Dataset):
  def __init__(self, root_dir):
    self.root_dir = root_dir
    self.label_list = []
    self.image_list = []
    self.class_map = {}
    cls_list = []

    for class_path in glob.glob(root_dir + "*"):
      cls = class_path.split("/")[-1]
      cls_list.append(cls)
      for img_path in glob.glob(class_path + "/*.png"):
        img = Image.open(str(img_path)).convert("RGB")
        tensor_image = TF.pil_to_tensor(img)
        self.image_list.append(tensor_image)
        self.label_list.append(cls)
    
    for i, cls in enumerate(cls_list):
        self.class_map[cls] = i

  def __len__(self):
    return len(self.image_list)

  def __getitem__(self, idx):
    class_name = self.label_list[idx]
    if torch.is_tensor(idx):
          idx = idx.tolist()

    self.class_id = self.class_map[class_name]
    self.class_id = torch.tensor([self.class_id])
    return self.image_list[idx], self.class_id

image_dir = "/content/drive/MyDrive/MT Gabriel/data_ext/"
image_dataset = ImageDataset(image_dir)

labels = [label.numpy() for tensor, label in iter(image_dataset)]
train_indices, test_indices = train_test_split(list(range(len(labels))), test_size=0.2, stratify=labels)
train_dataset = torch.utils.data.Subset(image_dataset, train_indices)
test_dataset = torch.utils.data.Subset(image_dataset, test_indices)

dataloader = DataLoader(
    train_dataset, 
    batch_size = 128, 
    shuffle = True, 
    pin_memory = False,
    collate_fn = collate_CNN_2
    )

dataloader_test = DataLoader(
    test_dataset, 
    batch_size = 128, 
    shuffle = True,
    pin_memory = False,
    collate_fn = collate_CNN_2
    )

In [None]:
import torchvision
resnet18_model = torch.load("/content/drive/MyDrive/MT Gabriel/model_runs/resnet18_temp.pt")

In [None]:
test_acc = resnet_eval(resnet18_model, dataloader_test, num_classes = 5, path = "/content/drive/MyDrive/MT Gabriel/model_runs/" + run_name + "_final.txt")

  "Argument interpolation should be of type InterpolationMode instead of int. "
  return torch.stack(timage_list_1, 0, out=out_1).squeeze().to(
  ), torch.stack(timage_list_2, 0, out=out_2).squeeze().to(
  ), torch.stack(label_list, 0, out=out_3).squeeze().to(
  return torch.stack(timage_list_1, 0, out=out_1).squeeze().to(
  ), torch.stack(timage_list_2, 0, out=out_2).squeeze().to(
  ), torch.stack(label_list, 0, out=out_3).squeeze().to(


# Embedding visualization

In [None]:
%matplotlib inline
from typing import Any
import torch
import plotly.express as px
import pandas as pd

import umap

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns

from skimage.future import graph
from torch.utils.data import DataLoader, Dataset
from thesis.helper import utils
from sklearn.preprocessing import StandardScaler

def visualize_embeddings(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    n_neighbors: int = 30, 
    min_dist: int = 0,
    n_components: int = 3,
    prototypes: np.array = None,
    labels_dict: dict = {0.0:"fold", 1.0: "gap", 2.0: "hole", 3.0: "rabbet", 4.0: "regular"},
    projected: bool = True,
    ) -> [list, list]:

    """
    Args: pretrained model; n_neighbors, min_dist, n_components for UMAP algo

    Returns: plot of embeddings in 3d-space
    """

    utils.set_parameter_requires_grad(model, False)

    #Calculate embeddings:
    device = "cuda" if torch.cuda.is_available() else "cpu"

    model.eval()
    model.to(device)
    outs = torch.Tensor().to(device)
    labels = torch.Tensor().to(device)

    activation = {}
    def get_activation(name):
        def hook(model, input, output):
            activation[name] = output.detach()
        return hook

    for data, data_2, label in dataloader:
        torch.no_grad()
        if projected == False:
            model.backbone_1.fc.register_forward_hook(get_activation("fc_1"))
            model.backbone_2.fc.register_forward_hook(get_activation("fc_2"))
            outs_proj = model(data.float(),data_2.float()).squeeze()

            a = activation["fc_1"]
            b = activation["fc_2"]
            features = torch.cat((a, b), 1)
            outs = torch.cat((outs, features), 0)
            labels = torch.cat((labels, label), 0)

        else:
            outs = torch.cat((outs, model(data.float(), data_2.float()).squeeze().to(device)), 0)
            outs = outs.detach().cpu()
            labels = torch.cat((labels, label), 0)

    feature_size = 0.5*outs.size()[1]
    outs = torch.cat((outs[:, :int(feature_size)], outs[:, int(feature_size):]), 0)
    labels = labels.detach().cpu().repeat(2).numpy()

    if prototypes is not None:
        labels_prototypes = [(np.max(labels) + 1 + i).astype(float) for i, proto in enumerate(prototypes)]
        labels = np.concatenate((labels, labels_prototypes))
        outs = torch.cat((outs, protos), 0)

    #Scale inputs to ease UMAP operation (m = 0, std = 1)
    scaler = StandardScaler()
    out_np = outs.detach().cpu().numpy()
    outs_scaled = scaler.fit_transform(out_np)

    #Generate px.scatter_3d plot
    sns.set(style='white', context='poster', rc={'figure.figsize':(14,10)})

    clusterable_embedding = umap.UMAP(
        n_neighbors = n_neighbors,
        min_dist = min_dist,
        n_components = n_components,
        random_state = 47,
    ).fit_transform(out_np)

    df = pd.DataFrame()
    df["x"] = clusterable_embedding[:,0]
    df["y"] = clusterable_embedding[:,1]
    df["z"] = clusterable_embedding[:,2]
    df["labels"] = labels
    df["labels"] = df["labels"].replace(labels_dict)

    if prototypes is None:
        means = df.groupby("labels").mean()
        means["labels"] = ["p_fold", "p_gap", "p_hole", "p_rabbet", "p_regular"]
        df = df.append(means)

    fig = px.scatter_3d(df, x = "x", y = "y", z = "z", color = "labels")
    fig.show()
    if prototypes is None:
        return clusterable_embedding, labels
    else:
        return clusterable_embedding, labels

device = "cuda" if torch.cuda.is_available() else "cpu"
embed, labels = visualize_embeddings(model, dataloader, prototypes = None, projected = False)

  "Argument interpolation should be of type InterpolationMode instead of int. "


In [None]:
visualize_embeddings(model, dataloader, prototypes = None, projected = False)

#Separator

In [None]:
import os
import torch
import numpy as np
from torch import nn as nn
import torch.optim as optim
from tqdm.auto import tqdm

from IPython.display import clear_output

from thesis.loss import vicreg_loss_fn as vlf
from thesis.loss.similarity_loss import cosine_sim
from thesis.loss.similarity_loss import NCELoss

def train_vicreg_cnn_2(
  model: torch.nn.Module, 
  dataloader: torch.utils.data.DataLoader, 
  epochs: int,
  weight_vicreg: float = 1,
  sim_vicreg: float = 25,
  var_vicreg: float = 25,
  cov_vicreg: float = 1,
  decay_rate_vicreg: float = 0.01,
  decay_steps_vicreg: float = 100,
  weight_sim: float = 0.01,
  decay_rate_sim: float = 0.01,
  decay_steps_sim: float = -100,
  lr: float = 0.001,
  t: float = 1.,
  m: float = 0.5,
  alpha: float = 0.5,
  lr_scheduler: str = "exp",
  metric: str = "euclid",
  warm_up: int = 20,
  num_classes: int = 3,
  gamma: float = 0.9,
  root_dir: str = None,
  run_name: str = None, 
  **kwargs) -> torch.Tensor:

  """Training step for Self-Supervised Training model with VICReg and Sim loss
  Args:
      batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where
          [X] is a list of size num_crops containing batches of images.
      batch_idx (int): index of the batch.
  Returns:
      torch.Tensor: total loss composed of VICReg loss and classification loss.
  Gratefully adapted from: https://github.com/vturrisi/solo-learn
  """

  try:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.train()
    model.to(device)
    prototypes = torch.Tensor().to(device)
    counts = torch.Tensor().to(device)

    # Initiate return variables
    loss_list, vicreg_loss_list, sim_loss_list, prototypes_list = [], [], [], []
    
    # Define optimizer and scheduler
    optimizer = torch.optim.Adam(model.parameters(), lr = lr)

    if lr_scheduler == "exp":
      scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = gamma)

    weight_vicreg_init = weight_vicreg
    weight_sim_init = weight_sim

    # Training loop
    for epoch in tqdm(range(epochs)):
        batch_count = batch_loss = vicreg_batch_loss = sim_batch_loss = epoch_loss = 0
        
        # VICReg Loss Decay
        if (bool(decay_rate_vicreg)) & (epoch <= abs(decay_steps_vicreg)):
            weight_vicreg = weight_vicreg_init*decay_rate_vicreg**(epoch/decay_steps_vicreg)
        
        # Sim Loss Decay
        if (bool(decay_rate_sim)) & (epoch <= abs(decay_steps_sim)):
            weight_sim = weight_sim_init*decay_rate_sim**(epoch/decay_steps_sim)

        # Batch loop
        for image_1, image_2, labels in tqdm(dataloader, leave = False):
            
            # Zero grads -> forward pass -> compute loss -> backprop
            optimizer.zero_grad()
            out = model(image_1.float(), image_2.float()).float().squeeze()
            feature_size = out.size()[1]
            labels = labels.view(labels.size(dim = 0), 1).repeat(2, 1)
            
            # Calculate VICReg Loss function
            vicreg_loss = weight_vicreg*vlf.vicreg_loss_func(
                out[:,:int(feature_size*0.5)],
                out[:, int(feature_size*0.5):], sim_loss_weight = sim_vicreg,
                var_loss_weight = var_vicreg, cov_loss_weight = cov_vicreg,
                ).float()
            
            # Assign features for handling in following loss functions from model
            # outputs
            sim_features = torch.cat(
                            (
                                out[:,:int(feature_size*0.5)],
                                out[:, int(feature_size*0.5):]
                              ), dim = 0)
            
            #Implementation of prototypical loss
            sim_loss, prototypes, counts = energy_loss(
                reps = sim_features, 
                labels = labels,
                prototypes = prototypes,
                alpha = alpha,
                metric = metric,
                warm_up = warm_up,
                epoch = epoch,
                counts = counts,
                t = t,
                m = m,
                )
            
            # Reassign prototypes
            #prototypes = prototypes.detach()
            sim_loss = weight_sim*sim_loss.float().detach()
            
            # Determine the probability with which supervised labels will be used
            loss = vicreg_loss + sim_loss
            loss.backward()
            optimizer.step()

            # Output batch losses
            batch_count += 1
            batch_loss += loss.detach().cpu().numpy()
            vicreg_batch_loss += vicreg_loss.detach().cpu().numpy()
            sim_batch_loss += sim_loss.detach().cpu().numpy()
            print(f"Epoch: {epoch} | Batch_Loss: {loss.detach().cpu().numpy()}")
            
        clear_output()

        # Calculate and log epoch losses
        epoch_loss = batch_loss/batch_count
        vicreg_loss = vicreg_batch_loss/batch_count
        sim_loss = sim_batch_loss/batch_count

        wandb.log({
                    "loss":epoch_loss, 
                    "vicreg_loss": vicreg_loss, 
                    "sim_loss": sim_loss,
                    "sim_loss_norm": sim_loss/weight_sim,
                    "vicreg_loss_norm": vicreg_loss/weight_vicreg,
                    "weight_vicreg":weight_vicreg, 
                    "weight_sim":weight_sim,
                    "prototypes":prototypes.detach(),
                  })
        
        prototypes_list.append(prototypes.detach().cpu().numpy())
            
        loss_list.append(epoch_loss)
        vicreg_loss_list.append(vicreg_loss)
        sim_loss_list.append(sim_loss)

        print(f"Epoch: {epoch} | Epoch loss: {epoch_loss:.2f}")

        # Save model, in case a root_dir is given
        if (root_dir != None) & (loss < loss_list[-1]):
          PATH = os.path.join(root_dir, f"{run_name}.pt")
          torch.save({
              'epoch': epoch,
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'loss': loss,
              }, PATH)          
          
    # Return loss logs and prototypes, in case it's given
    return loss_list, vicreg_loss_list, sim_loss_list, prototypes_list

  except KeyboardInterrupt:
      print("Execution interrupted by user")
      return loss_list, vicreg_loss_list, sim_loss_list, prototypes_list

In [None]:
def proto_sim(reps: torch.Tensor,
                labels: torch.Tensor,
                prototypes: torch.Tensor,
                t: float,
                alpha: float,
                alpha_prot: float,
                instance_weight: float,
                proto_weight: float,
                cel_weight: float,
                dist_weight: float,
                num_classes: int,
                epsilon: float,
                epoch: int,
                ) -> torch.Tensor:
        
    device = "cuda" if torch.cuda.is_available() else "cpu"
    assert (alpha >= 0) & (alpha <= 1), "alpha must be in [0,1]"
    assert (alpha_prot >= 0) & (alpha_prot <= 1), "alpha_prot must be in [0,1]"

    if epoch == 0:
        prototypes = torch.zeros(
            len(torch.unique(labels)), reps.size()[1]
            ).to(device)
        for i in torch.unique(labels):
            trues = labels == i
            trues = trues.view(-1, 1)
            prototypes[i] = (trues.T.float() @ reps)/torch.sum(trues)

    # For numerical stability of exp function
    eps = torch.Tensor([1e-08]).to(device)
    Eps = torch.Tensor([1.797693134862315e+308]).to(device)

    reps = reps/torch.sqrt(torch.sum(reps.pow(2), dim = 1, keepdim = True))
    prototypes = prototypes/torch.sqrt(torch.sum(prototypes.pow(2), dim = 1, keepdim = True))

    # Calculate l2-norms of all vector combinations 
    v_1, v_2 = torch.sum(torch.square(reps), dim = 1).view(reps.size(0), 1), \
        torch.sum(torch.square(reps), dim = 1).view(1, reps.size(0))

    v_p1, v_p2 = torch.sum(torch.square(prototypes), dim = 1).view(prototypes.size(0), 1), \
        torch.sum(torch.square(prototypes), dim = 1).view(1, prototypes.size(0))

    # Note: matmul of v_1 with v_p2 to get 100 x 3 norm_matrix_proto
    norm_matrix = torch.matmul(torch.sqrt(v_1), torch.sqrt(v_2))
    norm_matrix_proto = torch.matmul(torch.sqrt(v_1), torch.sqrt(v_p2))

    # Calculate vector (cosine) similarities and normalize by l2 norms of vectors
    sim = torch.matmul(reps, reps.T)/(torch.max(eps, norm_matrix)*t)
    #sim = sim/torch.sum(sim, dim = 1, keepdim = True)

    sim_p = torch.matmul(reps, prototypes.T)/(torch.max(eps, norm_matrix_proto)*t)
    sim_p = torch.clamp(torch.exp(sim_p + 0), min = eps, max = Eps)

    #sim_p = sim_p/torch.sum(sim_p, dim = 1, keepdim = True)

    # Concat output from head_1 and head_2 and choose best prototype by "ensemble" voting
    ensemble = torch.cat((
        sim_p[:int(0.5*sim_p.size()[0]), :], 
        sim_p[int(0.5*sim_p.size()[0]):, :]), 
        dim = 1
        )

    # Find the prototype p_c with the shortest distance to two corresponding image patches
    proto_max = torch.argmax(ensemble, dim = 1)
    proto_max_tot = torch.LongTensor([entry if entry < sim_p.size()[1] else entry - sim_p.size()[1] for entry in proto_max]).to(device)
    proto_c = torch.cat((proto_max_tot, proto_max_tot))

    # Calculate mean of the sims of instances of the same (positive)
    # prototypes of one class: p_c_+

    # If calculated from the predictions
    #proto_labels = torch.nn.functional.one_hot(proto_c, num_classes = num_classes)

    # If calculated from the labels
    proto_labels = torch.nn.functional.one_hot(labels.squeeze(), num_classes = num_classes)

    proto_pos_loss = alpha_prot*(torch.mean(proto_labels.float() @ sim_p.T))
    proto_neg_loss = (1 - alpha_prot)*torch.mean((abs(proto_labels.float() - 1) @ sim_p.T))

    # Calculate the CrossEntropyLoss -> correct classification of instance i_c to p_c
    cel = torch.nn.CrossEntropyLoss()
    ce_loss = cel(sim_p, labels.squeeze())

    # Delete "self-loops" from similarity matrix by subtracting diagonal values
    sim = sim - torch.diag(torch.diagonal(sim))
    # Add zero for stability and clamp to float32 values
    sim = torch.clamp(torch.exp(sim + 0), min = eps, max = Eps)

    # Finds which instances are of the same class
    # If cls1 == cls2 -> label_1 - label_2 == 0
    # If cls 1 != cls2 -> abs(label_1 - label_2) >= 0
    proto_class = proto_c.view(-1, 1)
    pos_mask = (~torch.abs(labels.T - labels).bool()).float()
    neg_mask = (torch.abs(labels.T - labels).bool()).float()

    # Average positive and negative similarities for a batch and weight by alpha
    pos_loss = torch.mean(alpha*(pos_mask*sim))
    neg_loss = torch.mean((1 - alpha)*(neg_mask*sim))

    # Update Prototypes
    prototypes_updated = prototypes.clone().detach()

    for i in range(prototypes.size()[0]):
        label_i = labels == i
        prototype_new = (epsilon*prototypes_updated[i] + (1-epsilon)*(torch.mean(label_i.float().T @ reps)/torch.sum(label_i)))
        prototypes_updated[i] = prototype_new/torch.linalg.vector_norm(prototype_new)

    # Calculate the distances of the prototypes -> L2-Norm normalization
    proto_dist = torch.Tensor([0]).to(device)
    for i, prototype in enumerate(prototypes_updated):
        proto_dist += torch.sqrt(torch.sum(torch.square(prototypes_updated - prototype)))
    proto_dist = proto_dist/prototypes_updated.size()[0]
    
    # Sum up and weigh the different losses
    instance_loss = instance_weight*(neg_loss - pos_loss)
    proto_loss = proto_weight*(proto_neg_loss - proto_pos_loss)
    ce_l = cel_weight*ce_loss
    dist_l = dist_weight*proto_dist
    loss =  instance_loss + proto_loss + ce_l - dist_l.squeeze()

    # Return overall loss
    return loss.to(reps.device), instance_loss, proto_loss, ce_l, dist_l, prototypes_updated
    #return loss.to(reps.device), instance_loss, proto_loss, ce_l, prototypes_updated

In [None]:
def energy_loss(
    reps: torch.Tensor = None, 
    labels: torch.Tensor = None,
    prototypes: torch.Tensor = None,
    alpha: float = 0.3,
    metric: str = "euclid",
    warm_up: int = 20,
    epoch: int = None,
    counts: int = None,
    num_classes: int = 3,
    t: float = 30,
    ):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    feature_size = reps.size()[1]    
    eps = torch.Tensor([1e-08]).to(device)
    Eps = torch.Tensor([1.797693134862315e+308]).to(device)

    def euclid(reps, labels, prototypes):
        # Calculate the euclidean distance by binomial expansion:
        # (a - b)**2 = a**2 + 2ab + b**2
        # vector --> a.T*a + 2*a.T*b + b.T * b = (a - b)^2
        loss = torch.Tensor([0]).to(device)
        if prototypes is None:
            cosine_matrix = reps @ reps.T
            reps_squared = reps.pow(2).sum(1)
            diff = reps_squared - reps_squared.T
            dist_matrix = (cosine_matrix + diff)/feature_size

            pos_mask = (~torch.abs(labels.T - labels).bool()).float()
            neg_mask = (torch.abs(labels.T - labels).bool()).float()

            loss = alpha*((pos_mask*dist_matrix).sum() + (1-alpha)*(torch.log(torch.exp(neg_mask*-dist_matrix).sum(1, keepdim = True))).sum())/reps.size()[0]
            return loss

        else:
            cosine_matrix = reps @ prototypes.T
            reps_squared = reps.pow(2).sum(1, keepdim = True)
            prototypes_squared = prototypes.pow(2).sum(1, keepdim = True)
            prototypes_diff = prototypes_squared - prototypes_squared.T

            cosine_matrix_neg = reps @ reps.T
            diff_neg = reps_squared - reps_squared.T
            dist_matrix_neg = (cosine_matrix_neg + diff_neg)/feature_size

            prototypes_squared = prototypes_squared.repeat(1, reps.size()[0])
            reps_squared = reps_squared.repeat(1, prototypes.size()[0])

            diff = prototypes_squared.T - reps_squared
            dist_matrix = (cosine_matrix + diff)/feature_size

            pos_mask_ohe = torch.nn.functional.one_hot(labels).to(device)
            #neg_mask_ohe = (~pos_mask_ohe.bool()).float().squeeze()
            neg_mask = (torch.abs(labels.T - labels).bool()).float()

            t = reps.size()[0]
            pos_loss = alpha*((pos_mask_ohe*dist_matrix)/t).sum()
            neg_loss = (1-alpha)*torch.log(torch.exp((neg_mask*-dist_matrix_neg)/t).sum(1, keepdim = True)).sum()
            loss = (pos_loss + neg_loss)/t

            prototypes_cosine = prototypes @ prototypes.T
            prototypes_dist = ((prototypes_cosine + prototypes_diff)/feature_size).sum()
            return loss - 0.01*prototypes_dist

    def cosine(reps, labels, prototypes, t):
        # Calculate the cosine similarity of reps / prototypes
        loss = torch.Tensor([0]).to(device)

        if prototypes is None:
            cosine_matrix = reps @ reps.T
            reps_squared = reps.pow(2).sum(1)
            norm_matrix = (reps_squared @ reps_squared.T).to(device)
            sim_matrix = cosine_matrix/norm_matrix
            sim_matrix = sim_matrix - torch.diag(torch.diagonal(sim_matrix))

            pos_mask = (~torch.abs(labels.T - labels).bool()).float()
            neg_mask = (torch.abs(labels.T - labels).bool()).float()
            pos_loss = alpha*(pos_mask*-sim_matrix).sum()
            print(pos_loss)
            neg_loss = (1-alpha)*(neg_mask*torch.log(torch.exp(-sim_matrix/t).sum(1, keepdim = True))).sum()
            print(neg_loss)
            loss = (pos_loss + neg_loss)/reps.size()[0]
            return loss

        else:
            cosine_matrix = reps @ prototypes.T
            reps_squared = reps.pow(2).sum(1)

            prototypes_squared = prototypes.pow(2).sum(1)
            norm_matrix = (reps_squared @ prototypes_squared.T).to(device)
            sim_matrix = cosine_matrix/norm_matrix
            sim_matrix = sim_matrix - torch.diag(torch.diagonal(sim_matrix))

            pos_mask = alpha*(~torch.abs(labels.T - labels).bool()).float()
            neg_mask = (1-alpha)*(torch.abs(labels.T - labels).bool()).float()
            loss = ((pos_mask*-sim_matrix).sum() + (neg_mask*torch.log(torch.exp(-sim_matrix/t)).sum(1, keepdim = True)).sum())/reps.size()[0]
            return loss

    def update_proto(reps, labels, prototypes, counts):
            label_copy = labels.detach().cpu().numpy()
            prototypes_updated = torch.Tensor().to(device)

            for i, label in enumerate(np.unique(label_copy)):
                mask = torch.Tensor(np.array(label == label_copy).astype(float)).to(device)
                label = int(label)
                prototype = mask.T @ reps
                counts[label] += mask.sum()
                prototypes_updated = torch.cat((prototypes_updated, (1-mask.sum()/counts[label])*prototypes[label]+(mask.sum()/counts[label])*prototype))
            return prototypes_updated, counts

    if epoch == 0:
        prototypes = torch.zeros(size = (num_classes, int(feature_size))).to(device)
        counts = torch.zeros(prototypes.size()[0]).to(device)

    if epoch < warm_up:
        if metric == "euclid":
            loss = euclid(reps, labels, None)
        elif metric == "cosine":
            loss = cosine(reps, labels, None, t)

    else:
        if epoch == warm_up:
            prototypes, counts = update_proto(reps, labels, prototypes, counts)
        if metric == "euclid":
            loss = euclid(reps, labels, prototypes)
        elif metric == "cosine":
            loss = cosine(reps, labels, prototypes, t)
        prototypes, counts = update_proto(reps, labels, prototypes, counts)

    return loss, prototypes, counts

  "Argument interpolation should be of type InterpolationMode instead of int. "


In [None]:
class ImageDataset(Dataset):
  def __init__(self, root_dir):
    self.root_dir = root_dir
    self.label_list = []
    self.image_list = []
    self.class_map = {}
    cls_list = []

    for class_path in glob.glob(root_dir + "*"):
      cls = class_path.split("/")[-1]
      cls_list.append(cls)
      for img_path in glob.glob(class_path + "/*.png"):
        img = Image.open(str(img_path)).convert("RGB")
        tensor_image = TF.pil_to_tensor(img)
        self.image_list.append(tensor_image)
        self.label_list.append(cls)
    
    for i, cls in enumerate(cls_list):
        self.class_map[cls] = i

  def __len__(self):
    return len(self.image_list)

  def __getitem__(self, idx):
    class_name = self.label_list[idx]
    if torch.is_tensor(idx):
          idx = idx.tolist()

    self.class_id = self.class_map[class_name]
    self.class_id = torch.tensor([self.class_id])
    return self.image_list[idx], self.class_id

#Clustering stuff

In [None]:
from sklearn.mixture import BayesianGaussianMixture
import itertools

import numpy as np
from scipy import linalg
import matplotlib.pyplot as plt
import matplotlib as mpl

def plot_results(X, Y_, means, covariances, index, title):
    splot = plt.subplot(2, 1, 1 + index)
    for i, (mean, covar, color) in enumerate(zip(means, covariances, color_iter)):
        v, w = linalg.eigh(covar)
        v = 2.0 * np.sqrt(2.0) * np.sqrt(v)
        u = w[0] / linalg.norm(w[0])
        # as the DP will not use every component it has access to
        # unless it needs it, we shouldn't plot the redundant
        # components.
        if not np.any(Y_ == i):
            continue
        plt.scatter(X[Y_ == i, 0], X[Y_ == i, 1], 0.8, color=color)

        # Plot an ellipse to show the Gaussian component
        angle = np.arctan(u[1] / u[0])
        angle = 180.0 * angle / np.pi  # convert to degrees
        ell = mpl.patches.Ellipse(mean, v[0], v[1], 180.0 + angle, color=color)
        ell.set_clip_box(splot.bbox)
        ell.set_alpha(0.5)
        splot.add_artist(ell)

    plt.xlim(-9.0, 5.0)
    plt.ylim(-3.0, 6.0)
    plt.xticks(())
    plt.yticks(())
    plt.title(title)

# Fit a Dirichlet process Gaussian mixture using five components
dpgmm = BayesianGaussianMixture(n_components=5, covariance_type="full").fit(data_train)
plot_results(
    data_train,
    dpgmm.predict(data_test),
    dpgmm.means_,
    dpgmm.covariances_,
    1,
    "Bayesian Gaussian Mixture with a Dirichlet process prior",
)

plt.show()

In [None]:
model.eval()
device = "cuda"
model.to(device)
utils.set_parameter_requires_grad(model, False)
outs = torch.Tensor().to(device)
labels = torch.Tensor().to(device)

for data, data_2, label in dataloader:
    torch.no_grad()
    outs = torch.cat((outs, model(data.float(),data_2.float()).squeeze()), 0)
    labels = torch.cat((labels, label), 0)


Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum.



In [None]:
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, adjusted_mutual_info_score, normalized_mutual_info_score

km = KMeans(n_clusters = 3)
labels_predicted = km.fit_predict(embeddings)
samples_transformed = km.transform(embeddings)
print(f"Label allocation: {np.unique(labels_predicted, return_counts = True)}")
rand_ind_score = adjusted_rand_score(labels.detach().cpu().numpy(), labels_predicted)
adj_info = adjusted_mutual_info_score(labels.detach().cpu().numpy(), labels_predicted)
print(f"Random index score: {rand_ind_score} | Mutual info score: {adj_info})")

Label allocation: (array([0, 1, 2], dtype=int32), array([1094,  858,  695]))
Random index score: -0.00014908501444914538 | Mutual info score: -0.00010214337805545946)


In [None]:
from sklearn.mixture import BayesianGaussianMixture
bgm = BayesianGaussianMixture(
            weight_concentration_prior_type="dirichlet_process",
            n_components=2 * 3,
            init_params="random",
            max_iter=1500,
            mean_precision_prior=0.8,
            random_state=3407,
        )

#Backup

In [None]:
class VICRegCNN_2(nn.Module):
    def __init__(
        self,
        features_dim: int = 512,
        proj_output_dim: int = 4096,
        proj_hidden_dim: int = 4096,
        sim_loss_weight: float = 25,
        var_loss_weight: float = 25,
        cov_loss_weight: float = 1,
        backbone_1 = torchvision.models.resnet18(pretrained = True, progress = True, zero_init_residual=True),
        backbone_2 = torchvision.models.resnet18(pretrained = True, progress = True, zero_init_residual=True),
        **kwargs
    ):

        """Implements VICReg with two CNN branches (https://arxiv.org/abs/2105.04906)
        Args:
            proj_output_dim (int): number of dimensions of the projected features.
            proj_hidden_dim (int): number of neurons in the hidden layers of the projector.
            sim_loss_weight (float): weight of the invariance term.
            var_loss_weight (float): weight of the variance term.
            cov_loss_weight (float): weight of the covariance term.
            backbone_1, backbone_2 (torch.nn.Module): models of the respective branches
        """

        super().__init__()

        self.sim_loss_weight = sim_loss_weight
        self.var_loss_weight = var_loss_weight
        self.cov_loss_weight = cov_loss_weight
        self.backbone_1 = backbone_1
        self.backbone_2 = backbone_2

        self.backbone_1.fc = nn.Linear(in_features=features_dim, out_features = features_dim)
        self.backbone_2.fc = nn.Linear(in_features = features_dim, out_features = features_dim)

        # projector
        self.projector = nn.Sequential(
            nn.Linear(features_dim, proj_hidden_dim),
            nn.BatchNorm1d(proj_hidden_dim),
            nn.ReLU(),
            nn.Linear(proj_hidden_dim, proj_hidden_dim),
            nn.BatchNorm1d(proj_hidden_dim),
            nn.ReLU(),
            nn.Linear(proj_hidden_dim, proj_output_dim),
        )

    def forward(
        self, 
        timage_1: torch.Tensor, 
        timage_2: torch.Tensor, 
        *args, 
        **kwargs):
      
        """Performs the forward pass of the backbones and the projectors.
        Args:
            X (torch.Tensor): a batch of images in the tensor format.
        """
        z1 = self.projector(self.backbone_1(timage_1))
        z2 = self.projector(self.backbone_2(timage_1))

        out = torch.cat((z1, z2), 1)
        return out

In [None]:
def collate_CNN_2(batch):
  device = "cuda" if torch.cuda.is_available() else "cpu"

  timage_list_1, timage_list_2, label_list = [], [], []
  for timage, labels in batch:
    t = tensor_img_transforms.Transform()
    y1, y2 = t(timage)
    timage_list_1.append(y1)
    timage_list_2.append(y2)
    label_list.append(labels)

  elem_1 = timage_list_1[0]
  numel_1 = sum(x.numel() for x in timage_list_1)
  storage_1 = elem_1.storage()._new_shared(numel_1)
  out_1 = elem_1.new(storage_1).resize_(len(batch), *list(elem_1.size()))

  elem_2 = timage_list_2[0]
  numel_2 = sum(x.numel() for x in timage_list_2)
  storage_2 = elem_2.storage()._new_shared(numel_2)
  out_2 = elem_2.new(storage_2).resize_(len(batch), *list(elem_2.size()))

  elem_3 = label_list[0]
  numel_3 = sum(x.numel() for x in label_list)
  storage_3 = elem_3.storage()._new_shared(numel_3)
  out_3 = elem_3.new(storage_3).resize_(len(batch), *list(elem_3.size()))

  return torch.stack(timage_list_1, 0, out=out_1).squeeze().to(
      device, non_blocking = True
      ), torch.stack(timage_list_2, 0, out=out_2).squeeze().to(
      device, non_blocking = True
      ), torch.stack(label_list, 0, out=out_3).squeeze().to(
      device, non_blocking = True
      )

In [None]:
def energy_loss(
    reps: torch.Tensor = None, 
    labels: torch.Tensor = None,
    prototypes: torch.Tensor = None,
    metric: str = "euclid",
    warm_up: int = 20,
    epoch: int = None,
    counts: int = None,
    num_classes: int = 3
    ):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    feature_size = reps.size()[1]    
    eps = torch.Tensor([1e-08]).to(device)
    Eps = torch.Tensor([1.797693134862315e+308]).to(device)

    def euclid(reps, labels, prototypes):
        # Calculate the euclidean distance by binomial expansion:
        # (a - b)**2 = a**2 + 2ab + b**2
        # vector --> a.T*a + 2*a.T*b + b.T * b = (a - b)^2
        loss = torch.Tensor([0]).to(device)
        if prototypes is None:
            cosine_matrix = reps @ reps.T
            reps_squared = reps.pow(2).sum(1)
            diff = reps_squared - reps_squared.T
            dist_matrix = torch.sqrt((cosine_matrix + diff + eps)/feature_size)

            pos_mask = (~torch.abs(labels.T - labels).bool()).float()
            neg_mask = (torch.abs(labels.T - labels).bool()).float()

            loss = ((pos_mask*dist_matrix).sum() + (torch.log(torch.exp(neg_mask*(-1)*dist_matrix).sum(1, keepdim = True))).sum())/reps.size()[0]
            return loss

        else:
            cosine_matrix = reps @ prototypes.T
            reps_squared = reps.pow(2).sum(1, keepdim = True)

            prototypes_squared = prototypes.pow(2).sum(1, keepdim = True)
            prototypes_squared = prototypes_squared.repeat(1, reps.size()[0])
            print(f"Proto_squared {prototypes_squared}")
            reps_squared = reps_squared.repeat(1, prototypes.size()[0])
            print(f"Rep_squared {reps_squared}")
            diff = reps_squared.T - prototypes_squared
            dist_matrix = torch.sqrt((cosine_matrix + diff.T + eps)/feature_size)

            pos_mask = (~torch.abs(labels.T - labels).bool()).float()
            neg_mask = (torch.abs(labels.T - labels).bool()).float()

            pos_loss = (pos_mask@dist_matrix).sum()
            print(f"pos_loss: {pos_loss}")
            neg_loss = torch.log(torch.exp(neg_mask@(dist_matrix*(-1))).sum(1, keepdim = True)).sum()
            print(f"neg_loss: {neg_loss}")
            loss = (pos_loss + neg_loss)/reps.size()[0]
            print(f"Loss: {loss}")
            return loss

    def cosine(reps, labels, prototypes):
        # Calculate the cosine similarity of reps / prototypes
        loss = torch.Tensor([0]).to(device)
        if prototypes is None:
            cosine_matrix = reps @ reps.T

            reps_squared = reps.pow(2).sum(1)
            norm_matrix = reps_squared @ reps_squared.T
            sim_matrix = cosine_matrix/norm_matrix

            pos_mask = (~torch.abs(labels.T - labels).bool()).float()
            neg_mask = (torch.abs(labels.T - labels).bool()).float()
            loss = ((pos_mask*-sim_matrix).sum() + (neg_mask*torch.log(-sim_matrix)).sum())/reps.size()[0]
            return loss

        else:
            cosine_matrix = reps @ prototypes.T
            reps_squared = reps.pow(2).sum(1)

            prototypes_squared = prototypes.pow(2).sum(1)
            norm_matrix = reps_squared @ prototypes_squared.T
            sim_matrix = cosine_matrix/norm_matrix

            pos_mask = (~torch.abs(labels.T - labels).bool()).float()
            neg_mask = (torch.abs(labels.T - labels).bool()).float()
            loss = ((pos_mask*-sim_matrix).sum() + (neg_mask*torch.log(-sim_matrix)).sum())/reps.size()[0]
            return loss

    def update_proto(reps, labels, prototypes, counts):
            label_copy = labels.detach().cpu().numpy()
            prototypes_updated = torch.Tensor().to(device)

            for i, label in enumerate(np.unique(label_copy)):
                mask = torch.Tensor(np.array(label == label_copy).astype(float)).to(device)
                label = int(label)
                print(f"Proto_size: {prototype.size()}")
                prototype = mask.T @ reps
                counts[label] += mask.sum()
                prototypes_updated = torch.cat((prototypes_updated, (1-mask.sum()/counts[label])*prototypes[label]+(mask.sum()/counts[label])*prototype))
            return prototypes_updated, counts

    if epoch == 0:
        prototypes = torch.zeros(size = (num_classes, int(0.5*feature_size))).to(device)
        counts = torch.zeros(prototypes.size()[0]).to(device)

    if epoch < warm_up:
        if metric == "euclid":
            loss = euclid(reps, labels, None)
        elif metric == "cosine":
            loss = cosine(reps, labels, None)

    else:
        if epoch == warm_up:
            prototypes, counts = update_proto(reps, labels, prototypes, counts)
        if metric == "euclid":
            loss = euclid(reps, labels, prototypes)
        elif metric == "cosine":
            loss = cosine(reps, labels, prototypes)
        prototypes, counts = update_proto(reps, labels, prototypes, counts)
    print(f"Counts: {counts}")
    return loss, prototypes, counts