In [None]:
from datetime import datetime
from pathlib import Path
import random

import fiftyone as fo
import fiftyone.zoo as foz
import fiftyone.utils.torch as fotorch
import fiftyone.utils.annotations as foua
import fiftyone.utils.patches as foup

import torch
import torchvision.datasets as dset
import torch.optim as optim
import torch.nn as nn
import torchvision as tv
from torchvision import transforms as tf
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.transforms import functional as tfunc
from torch.utils.data import DataLoader
import torchvision.utils as vutils

import torchviz

from PIL import Image

import matplotlib.pyplot as plt

import numpy as np
from sklearn.cluster import KMeans, OPTICS, DBSCAN, Birch

In [None]:
COCO_2017_DATASET_DIR = Path("./data/coco_2017/")
def make_dataloaders():
    if not COCO_2017_DATASET_DIR.exists():
        foz.download_zoo_dataset("coco-2017", dataset_dir=str(COCO_2017_DATASET_DIR))
    
    training_dataset = dset.CocoDetection(
        root=str(COCO_2017_DATASET_DIR.joinpath("train/data")),
        annFile=str(COCO_2017_DATASET_DIR.joinpath("train/labels.json")),
        transform=tf.Compose([tf.ToTensor()])
        )
    training_dataloader = DataLoader(training_dataset, batch_size=1, shuffle=True)

    testing_dataset = dset.CocoDetection(
        root=str(COCO_2017_DATASET_DIR.joinpath("test/data")),
        annFile=str(COCO_2017_DATASET_DIR.joinpath("test/labels.json")),
        transform=tf.Compose([tf.ToTensor()])
        )
    testing_dataloader = DataLoader(testing_dataset, batch_size=1, shuffle=True)

    validation_dataset = dset.CocoDetection(
        root=str(COCO_2017_DATASET_DIR.joinpath("validation/data")),
        annFile=str(COCO_2017_DATASET_DIR.joinpath("train/labels.json")),
        transform=tf.Compose([tf.ToTensor()])
        )
    validation_dataloader = DataLoader(validation_dataset, batch_size=1, shuffle=True)

    return training_dataloader, testing_dataloader, validation_dataloader

In [None]:
def draw_image(image):
    if isinstance(image, torch.Tensor):
        if str(image.device) == 'cuda:0':
            image = image.detach().cpu()
        image = image.squeeze().numpy()
    plt.figure(figsize=(8,8))
    plt.axis("off")
    if len(image.shape) == 2:
        plt.imshow(image, interpolation="none")
    else:
        image = np.transpose(image, (1,2,0))
        plt.imshow(image, interpolation="none")
    plt.show()
    plt.close()
    
def draw_layers(data):
    if isinstance(data, torch.Tensor):
        if str(data.device) == 'cuda:0':
            data = data.detach().cpu()
        data = data.squeeze().numpy()
    if len(data.shape) == 2:
        return draw_image(data)
    data_layers = [tfunc.to_tensor(d) for d in data]
    grid_image = vutils.make_grid(data_layers, nrow=int(len(data_layers)**0.5), padding=0, pad_value=0.5, normalize=True).cpu()
    draw_image(grid_image)

In [None]:
class FPNStage(nn.Module):
    def __init__(self, fpn_dim, bbone_dim):
        super().__init__()
        self.lat = nn.Conv2d(bbone_dim, fpn_dim, kernel_size=1)
        self.top = nn.ConvTranspose2d(fpn_dim, fpn_dim, kernel_size=4, stride=2, padding=1)
        self.aa  = nn.Conv2d(fpn_dim, fpn_dim, kernel_size=3, stride=1, padding=1)
        
    def forward(self, bbone_activation, prev_fpn_stage):
        lat_out = self.lat(bbone_activation)
        top_out = self.top(prev_fpn_stage)
        
        if not lat_out.shape == top_out.shape:
            top_out = nn.UpsamplingNearest2d(size=lat_out.shape[2:])(top_out)          
            
        final_out = self.aa(lat_out + top_out)
        return final_out
        

class FeatureExtractor(nn.Module):
    def __init__(self, fpn_dim=256):
        super().__init__()
        self.num_fpn_stages = 4
        self.fpn_dim = fpn_dim
        
        self.activation = {}
        self.resnet50_backbone = resnet50(weights=ResNet50_Weights.DEFAULT)
        self.resnet50_backbone.layer1.register_forward_hook(self.get_activation('conv2'))
        self.resnet50_backbone.layer2.register_forward_hook(self.get_activation('conv3'))
        self.resnet50_backbone.layer3.register_forward_hook(self.get_activation('conv4'))
        self.resnet50_backbone.layer4.register_forward_hook(self.get_activation('conv5'))
        
        self.fpn_stage_1 = nn.Conv2d(2048, self.fpn_dim, kernel_size=1)
        self.fpn_stage_2 = FPNStage(self.fpn_dim, 1024)
        self.fpn_stage_3 = FPNStage(self.fpn_dim, 512)
        self.fpn_stage_4 = FPNStage(self.fpn_dim, 256)
        
    def forward(self, input):
        backbone_output = self.resnet50_backbone(input)
        fpn_stage_1_output = self.fpn_stage_1(self.activation['conv5'])
        fpn_stage_2_output = self.fpn_stage_2(self.activation['conv4'], fpn_stage_1_output)
        fpn_stage_3_output = self.fpn_stage_3(self.activation['conv3'], fpn_stage_2_output)
        fpn_stage_4_output = self.fpn_stage_4(self.activation['conv2'], fpn_stage_3_output)

        
        return fpn_stage_4_output
        
    def get_activation(self, name):
        def hook(model, input, output):
            self.activation[name] = output.detach()
        return hook 


class Projector(nn.Module):
    def __init__(self, fpn_dim=256):
        super().__init__()
        self.main = nn.Sequential(
            nn.Linear(fpn_dim, fpn_dim, 1),
            nn.ReLU(inplace=False),
            nn.Linear(fpn_dim, fpn_dim, 1),
            nn.ReLU(inplace=False)
        )
        
    def forward(self, input):
        return self.main(input)


class Predictor(nn.Module):
    def __init__(self, fpn_dim=256):
        super().__init__()
        self.main = nn.Sequential(
            nn.Linear(fpn_dim, fpn_dim, 1),
            nn.ReLU(inplace=False),
            nn.Linear(fpn_dim, fpn_dim, 1),
            nn.ReLU(inplace=False)
        )
        
    def forward(self, input):
        return self.main(input)


class TauModel(nn.Module):
    def __init__(self, fpn_dim=256):
        super().__init__()
        self.f = FeatureExtractor(fpn_dim=fpn_dim)
        self.g = Projector(fpn_dim=fpn_dim)
        self.q = Predictor(fpn_dim=fpn_dim)
        
    def forward(self, input):
        h = tfunc.resize(self.f(input), (448, 448), tf.InterpolationMode.BILINEAR)
        z = self.g(h.transpose(1,3))
        p = self.q(z)
        return h, z, p
        
class ThetaXiModel(nn.Module):
    def __init__(self, fpn_dim=256):
        super().__init__()
        self.f = FeatureExtractor(fpn_dim=fpn_dim)
        self.g = Projector(fpn_dim=fpn_dim)
        self.q = Predictor(fpn_dim=fpn_dim)
        
    def forward(self, input):
        h = tfunc.resize(self.f(input), (224, 224), tf.InterpolationMode.BILINEAR)
        z = self.g(h)
        p = self.q(z)
        return h, z, p

In [None]:
class ViewGenerator(nn.Module):
    """
    nn.Module sub-class which, when called on an image, generates three views of the image, v0, v1, and v2. v1
    and v2 are generated first, and then v0 is generated from their bounding box. The class also has a
    `reverse` method which, given an image the size of v0, produces the equivalent v1 and v2 crops from it.
    """
    SCALE_RANGE = (0.08, 1.0)
    RATIO_RANGE = (0.75, 1.33333333333)

    FLIP_PROB = 0.5

    COLOR_JITTER_PROB = 0.8
    COLOR_OPERATIONS = ["brightness", "contrast", "saturation", "hue"]
    BRIGHTNESS_MAX = 0.4
    CONTRAST_MAX = 0.4
    SATURATION_MAX = 0.2
    HUE_MAX = 0.1

    GRAY_PROB = 0.2

    v1_BLUR_PROB = 1.0
    v2_BLUR_PROB = 0.1

    v1_SOLAR_PROB = 0.0
    v2_SOLAR_PROB = 0.2

    v0_SHAPE = (448, 448)
    v1_SHAPE = v2_SHAPE = (224, 224)
    INTERPOLATION = tf.InterpolationMode.BILINEAR

    def __init__(self, image):
        super().__init__()
        self.crop_v1= tf.RandomResizedCrop.get_params(image, self.SCALE_RANGE, self.RATIO_RANGE)
        self.crop_v2 = tf.RandomResizedCrop.get_params(image, self.SCALE_RANGE, self.RATIO_RANGE)
        
        flip_v1 = random.random() < self.FLIP_PROB
        flip_v2 = random.random() < self.FLIP_PROB
        
        gray_v1 = random.random() < self.GRAY_PROB
        gray_v2 = random.random() < self.GRAY_PROB
        
        blur_v1 = random.random() < self.v1_BLUR_PROB
        blur_v2 = random.random() < self.v2_BLUR_PROB
        
        solar_v1 = random.random() < self.v1_SOLAR_PROB
        solar_v2 = random.random() < self.v2_SOLAR_PROB
        
        if random.random() < self.COLOR_JITTER_PROB: 
            color_params = tf.ColorJitter.get_params(
                [max(0, 1 - self.BRIGHTNESS_MAX), 1 + self.BRIGHTNESS_MAX],
                [max(0, 1 - self.CONTRAST_MAX), 1 + self.CONTRAST_MAX],
                [max(0, 1 - self.SATURATION_MAX), 1 + self.SATURATION_MAX],
                [-self.HUE_MAX, self.HUE_MAX]
            )
            order = color_params[0]
            color_params = color_params[1:]
            jitter_v1 = [(self.COLOR_OPERATIONS[i], color_params[i]) for i in order]
        else:
            jitter_v1 = None
        if random.random() < self.COLOR_JITTER_PROB: 
            color_params = tf.ColorJitter.get_params(
                [max(0, 1 - self.BRIGHTNESS_MAX), 1 + self.BRIGHTNESS_MAX],
                [max(0, 1 - self.CONTRAST_MAX), 1 + self.CONTRAST_MAX],
                [max(0, 1 - self.SATURATION_MAX), 1 + self.SATURATION_MAX],
                [-self.HUE_MAX, self.HUE_MAX]
            )
            order = color_params[0]
            color_params = color_params[1:]
            jitter_v2 = [(self.COLOR_OPERATIONS[i], color_params[i]) for i in order]
        else:
            jitter_v2 = None
        
        self.v1_params = (self.crop_v1, flip_v1, jitter_v1, gray_v1, blur_v1, solar_v1, self.v1_SHAPE)
        self.v2_params = (self.crop_v2, flip_v2, jitter_v2, gray_v2, blur_v2, solar_v2, self.v2_SHAPE)
        
        self.v1_proportional_crop = None
        self.v2_proportional_crop = None
        
    def __call__(self, img):
        v1 = self._generate_sub_view(img, self.v1_params)
        v2 = self._generate_sub_view(img, self.v2_params)
        v0 = self._generate_whole_view(img, self.crop_v1, self.crop_v2)
        
        return v0, v1, v2
    
    def reverse(self, image):
        image_height, image_width = image.shape[-2:]
        
        v1_top_scaled = round(self.v1_proportional_crop[0] * image_height)
        v1_left_scaled = round(self.v1_proportional_crop[1] * image_width)
        v1_height_scaled = round(self.v1_proportional_crop[2] * image_height)
        v1_width_scaled = round(self.v1_proportional_crop[3] * image_width)
        
        v2_top_scaled = round(self.v2_proportional_crop[0] * image_height)
        v2_left_scaled = round(self.v2_proportional_crop[1] * image_width)
        v2_height_scaled = round(self.v2_proportional_crop[2] * image_height)
        v2_width_scaled = round(self.v2_proportional_crop[3] * image_width)
        
        v1_scaled = tfunc.resized_crop(image, v1_top_scaled, v1_left_scaled, v1_height_scaled, v1_width_scaled, self.v1_SHAPE, tf.InterpolationMode.BILINEAR)
        v2_scaled = tfunc.resized_crop(image, v2_top_scaled, v2_left_scaled, v2_height_scaled, v2_width_scaled, self.v2_SHAPE, tf.InterpolationMode.BILINEAR)
        
        if self.v1_params[1]: v1_scaled = tfunc.hflip(v1_scaled)
        if self.v2_params[1]: v2_scaled = tfunc.hflip(v2_scaled)
        
        return v1_scaled, v2_scaled
        
    def _generate_whole_view(self, image, crop_v1, crop_v2):
        v1_top, v1_left, v1_height, v1_width = crop_v1
        v2_top, v2_left, v2_height, v2_width = crop_v2
        v1_bot = v1_top + v1_height
        v1_right = v1_left + v1_width
        v2_bot = v2_top + v2_height
        v2_right = v2_left + v2_width
        
        v0_top = min(v1_top, v2_top)
        v0_left = min(v1_left, v2_left)
        v0_bot = max(v1_bot, v2_bot)    
        v0_right = max(v1_right, v2_right)
        v0_height = v0_bot - v0_top
        v0_width = v0_right - v0_left
        
        self.v0_crop = (v0_top, v0_left, v0_height, v0_width)
        
        v1_proportional_top = (v1_top - v0_top)/v0_height
        v1_proportional_left = (v1_left - v0_left)/v0_width
        v1_proportional_height = v1_height/v0_height
        v1_proportional_width = v1_width/v0_width
        self.v1_proportional_crop = (v1_proportional_top, v1_proportional_left, v1_proportional_height, v1_proportional_width)
        
        v2_proportional_top = (v2_top - v0_top)/v0_height
        v2_proportional_left = (v2_left - v0_left)/v0_width
        v2_proportional_height = v2_height/v0_height
        v2_proportional_width = v2_width/v0_width
        self.v2_proportional_crop = (v2_proportional_top, v2_proportional_left, v2_proportional_height, v2_proportional_width)
        
        return tfunc.resized_crop(image, v0_top, v0_left, v0_height, v0_width, self.v0_SHAPE, tf.InterpolationMode.BILINEAR)
    
    def _generate_sub_view(self, image, params):
        crop, flip, jitter, gray, blur, solar, shape = params
        
        t, l, h, w = crop
        output = tfunc.resized_crop(image, t, l, h, w, shape, self.INTERPOLATION)
        if flip: output = tfunc.hflip(output)
        if jitter is not None:
            for param, value in jitter:
                if param == "brightness": output = tfunc.adjust_brightness(output, value)
                elif param == "contrast": output = tfunc.adjust_contrast(output, value)
                elif param == "hue": output = tfunc.adjust_hue(output, value)
                elif param == "saturation": output = tfunc.adjust_saturation(output, value)
        if gray: output = tfunc.rgb_to_grayscale(output, 3)
        if blur: output = tfunc.gaussian_blur(output, 23, (0.1, 2.0))
        if solar: output = tfunc.solarize(output, 0.5)

        return output

In [None]:
train_dataloader, test_dataloader, val_dataloader = make_dataloaders()
fpn_dim=64
lr_tau = 1e-3
lr_theta = 1e-2
lr_xi = 1e-3
device = 'cuda:0'
num_clusters = 8

f_tau = FeatureExtractor(fpn_dim).to(device)
f_theta = FeatureExtractor(fpn_dim).to(device)
f_xi = FeatureExtractor(fpn_dim).to(device)

g_tau = Projector(fpn_dim).to(device)
g_theta = Projector(fpn_dim).to(device)
g_xi = Projector(fpn_dim).to(device)

q_theta = Predictor(fpn_dim).to(device)

f_theta_optim = optim.SGD(f_theta.parameters(), lr=lr_theta)
g_theta_optim = optim.SGD(g_theta.parameters(), lr=lr_theta)
q_theta_optim = optim.SGD(q_theta.parameters(), lr=lr_theta)

In [None]:
def get_clusterer(h0, z0, eps_coeff=1.0, num_clusters=4, clusterer_type="kmeans"):
    z0_dists = torch.cdist(z0, z0, p=2)
    z0_dist_mean = z0_dists.mean().item()
    z0_dist_std_dev = z0_dists.std().item()
    z0_dist_median = z0_dists.median().item()

    z0_norms = torch.norm(z0, p=2, dim=1)
    z0_norms_mean = z0_norms.mean().item()
    z0_norms_std_dev = z0_norms.std().item()
    z0_norms_median = z0_norms.median().item()

    epsilon = z0_norms_median * eps_coeff

    clusterer_OPTICS = OPTICS(
        cluster_method="dbscan",
        min_samples=0.05, 
        eps=epsilon,
        n_jobs=4
        )
    clusterer_kmeans = KMeans(
        n_clusters=num_clusters,
        n_init=10, 
        max_iter=500,
        tol=0.0001,
        copy_x=False,
        algorithm='elkan'
        )

    if clusterer_type == "kmeans":
        return clusterer_kmeans
    elif clusterer_type == "optics":
        return clusterer_OPTICS
    elif clusterer_type == "both":
        return clusterer_kmeans, clusterer_OPTICS
    else:
        raise ValueError(clusterer_type)

In [None]:
def generate_masks(feature_map, clusterer, view_gen=None):
    feature_map_np = np.transpose(feature_map.detach().cpu().squeeze(),(1,2,0))

    original_shape = feature_map_np.shape
    flat_shape = (original_shape[0]*original_shape[1], original_shape[2])
    feature_map_np_flat = feature_map_np.reshape((flat_shape))

    mask_assignments_flat = clusterer.fit_predict(feature_map_np_flat)
    cluster_ids = set(mask_assignments_flat)
    if -1 in cluster_ids:
        cluster_ids.remove(-1)

    mask_assignments = tfunc.to_tensor(mask_assignments_flat.reshape(original_shape[:2])).squeeze()
    m0_layers = [torch.where(mask_assignments==c_id, 1., 0.).numpy() for c_id in sorted(cluster_ids)]
    m0_np = np.stack(m0_layers, 2)
    m0 = tfunc.to_tensor(m0_np).to(device)

    if view_gen:    
        m1_raw, m2_raw = view_gen.reverse(m0)
        m1 = m1_raw[torch.argwhere(m1_raw.sum(dim=(1,2))>0)].squeeze()
        m2 = m2_raw[torch.argwhere(m2_raw.sum(dim=(1,2))>0)].squeeze()
    else:
        m1 = m2 = None
        
    m0 = m0.reshape((1,) + m0.shape)
    m1 = m1.reshape((1,) + m1.shape)
    m2 = m2.reshape((1,) + m2.shape)
    return cluster_ids, mask_assignments, m0, m1, m2

In [None]:
def run_tau_network(v0):
    h_0 = f_tau(v0)
    z_0 = g_tau(h_0.transpose(1,3)).transpose(3,1)
    return h_0, z_0

def run_theta_network(v1, v2, m1, m2):
    h_1 = tfunc.resize(f_theta(v1), (224,224), tf.InterpolationMode.BILINEAR)
    h_2 = tfunc.resize(f_theta(v2), (224,224), tf.InterpolationMode.BILINEAR)

    masked_h_1 = torch.concat([torch.concat([torch.mul(m1_i, h_1) for m1_i in m1_j]) for m1_j in m1])
    masked_h_2 = torch.concat([torch.concat([torch.mul(m2_i, h_1) for m2_i in m2_j]) for m2_j in m2])

    sum_masked_h_1 = masked_h_1.sum(dim=(2,3))
    sum_masked_h_2 = masked_h_2.sum(dim=(2,3))

    m1_sums = m1.sum(dim=(2,3)).transpose(1,0)
    m2_sums = m2.sum(dim=(2,3)).transpose(1,0)

    h_k_1 = sum_masked_h_1/m1_sums
    h_k_2 = sum_masked_h_2/m2_sums

    z_k_1 = g_theta(h_k_1)
    z_k_2 = g_theta(h_k_2)

    p_k_1 = q_theta(z_k_1)
    p_k_2 = q_theta(z_k_2)

    return (h_1, h_2), (masked_h_1, masked_h_2), (h_k_1, h_k_2), (z_k_1, z_k_2), (p_k_1, p_k_2)

def run_xi_network(v1, v2, m1, m2):
    h_1 = tfunc.resize(self.f_xi(v1), (224,224), tf.InterpolationMode.BILINEAR)
    h_2 = tfunc.resize(self.f_xi(v2), (224,224), tf.InterpolationMode.BILINEAR)

    h_k_1 = compute_mask_pooled_vectors(h_1, m1)
    h_k_2 = compute_mask_pooled_vectors(h_2, m2)

    z_k_1 = self.g_xi(h_k_1)
    z_k_2 = self.g_xi(h_k_2)

    return (h_1, h_2), (h_k_1, h_k_2), (z_k_1, z_k_2)

In [None]:
def single_mask_similarity(p_theta, z_xi, alpha):
    top = torch.dot(p_theta, z_xi)
    bot = torch.mul(torch.norm(p_theta), torch.norm(z_xi))
    return (1.0/alpha)*(top/bot)

def feature_contrastive_loss(p_k_1_theta, z_k_2_xi, index, alpha=0.1):
    if index >= len(p_k_1_theta): return 0
    if index >= len(z_k_2_xi): return 0
    positive_similarity = single_mask_similarity(p_k_1_theta[index], z_k_2_xi[index], alpha)
    negative_similarities_sum = sum([single_mask_similarity(p_k_1_theta[index], zk2_xi, alpha) for i, zk2_xi in enumerate(z_k_2_xi) if i != index])
    bot = positive_similarity + negative_similarities_sum
    nll = -1*torch.log(positive_similarity/bot)
    return nll

def total_contrastive_loss(p_k_1_theta, p_k_2_theta, z_k_1_xi, z_k_2_xi, alpha):
    cum_loss = torch.Tensor([0.]).to(device)
    num_masks_found = min([len(p_k_1_theta), len(p_k_2_theta), len(z_k_1_xi), len(z_k_2_xi)])
    if num_masks_found == 0:
        num_masks_found = 1e-5
    for mask_idx in range(num_masks_found):
        l_12_k = feature_contrastive_loss(p_k_1_theta, z_k_2_xi, mask_idx, alpha)
        l_21_k = feature_contrastive_loss(p_k_2_theta, z_k_1_xi, mask_idx, alpha)
        cum_loss += l_12_k + l_21_k
    return cum_loss/num_masks_found

In [None]:
def within_mask_closeness_loss(masked_h):
    within_h_max = torch.where(masked_h != 0, masked_h, float("-Inf")).amax(dim=(2,3))
    within_h_min = torch.where(masked_h != 0, masked_h, float("Inf")).amin(dim=(2,3))
    within_h_range = within_h_max - within_h_min
    average_range = within_h_range.sum()/within_h_range.shape[0]
    return average_range

In [None]:
def byol_parameter_adjustment(param_zip):
    with torch.no_grad():
        for p_f_tau, p_f_theta, p_f_xi in param_zip:
            new_p_f_xi = (1-lr_xi)*p_f_xi + lr_xi*p_f_theta
            new_p_f_tau = (1-lr_tau)*p_f_tau + lr_tau*p_f_theta

            p_f_xi.copy_(new_p_f_xi)
            p_f_tau.copy_(new_p_f_tau)

In [None]:
def run_networks_training_mode(v1, v2, m1, m2):
    _, __, ___, (pk1_theta, pk2_theta) = run_theta_network(v1, v2, m1, m2)
    _, __, (zk1_xi, zk2_xi) = run_xi_network(v1, v2, m1, m2)
    loss = total_contrastive_loss(pk1_theta, pk2_theta, zk1_xi, zk2_xi, 0.1)

    loss.backward()
    f_theta_optim.step()
    g_theta_optim.step()
    q_theta_optim.step()

    byol_parameter_adjustment(zip(f_tau.parameters(), f_theta.parameters(), f_xi.parameters()))
    byol_parameter_adjustment(zip(g_tau.parameters(), g_theta.parameters(), g_xi.parameters()))

    return loss

In [None]:
def run_training_iteration(input_tensor, eps_coeff, clusterer_type="kmeans"):
    view_gen = ViewGenerator(input_tensor)
    v0, v1, v2 = view_gen(input_tensor)

    h0, z0 = run_tau_network(v0)
    cluster_results = dict()

    loss = torch.Tensor([0.]).to(device)
    if clusterer_type in ["kmeans", "both"]:
        clusterer = get_clusterer(h0, z0, eps_coeff, num_clusters, clusterer_type="kmeans")
        cluster_ids, masks, m0, m1, m2 = generate_masks(z0, view_gen, clusterer)
        loss += run_networks_training_mode(v1,v2, m1, m2)
        cluster_results["kmeans"] = (cluster_ids, masks, m0, m1, m2)

    if clusterer_type in ["optics", "both"]:
        clusterer = get_clusterer(h0, z0, eps_coeff, num_clusters, clusterer_type="optics")
        cluster_ids, masks, m0, m1, m2 = generate_masks(z0, view_gen, clusterer)
        loss += run_networks_training_mode(v1,v2, m1, m2)
        cluster_results["optics"] = (cluster_ids, masks, m0, m1, m2)

    return cluster_results, loss

In [None]:
input_tensor = next(iter(test_dataloader))[0].to(device)
draw_image(input_tensor)

In [None]:
view_gen = ViewGenerator(input_tensor)
v0, v1, v2 = view_gen(input_tensor)

h0, z0 = run_tau_network(v0)
clusterer = get_clusterer(h0, z0, eps_coeff=1.0, num_clusters=num_clusters, clusterer_type="kmeans")
cluster_ids, mask_assignments, m0, m1, m2 = generate_masks(h0, clusterer, view_gen)

In [None]:
h_1_theta = tfunc.resize(f_theta(v1), (224,224), tf.InterpolationMode.BILINEAR)
h_2_theta = tfunc.resize(f_theta(v2), (224,224), tf.InterpolationMode.BILINEAR)

masked_h_1_theta = torch.concat([torch.concat([torch.mul(m1_i, h_1_theta) for m1_i in m1_j]) for m1_j in m1])
masked_h_2_theta = torch.concat([torch.concat([torch.mul(m2_i, h_2_theta) for m2_i in m2_j]) for m2_j in m2])

sum_masked_h_1_theta = masked_h_1_theta.sum(dim=(2,3))
sum_masked_h_2_theta = masked_h_2_theta.sum(dim=(2,3))

m1_sums_theta = m1.sum(dim=(2,3)).transpose(1,0)
m2_sums_theta = m2.sum(dim=(2,3)).transpose(1,0)

h_k_1_theta = sum_masked_h_1_theta/m1_sums_theta
h_k_2_theta = sum_masked_h_2_theta/m2_sums_theta

z_k_1_theta = g_theta(h_k_1_theta)
z_k_2_theta = g_theta(h_k_2_theta)

p_k_1_theta = q_theta(z_k_1_theta)
p_k_2_theta = q_theta(z_k_2_theta)

In [None]:
h_1_xi = tfunc.resize(f_xi(v1), (224,224), tf.InterpolationMode.BILINEAR)
h_2_xi = tfunc.resize(f_xi(v2), (224,224), tf.InterpolationMode.BILINEAR)

masked_h_1_xi = torch.concat([torch.concat([torch.mul(m1_i, h_1_xi) for m1_i in m1_j]) for m1_j in m1])
masked_h_2_xi = torch.concat([torch.concat([torch.mul(m2_i, h_2_xi) for m2_i in m2_j]) for m2_j in m2])

sum_masked_h_1_xi = masked_h_1_xi.sum(dim=(2,3))
sum_masked_h_2_xi = masked_h_2_xi.sum(dim=(2,3))

m1_sums_xi = m1.sum(dim=(2,3)).transpose(1,0)
m2_sums_xi = m2.sum(dim=(2,3)).transpose(1,0)

h_k_1_xi = sum_masked_h_1_xi/m1_sums_xi
h_k_2_xi = sum_masked_h_2_xi/m2_sums_xi

z_k_1_xi = g_xi(h_k_1_xi)
z_k_2_xi = g_xi(h_k_2_xi)