# Imports and Setup

In [1]:
import random

import fiftyone as fo
import fiftyone.zoo as foz
import fiftyone.utils.torch as fotorch
import fiftyone.utils.annotations as foua
import fiftyone.utils.patches as foup

import torch
import torch.optim as optim
import torch.nn as nn
import torchvision as tv
from torchvision import transforms as tf
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.transforms import functional as tfunc
from torch.utils.data import DataLoader
import torchvision.utils as vutils

import torchviz

from PIL import Image

import matplotlib.pyplot as plt

import numpy as np
from sklearn.cluster import KMeans, OPTICS

In [2]:
# Run the model on GPU if it is available
ngpu = 1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#  Get COCO 2017 Dataset

In [3]:
coco_17_train = foz.load_zoo_dataset(
    "coco-2017",
    split="train",
    dataset_name="detector-recipe",
    label_types=["detections", "segmentations"]
)

coco_17_test = foz.load_zoo_dataset(
    "coco-2017",
    split="test",
    dataset_name="detector-recipe",
    label_types=["detections", "segmentations"]
)

coco_17_validation = foz.load_zoo_dataset(
    "coco-2017",
    split="validation",
    dataset_name="detector-recipe",
    label_types=["segmentations"]
)

Downloading split 'train' to '/home/boggsj/fiftyone/coco-2017/train' if necessary
Found annotations at '/home/boggsj/fiftyone/coco-2017/raw/instances_train2017.json'
Images already downloaded
Existing download of split 'train' is sufficient
Loading existing dataset 'detector-recipe'. To reload from disk, either delete the existing dataset or provide a custom `dataset_name` to use
Downloading split 'test' to '/home/boggsj/fiftyone/coco-2017/test' if necessary
Found test info at '/home/boggsj/fiftyone/coco-2017/raw/image_info_test2017.json'
Images already downloaded
Existing download of split 'test' is sufficient
Loading existing dataset 'detector-recipe'. To reload from disk, either delete the existing dataset or provide a custom `dataset_name` to use
Downloading split 'validation' to '/home/boggsj/fiftyone/coco-2017/validation' if necessary
Found annotations at '/home/boggsj/fiftyone/coco-2017/raw/instances_val2017.json'
Images already downloaded
Existing download of split 'validation'

In [4]:
session = fo.launch_app(coco_17_validation)

In [5]:
test_input_raw = coco_17_validation["6350559f099f10e8691afb1f"]
test_image = np.asarray(Image.open(test_input_raw.filepath)).copy()
test_tens = tfunc.to_tensor(test_image).to(device)

def draw_image(image):
    plt.figure()
    plt.axis("off")
    if len(image.shape) == 2:
        plt.imshow(image, cmap="binary", interpolation="none")
    else:
        plt.imshow(image, interpolation="none")
    plt.show()
    plt.close()
    
print(f"Test image (shape={test_image.shape}):")
draw_image(test_image)

KeyError: "No sample found with ID '6350559f099f10e8691afb1f'"

In [None]:
def make_ground_truth_segmentations(coco_object, draw=True):
    image = np.asarray(Image.open(coco_object.filepath))
    image_shape = image.shape
    gt_segs = coco_object['segmentations']['detections']
    segs_image = np.zeros(image_shape[:2])

    for i, seg in enumerate(gt_segs, start=1):
        mask = seg['mask']
        label = seg['label']
        bbox_raw = seg['bounding_box']
        bbox_t = int(bbox_raw[1] * image_shape[0])
        bbox_l = int(bbox_raw[0] * image_shape[1])
        bbox_h, bbox_w = mask.shape
        if draw:
            print(label, bbox_h, bbox_w, mask.shape)
        
        segs_image[bbox_t:bbox_t+bbox_h, bbox_l:bbox_l+bbox_w] += mask*i
    
    if draw:
        print("Segmentation:")
        plt.figure()
        plt.axis("off")
        plt.imshow(segs_image)
        plt.show()
        plt.close()
    
    return segs_image
    
test_segs_image = make_ground_truth_segmentations(test_input_raw)

# Get Resnet 50 and test

In [None]:
resnet50_backbone = resnet50(weights=ResNet50_Weights.DEFAULT)
resnet50_backbone.to(device)
resnet50_backbone

In [None]:
print(f"Input shape: {test_tens.shape}")
test_tens = test_tens.reshape((1,) + test_tens.shape)
test_output = resnet50_backbone(test_tens)
print(f"Output shape: {test_output.shape}")

# Make Odin networks

## Network class definition

In [None]:
class FPNStage(nn.Module):
    def __init__(self, fpn_dim, bbone_dim):
        super().__init__()
        self.lat = nn.Conv2d(bbone_dim, fpn_dim, kernel_size=1)
        self.top = nn.ConvTranspose2d(fpn_dim, fpn_dim, kernel_size=4, stride=2, padding=1)
        self.aa  = nn.Conv2d(fpn_dim, fpn_dim, kernel_size=3, stride=1, padding=1)
        
    def forward(self, bbone_activation, prev_fpn_stage):
        lat_out = self.lat(bbone_activation)
        top_out = self.top(prev_fpn_stage)
        
        if not lat_out.shape == top_out.shape:
            top_out = nn.UpsamplingNearest2d(size=lat_out.shape[2:])(top_out)          
            
        final_out = self.aa(lat_out + top_out)
        return final_out
        

class FeatureExtractor(nn.Module):
    def __init__(self, ngpu, fpn_dim=256):
        super().__init__()
        self.ngpu = ngpu
        self.num_fpn_stages = 4
        self.fpn_dim = fpn_dim
        
        self.activation = {}
        self.resnet50_backbone = resnet50(weights=ResNet50_Weights.DEFAULT)
        self.resnet50_backbone.layer1.register_forward_hook(self.get_activation('conv2'))
        self.resnet50_backbone.layer2.register_forward_hook(self.get_activation('conv3'))
        self.resnet50_backbone.layer3.register_forward_hook(self.get_activation('conv4'))
        self.resnet50_backbone.layer4.register_forward_hook(self.get_activation('conv5'))
        
        self.fpn_stage_1 = nn.Conv2d(2048, self.fpn_dim, kernel_size=1)
        self.fpn_stage_2 = FPNStage(self.fpn_dim, 1024)
        self.fpn_stage_3 = FPNStage(self.fpn_dim, 512)
        self.fpn_stage_4 = FPNStage(self.fpn_dim, 256)
        
    def forward(self, input):
        backbone_output = self.resnet50_backbone(input)
        fpn_stage_1_output = self.fpn_stage_1(self.activation['conv5'])
        fpn_stage_2_output = self.fpn_stage_2(self.activation['conv4'], fpn_stage_1_output)
        fpn_stage_3_output = self.fpn_stage_3(self.activation['conv3'], fpn_stage_2_output)
        fpn_stage_4_output = self.fpn_stage_4(self.activation['conv2'], fpn_stage_3_output)

        
        return fpn_stage_4_output
        
    def get_activation(self, name):
        def hook(model, input, output):
            self.activation[name] = output.detach()
        return hook 

In [None]:
class Projector(nn.Module):
    def __init__(self, fpn_dim=256):
        super().__init__()
        self.main = nn.Sequential(
            nn.Linear(fpn_dim, fpn_dim, 1),
            nn.ReLU(inplace=False),
            nn.Linear(fpn_dim, fpn_dim, 1),
            nn.ReLU(inplace=False)
        )
        
    def forward(self, input):
        return self.main(input)

In [None]:
class Predictor(nn.Module):
    def __init__(self, fpn_dim=256):
        super().__init__()
        self.main = nn.Sequential(
            nn.Linear(fpn_dim, fpn_dim, 1),
            nn.ReLU(inplace=False),
            nn.Linear(fpn_dim, fpn_dim, 1),
            nn.ReLU(inplace=False)
        )
        
    def forward(self, input):
        return self.main(input)

In [None]:
f = FeatureExtractor(ngpu).to(device)
g = Projector().to(device)
q = Predictor().to(device)

if (device.type == 'cuda') and (ngpu > 1):
    f = nn.DataParallel(f, list(range(ngpu)))
    g = nn.DataParallel(g, list(range(ngpu)))
    q = nn.DataParallel(q, list(range(ngpu)))
    
    
print(f)
print(g)
print(q)

## Test Odin networks

In [None]:
h = f(test_tens).transpose(1,3)
print(h.shape)

z = g(h)
print(z.shape)

p = q(z)
print(p.shape)

# Initialize and save $\tau$, $\theta$, and $\xi$ parameters

In [None]:
torch.save(f.state_dict(), "f_tau.pth")
torch.save(f.state_dict(), "f_theta.pth")
torch.save(f.state_dict(), "f_xi.pth")

torch.save(g.state_dict(), "g_tau.pth")
torch.save(g.state_dict(), "g_theta.pth")
torch.save(g.state_dict(), "g_xi.pth")

torch.save(q.state_dict(), "q_tau.pth")
torch.save(q.state_dict(), "q_theta.pth")
torch.save(q.state_dict(), "q_xi.pth")

# View Generator

In [None]:
class ViewGenerator(nn.Module):
    """
    nn.Module sub-class which, when called on an image, generates three views of the image, v0, v1, and v2. v1
    and v2 are generated first, and then v0 is generated from their bounding box. The class also has a
    `reverse` method which, given an image the size of v0, produces the equivalent v1 and v2 crops from it.
    """
    SCALE_RANGE = (0.08, 1.0)
    RATIO_RANGE = (0.75, 1.33333333333)

    FLIP_PROB = 0.5

    COLOR_JITTER_PROB = 0.8
    COLOR_OPERATIONS = ["brightness", "contrast", "saturation", "hue"]
    BRIGHTNESS_MAX = 0.4
    CONTRAST_MAX = 0.4
    SATURATION_MAX = 0.2
    HUE_MAX = 0.1

    GRAY_PROB = 0.2

    v1_BLUR_PROB = 1.0
    v2_BLUR_PROB = 0.1

    v1_SOLAR_PROB = 0.0
    v2_SOLAR_PROB = 0.2

    v0_SHAPE = (448, 448)
    v1_SHAPE = v2_SHAPE = (224, 224)
    INTERPOLATION = tf.InterpolationMode.BILINEAR

    def __init__(self, image):
        super().__init__()
        self.crop_v1= tf.RandomResizedCrop.get_params(image, self.SCALE_RANGE, self.RATIO_RANGE)
        self.crop_v2 = tf.RandomResizedCrop.get_params(image, self.SCALE_RANGE, self.RATIO_RANGE)
        
        flip_v1 = random.random() < self.FLIP_PROB
        flip_v2 = random.random() < self.FLIP_PROB
        
        gray_v1 = random.random() < self.GRAY_PROB
        gray_v2 = random.random() < self.GRAY_PROB
        
        blur_v1 = random.random() < self.v1_BLUR_PROB
        blur_v2 = random.random() < self.v2_BLUR_PROB
        
        solar_v1 = random.random() < self.v1_SOLAR_PROB
        solar_v2 = random.random() < self.v2_SOLAR_PROB
        
        if random.random() < self.COLOR_JITTER_PROB: 
            color_params = tf.ColorJitter.get_params(
                [max(0, 1 - self.BRIGHTNESS_MAX), 1 + self.BRIGHTNESS_MAX],
                [max(0, 1 - self.CONTRAST_MAX), 1 + self.CONTRAST_MAX],
                [max(0, 1 - self.SATURATION_MAX), 1 + self.SATURATION_MAX],
                [-self.HUE_MAX, self.HUE_MAX]
            )
            order = color_params[0]
            color_params = color_params[1:]
            jitter_v1 = [(self.COLOR_OPERATIONS[i], color_params[i]) for i in order]
        else:
            jitter_v1 = None
        if random.random() < self.COLOR_JITTER_PROB: 
            color_params = tf.ColorJitter.get_params(
                [max(0, 1 - self.BRIGHTNESS_MAX), 1 + self.BRIGHTNESS_MAX],
                [max(0, 1 - self.CONTRAST_MAX), 1 + self.CONTRAST_MAX],
                [max(0, 1 - self.SATURATION_MAX), 1 + self.SATURATION_MAX],
                [-self.HUE_MAX, self.HUE_MAX]
            )
            order = color_params[0]
            color_params = color_params[1:]
            jitter_v2 = [(self.COLOR_OPERATIONS[i], color_params[i]) for i in order]
        else:
            jitter_v2 = None
        
        self.v1_params = (self.crop_v1, flip_v1, jitter_v1, gray_v1, blur_v1, solar_v1, self.v1_SHAPE)
        self.v2_params = (self.crop_v2, flip_v2, jitter_v2, gray_v2, blur_v2, solar_v2, self.v2_SHAPE)
        
        self.v1_proportional_crop = None
        self.v2_proportional_crop = None
        
    def __call__(self, img):
        v1 = self._generate_sub_view(img, self.v1_params)
        v2 = self._generate_sub_view(img, self.v2_params)
        v0 = self._generate_whole_view(img, self.crop_v1, self.crop_v2)
        
        return v0, v1, v2
    
    def reverse(self, image):
        image_height, image_width = image.shape[-2:]
        
        v1_top_scaled = round(self.v1_proportional_crop[0] * image_height)
        v1_left_scaled = round(self.v1_proportional_crop[1] * image_width)
        v1_height_scaled = round(self.v1_proportional_crop[2] * image_height)
        v1_width_scaled = round(self.v1_proportional_crop[3] * image_width)
        
        v2_top_scaled = round(self.v2_proportional_crop[0] * image_height)
        v2_left_scaled = round(self.v2_proportional_crop[1] * image_width)
        v2_height_scaled = round(self.v2_proportional_crop[2] * image_height)
        v2_width_scaled = round(self.v2_proportional_crop[3] * image_width)
        
        v1_scaled = tfunc.resized_crop(image, v1_top_scaled, v1_left_scaled, v1_height_scaled, v1_width_scaled, self.v1_SHAPE, tf.InterpolationMode.BILINEAR)
        v2_scaled = tfunc.resized_crop(image, v2_top_scaled, v2_left_scaled, v2_height_scaled, v2_width_scaled, self.v2_SHAPE, tf.InterpolationMode.BILINEAR)
        
        if self.v1_params[1]: v1_scaled = tfunc.hflip(v1_scaled)
        if self.v2_params[1]: v2_scaled = tfunc.hflip(v2_scaled)
        
        return v1_scaled, v2_scaled
        
    def _generate_whole_view(self, image, crop_v1, crop_v2):
        v1_top, v1_left, v1_height, v1_width = crop_v1
        v2_top, v2_left, v2_height, v2_width = crop_v2
        v1_bot = v1_top + v1_height
        v1_right = v1_left + v1_width
        v2_bot = v2_top + v2_height
        v2_right = v2_left + v2_width
        
        v0_top = min(v1_top, v2_top)
        v0_left = min(v1_left, v2_left)
        v0_bot = max(v1_bot, v2_bot)    
        v0_right = max(v1_right, v2_right)
        v0_height = v0_bot - v0_top
        v0_width = v0_right - v0_left
        
        self.v0_crop = (v0_top, v0_left, v0_height, v0_width)
        
        v1_proportional_top = (v1_top - v0_top)/v0_height
        v1_proportional_left = (v1_left - v0_left)/v0_width
        v1_proportional_height = v1_height/v0_height
        v1_proportional_width = v1_width/v0_width
        self.v1_proportional_crop = (v1_proportional_top, v1_proportional_left, v1_proportional_height, v1_proportional_width)
        
        v2_proportional_top = (v2_top - v0_top)/v0_height
        v2_proportional_left = (v2_left - v0_left)/v0_width
        v2_proportional_height = v2_height/v0_height
        v2_proportional_width = v2_width/v0_width
        self.v2_proportional_crop = (v2_proportional_top, v2_proportional_left, v2_proportional_height, v2_proportional_width)
        
        return tfunc.resized_crop(image, v0_top, v0_left, v0_height, v0_width, self.v0_SHAPE, tf.InterpolationMode.BILINEAR)
    
    def _generate_sub_view(self, image, params):
        crop, flip, jitter, gray, blur, solar, shape = params
        
        t, l, h, w = crop
        output = tfunc.resized_crop(image, t, l, h, w, shape, self.INTERPOLATION)
        if flip: output = tfunc.hflip(output)
        if jitter is not None:
            for param, value in jitter:
                if param == "brightness": output = tfunc.adjust_brightness(output, value)
                elif param == "contrast": output = tfunc.adjust_contrast(output, value)
                elif param == "hue": output = tfunc.adjust_hue(output, value)
                elif param == "saturation": output = tfunc.adjust_saturation(output, value)
        if gray: output = tfunc.rgb_to_grayscale(output, 3)
        if blur: output = tfunc.gaussian_blur(output, 23, (0.1, 2.0))
        if solar: output = tfunc.solarize(output, 0.5)

        return output

## Test View Generation

In [None]:
view_gen = ViewGenerator(test_tens)

v0, v1, v2 = view_gen(test_tens)

print(f"Inputted test image ({test_image.shape}):")
draw_image(test_image)

print(f"v0 ({v0.shape}):")
draw_image(v0.squeeze().detach().cpu().numpy().transpose(1,2,0))
print(f"v1 ({v1.shape}):")
draw_image(v1.squeeze().detach().cpu().numpy().transpose(1,2,0))
print(f"v2 ({v2.shape}):")
draw_image(v2.squeeze().detach().cpu().numpy().transpose(1,2,0))

## Test apply_transforms to recover alignment

In [None]:
test_image_gt_segs = make_ground_truth_segmentations(test_input_raw, draw=False)
v0_segs = tfunc.resized_crop(tfunc.to_tensor(test_image_gt_segs), *view_gen.v0_crop, (448, 448), tf.InterpolationMode.BILINEAR)
v1_segs, v2_segs = view_gen.reverse(v0_segs)

print(f"v0 segmentation shape: {v0_segs.shape}")
draw_image(v0_segs.squeeze())
print(f"v1 segmentation shape: {v1_segs.shape}")
draw_image(v1_segs.squeeze())
print(f"v2 segmentation shape: {v2_segs.shape}")
draw_image(v2_segs.squeeze())

# Apply $\tau$ Odin networks on generated v0

In [None]:
f.load_state_dict(torch.load("f_tau.pth"))
g.load_state_dict(torch.load("g_tau.pth"))
q.load_state_dict(torch.load("q_tau.pth"))

h0 = f(v0)
h0 = tfunc.resize(h0, (448, 448)).transpose(1,3)
print(h0.shape)

z0 = g(h0)
print(z0.shape)

p0 = q(z0)
print(p0.shape)

# Apply K-means to h0

In [None]:
h0_np = np.transpose(h0.detach().cpu().squeeze(),(1,2,0))

original_shape = h0_np.shape
flat_shape = (original_shape[0]*original_shape[1], original_shape[2])

h0_np_flat = h0_np.reshape((flat_shape))
print(f"Original shape: {h0_np.shape}")
print(f"Flattened shape: {h0_np_flat.shape}")

In [None]:
clusterer = KMeans(n_clusters=8, init='k-means++', n_init=10, max_iter=300, tol=0.0001, verbose=0, random_state=None, copy_x=True, algorithm='lloyd')
# clusterer = OPTICS(cluster_method="xi", xi=0.01, n_jobs=4)
mask_assignments_flat = clusterer.fit_predict(h0_np_flat)

In [None]:
cluster_ids = set(mask_assignments_flat)
num_masks = len(cluster_ids)
print(f"Num masks: {num_masks}")
print(cluster_ids)

In [None]:
mask_assignments = tfunc.resize(tfunc.to_tensor(mask_assignments_flat.reshape(original_shape[:2])), (448, 448), interpolation=tf.InterpolationMode.NEAREST).squeeze()
print("v0:")
draw_image(v0.detach().cpu().squeeze().numpy().transpose(1,2,0))
print("Mask assignments on v0")
draw_image(mask_assignments.squeeze())
print("v1:")
draw_image(v1.detach().cpu().squeeze().numpy().transpose(1,2,0))
print("v2:")
draw_image(v2.detach().cpu().squeeze().numpy().transpose(1,2,0))

In [None]:
m0_layers = [torch.where(mask_assignments==c_id, 1., 0.).numpy() for c_id in range(num_masks)]
m0_np = np.stack(m0_layers, 2)
m0 = tfunc.to_tensor(m0_np)
print(m0.shape)

In [None]:
m1_raw, m2_raw = view_gen.reverse(m0)
m1_raw = m1_raw > 0
m1_raw = m1_raw.to(torch.float32)
m1_nonempty_idxs = []
for idx, mask in enumerate(m1_raw):
    if mask.sum() != 0:
        m1_nonempty_idxs.append(idx)
        
if len(m1_nonempty_idxs) != 0:
    m1 = tfunc.to_tensor(np.stack([m1_raw[i] for i in m1_nonempty_idxs], 2)).to(device)
    m1 = m1.reshape((1,)+m1.shape)
    print(f"m1 shape: {m1.shape}")

    m1_np = m1.detach().cpu().numpy().squeeze()
    draw_image(np.transpose(vutils.make_grid([tfunc.to_tensor(m1_i_np) for m1_i_np in m1_np], padding=5, normalize=True, pad_value = 0.5), (1,2,0)))

In [None]:
m2_raw = m2_raw > 0
m2_raw = m2_raw.to(torch.float32)
m2_nonempty_idxs = []
for idx, mask in enumerate(m2_raw):
    if mask.sum() != 0:
        m2_nonempty_idxs.append(idx)
        
if len(m2_nonempty_idxs) != 0:
    m2 = tfunc.to_tensor(np.stack([m2_raw[i] for i in m2_nonempty_idxs], 2)).to(device)
    m2 = m2.reshape((1,)+m2.shape)
    print(f"m2 shape: {m2.shape}")

m2_np = m2.detach().cpu().numpy().squeeze()

In [None]:
draw_image(m0_np[:,:,5])
draw_image(m1_np[5])
draw_image(m2_np[5])

# Apply $\theta$ networks

In [None]:
f.load_state_dict(torch.load("f_theta.pth"))
g.load_state_dict(torch.load("g_theta.pth"))
q.load_state_dict(torch.load("q_theta.pth"))

h_1_theta = tfunc.resize(f(v1), (224, 224), tf.InterpolationMode.BILINEAR)
print(h_1_theta.shape)

h_2_theta = tfunc.resize(f(v2), (224, 224), tf.InterpolationMode.BILINEAR)
print(h_2_theta.shape)

In [None]:
h_k_1_thetas = torch.concat([(1/(m1_i.sum()))*(torch.where(m1_i==1, h_1_theta, 0).sum(dim=[2,3])) for m1_i in m1[0]])
h_k_2_thetas = torch.concat([(1/(m2_i.sum()))*(torch.where(m2_i==1, h_2_theta, 0).sum(dim=[2,3])) for m2_i in m2[0]])

In [None]:
print(f"h_k_1_thetas shape: {h_k_1_thetas.shape}, device: {h_k_1_thetas.device}")
print(f"h_k_2_thetas shape: {h_k_2_thetas.shape}, device: {h_k_2_thetas.device}")

In [None]:
z_k_1_theta = g(h_k_1_thetas)
print(f"z_k_1_theta shape: {z_k_1_theta.shape}, device: {z_k_1_theta.device}")
z_k_2_theta = g(h_k_2_thetas)
print(f"z_k_2_theta shape: {z_k_2_theta.shape}, device: {z_k_2_theta.device}")

In [None]:
p_k_1_theta = q(z_k_1_theta)
print(f"p_k_1_theta shape: {p_k_1_theta.shape}, device: {p_k_1_theta.device}")
p_k_2_theta = q(z_k_2_theta)
print(f"p_k_2_theta shape: {p_k_2_theta.shape}, device: {p_k_2_theta.device}")

# Apply $\xi$ networks

In [None]:
f.load_state_dict(torch.load("f_xi.pth"))
g.load_state_dict(torch.load("g_xi.pth"))
q.load_state_dict(torch.load("q_xi.pth"))

h_1_xi = tfunc.resize(f(v1), (224, 224), tf.InterpolationMode.BILINEAR)
h_2_xi = tfunc.resize(f(v2), (224, 224), tf.InterpolationMode.BILINEAR)
print(h_1_xi.shape)
print(h_2_xi.shape)

h_k_1_xi = torch.concat([(1/(m1_i.sum()))*(torch.where(m1_i==1, h_1_xi, 0).sum(dim=[2,3])) for m1_i in m1[0]])
h_k_2_xi = torch.concat([(1/(m2_i.sum()))*(torch.where(m2_i==1, h_2_xi, 0).sum(dim=[2,3])) for m2_i in m2[0]])
print(f"h_k_1_xi shape: {h_k_1_xi.shape}, device: {h_k_1_xi.device}")
print(f"h_k_2_xi shape: {h_k_2_xi.shape}, device: {h_k_2_xi.device}")

z_k_1_xi = g(h_k_1_xi)
z_k_2_xi = g(h_k_2_xi)
print(f"z_k_1_xi shape: {z_k_1_xi.shape}, device: {z_k_1_xi.device}")
print(f"z_k_2_xi shape: {z_k_2_xi.shape}, device: {z_k_2_xi.device}")

p_k_1_xi = q(z_k_1_xi)
p_k_2_xi = q(z_k_2_xi)
print(f"p_k_1_xi shape: {p_k_1_xi.shape}, device: {p_k_1_xi.device}")
print(f"z_k_2_xi shape: {p_k_2_xi.shape}, device: {p_k_2_xi.device}")

# Compute loss

In [None]:
def single_mask_similarity(pk1_theta, zk2_xi, alpha=0.1):
    top = torch.dot(pk1_theta, zk2_xi)
    bot = torch.norm(pk1_theta) * torch.norm(zk2_xi)
    return (1/alpha)*(top/bot)

def feature_contrastive_loss(p_k_1_theta, z_k_2_xi, index, alpha=0.1):
    positive_similarity = single_mask_similarity(p_k_1_theta[index], z_k_2_xi[index], alpha)
    negative_similarities_sum = sum([single_mask_similarity(p_k_1_theta[index], zk2_xi, alpha) for i, zk2_xi in enumerate(z_k_2_xi) if i != index])
    bot = positive_similarity + negative_similarities_sum
    nll = -1*torch.log(positive_similarity/bot)
    return nll

def total_contrastive_loss(p_k_1_theta, p_k_2_theta, z_k_1_xi, z_k_2_xi, alpha):
    cum_loss = torch.Tensor([0.]).to(device)
    for mask_idx in range(num_masks):
        l_12_k = feature_contrastive_loss(p_k_1_theta, z_k_2_xi, mask_idx, alpha)
        l_21_k = feature_contrastive_loss(p_k_2_theta, z_k_1_xi, mask_idx, alpha)
        cum_loss += l_12_k + l_21_k
    return cum_loss/num_masks
        

loss = total_contrastive_loss(p_k_1_theta, p_k_2_theta, z_k_1_xi, z_k_2_xi, 0.1)

In [None]:
print(loss)

# Optimize

In [None]:
f.load_state_dict(torch.load("f_theta.pth"))
g.load_state_dict(torch.load("g_theta.pth"))
q.load_state_dict(torch.load("q_theta.pth"))

In [None]:
torch.autograd.set_detect_anomaly(True)
optimizer = optim.SGD(q.parameters(), lr=0.05)
optimizer.zero_grad()

In [None]:
loss.backward()

In [None]:
torchviz.make_dot(loss, params=dict(q.named_parameters()))

In [None]:
loss_test = h0 - 0
loss_test = loss_test.sum()
print(loss_test)

In [None]:
torchviz.make_dot(loss_test, params=dict(q.named_parameters()))

In [None]:
torch.autograd.set_detect_anomaly(True)
optimizer = optim.SGD(q.parameters(), lr=0.05)
optimizer.zero_grad()
loss_test.backward()