In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import motti
import argparse
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader

In [None]:
from dataset import PathMNIST
from medmnist import PathMNIST
from augmentation import BarlowTwinsTransform, pathmnist_normalization
from constant import PathMNIST_HIST, PathMNIST_MEAN, PathMNIST_STD

In [None]:
opts = argparse.Namespace(
    img_size=28,
)

In [None]:
train_transform = BarlowTwinsTransform(
    train=True, 
    input_height=opts.img_size, 
    gaussian_blur=False, jitter_strength=0.5, 
    normalize=pathmnist_normalization()
)
train_dataset = PathMNIST(
    split="train", download=False, 
    transform=train_transform,
    root="../data/medmnist2d/"
)

train_loader = DataLoader(
    train_dataset, shuffle=False, num_workers=4, batch_size=4
)

In [None]:
PathMNIST_MEAN = [0.73765225, 0.53090023, 0.70307171]
PathMNIST_STD = [0.12319908, 0.17607205, 0.12394462]
PathMNIST_HIST = [9366, 9510, 10362, 10404, 8010, 12187, 7892, 9408, 12893]

In [None]:
train_dataset.info

In [None]:
from model.barlow_twins import (
    BarlowTwins,
    get_modified_resnet18,
)

In [None]:
model = BarlowTwins.load_from_checkpoint(
    "../ckpt/epoch=99-step=8700.ckpt",
    encoder = get_modified_resnet18(),
    encoder_out_dim=512,
    z_dim=128,
    num_training_samples=1,
    batch_size=1,
)

In [None]:
for x, y in train_dataset:
    x0 = x[0].unsqueeze(dim=0)
    x1 = x[1].unsqueeze(dim=0)
    x2 = x[2].unsqueeze(dim=0)
    break

In [None]:
model = BarlowTwins.load_from_checkpoint(
    "../ckpt/epoch=99-step=8700.ckpt",
    encoder = get_modified_resnet18(),
    encoder_out_dim=512,
    z_dim=128,
    num_training_samples=1,
    batch_size=1,
)

In [None]:
@torch.no_grad()
def get_corr(model, x1, x2):
    model.eval()
    x1 = x1.to(model.device)
    x2 = x2.to(model.device)
    en1 = model(x1)
    en2 = model(x2)
    z1 = model.projection_head(en1)
    z2 = model.projection_head(en2)
    # z1_norm = (z1 - torch.mean(z1, dim=0)) / torch.std(z1, dim=0)
    # z2_norm = (z2 - torch.mean(z2, dim=0)) / torch.std(z2, dim=0)
    # cross_corr = torch.matmul(z1_norm.T, z2_norm)
    cross_corr = z1.T @ z2
    X0 = x0.cpu().squeeze().numpy()
    X0=np.swapaxes(X0,0,1)
    X0=np.swapaxes(X0,1,2)
    X0 = X0 * PathMNIST_STD + PathMNIST_MEAN
    
    X1 = x1.cpu().squeeze().numpy()
    X1=np.swapaxes(X1,0,1)
    X1=np.swapaxes(X1,1,2)
    X1 = X1 * PathMNIST_STD + PathMNIST_MEAN
    return np.array(cross_corr.cpu()), X0, X1

In [None]:
C, X0, X1 = get_corr(model, x0, x1)

In [None]:
plt.imshow(X0)
plt.imshow(X1)

In [None]:
plt.imshow(C, cmap="gray")