# Load modules

In [None]:
import os
import torch
import torch.nn as nn
import torchvision
import pytorch_lightning as pl

from lightly.loss import NegativeCosineSimilarity, NTXentLoss
from lightly.models.modules.heads import SimSiamPredictionHead, SimSiamProjectionHead

from sklearn import random_projection

import numpy as np
import matplotlib.pyplot as plt
from einops import rearrange
from PIL import Image
import matplotlib.offsetbox as osb
from matplotlib import rcParams as rcp

from cuml import UMAP

import torchvision.transforms.functional as functional

from data.dataset import SDOTilesDataset
from data.augmentation_list import AugmentationList

seed = 42  # So clever.
pl.seed_everything(seed, workers=True)

# Data Setup

### Define augmentation

In [None]:
augmentation_list = AugmentationList('euv')
augmentation_list.keys

In [None]:
augmentation_list.keys = ['h_flip']

In [None]:
augmentation_list.keys

In [None]:
augmentation_list.randomize()

### Initialize dataset

In [None]:
DATA_STRIDE = 1
DATA_PATH = '/home/jovyan/scratch_space/andresmj/hss-self-supervision/AIA_211_193_171_128x128_small'
DATA_PATH = '/d0/euv/aia/preprocessed_ext/AIA_211_193_171/AIA_211_193_171_256x256'
DATA_STRIDE = 10000
dataset = SDOTilesDataset(
    data_path=DATA_PATH, augmentation_list=augmentation_list, augmentation_strategy='single', data_stride=DATA_STRIDE
)
dataset.__len__()

### Visualize Augmentation

In [None]:
# Get random index
idx = np.random.randint(0, high=dataset.__len__())
idx

In [None]:
x0, x1, _ = dataset.__getitem__(idx)

fig = plt.figure(figsize=np.array([4, 2]), constrained_layout=True)
spec = fig.add_gridspec(ncols=2, nrows=1, wspace=0, hspace=0)

ax = fig.add_subplot(spec[0, 0])
ax.imshow(rearrange(x0, 'c h w -> h w c'))
ax.set_xticks([])
ax.set_yticks([])
ax.set_title("Original")

ax = fig.add_subplot(spec[0, 1])
ax.imshow(rearrange(x1, 'c h w -> h w c'))
ax.set_xticks([])
ax.set_yticks([])
ax.set_title("Augmented")


# Setup training parameters

In [None]:
DEVICE = 'cuda'
EPOCHS = 2
BATCH_SIZE = 64
AUGMENTATION = 'single'
LOSS = 'contrast'   # 'contrast' or 'cos'
LEARNING_RATE = 0.1
PROJECTION_HEAD_SIZE = 128
PREDICTION_HEAD_SIZE = 128
EMBEDING_SIZE = 64

# Build dataloader

In [None]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=False,
    num_workers=4,
)

val_dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    drop_last=False,
    num_workers=4,
)

# Setup SimSiam model

In [None]:
class SimSiam(pl.LightningModule):
    def __init__(self):
        super().__init__()
        resnet = torchvision.models.resnet18()
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])
        self.projection_head = SimSiamProjectionHead(512, 512, PROJECTION_HEAD_SIZE)
        self.prediction_head = SimSiamPredictionHead(PROJECTION_HEAD_SIZE, EMBEDING_SIZE, PREDICTION_HEAD_SIZE)
        self.criterion = NegativeCosineSimilarity()

        self.loss = LOSS
        self.loss_cos = NegativeCosineSimilarity()
        self.loss_contrast = NTXentLoss()

    def forward(self, x):
        f = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(f)
        p = self.prediction_head(z)
        z = z.detach()
        return z, p

    def training_step(self, batch, batch_idx):
        (x0, x1, _) = batch
        z0, p0 = self.forward(x0)
        z1, p1 = self.forward(x1)

        loss_cos = 0.5 * (self.loss_cos(p0, z1) + self.loss_cos(p1, z0))
        loss_contrast = 0.5 * (self.loss_contrast(p0, z1) + self.loss_contrast(p1, z0))

        if self.loss == 'cos':
            loss = loss_cos
        else:
            loss = loss_contrast

        self.log('loss cos', loss_cos)
        self.log('loss contrast', loss_contrast)
        self.log('loss', loss)

        return loss

    def configure_optimizers(self):
        optim = torch.optim.SGD(self.parameters(), lr=0.06)
        return optim
        
model = SimSiam()
model

# Train model

In [None]:
trainer = pl.Trainer(max_epochs=EPOCHS,
                     accelerator=DEVICE, devices=1, strategy="auto",deterministic=True)

trainer.fit(model=model, train_dataloaders=dataloader)

# Visualize Output

In [None]:
# Now that the model is trained, embed images into dataset
embeddings = []
filenames = []

# disable gradients for faster calculations
model.eval()
with torch.no_grad():
    # passes batches and filenames to model to find embeddings
    # embedding -> vectorize image, simpler representation of image
    for i, (x, _, fnames) in enumerate(val_dataloader):
        # move the images to the gpu
        # x = x.to(DEVICE)
        # embed the images with the pre-trained backbone
        y = model.backbone(x).flatten(start_dim=1)
        # store the embeddings and filenames in lists
        embeddings.append(y)
        filenames = filenames + list(fnames)

# concatenate the embeddings and convert to numpy
embeddings = torch.cat(embeddings, dim=0)
embeddings = embeddings.cpu().numpy()

In [None]:
n_neighbors=5
min_dist=0.0
n_components=2
metric='euclidean'
spread = 0.5
repulsion_strength = 2

fit = UMAP(
    n_neighbors=n_neighbors,
    # min_dist=min_dist,
    # n_components=n_components,
    metric=metric,
    # spread=spread,
    # repulsion_strength=repulsion_strength,
    verbose=True
)

embeddings_2d = fit.fit_transform(embeddings)
# normalize the embeddings to fit in the [0, 1] square
M = np.max(embeddings_2d, axis=0)
m = np.min(embeddings_2d, axis=0)
embeddings_2d = (embeddings_2d - m) / (M - m)

In [None]:
# # for the scatter plot we want to transform the images to a two-dimensional
# # vector space using a random Gaussian projection
# projection = random_projection.GaussianRandomProjection(n_components=2)
# embeddings_2d = projection.fit_transform(embeddings)

# # normalize the embeddings to fit in the [0, 1] square
# M = np.max(embeddings_2d, axis=0)
# m = np.min(embeddings_2d, axis=0)
# embeddings_2d = (embeddings_2d - m) / (M - m)

In [None]:
# display a scatter plot of the dataset
# clustering similar images together

def get_scatter_plot_with_thumbnails():
    """Creates a scatter plot with image overlays."""
    # initialize empty figure and add subplot
    fig = plt.figure()
    fig.suptitle("Scatter Plot of the SDO/AIA 171 Tiles")
    ax = fig.add_subplot(1, 1, 1)
    # shuffle images and find out which images to show
    shown_images_idx = []
    shown_images = np.array([[1.0, 1.0]])
    iterator = [i for i in range(embeddings_2d.shape[0])]
    np.random.shuffle(iterator)
    for i in iterator:
        # only show image if it is sufficiently far away from the others
        dist = np.sum((embeddings_2d[i] - shown_images) ** 2, 1)
        if np.min(dist) < 2e-3:
            continue
        shown_images = np.r_[shown_images, [embeddings_2d[i]]]
        shown_images_idx.append(i)

    # plot image overlays
    for idx in shown_images_idx:
        thumbnail_size = int(rcp["figure.figsize"][0] * 2.0)
        # path = os.path.join(path_to_data, filenames[idx])
        img = Image.open(filenames[idx])
        img = functional.resize(img, thumbnail_size)
        img = np.array(img)
        img_box = osb.AnnotationBbox(
            osb.OffsetImage(img, cmap=plt.cm.gray_r),
            embeddings_2d[idx],
            pad=0.2,
        )
        ax.add_artist(img_box)

    # set aspect ratio
    ratio = 1.0 / ax.get_data_ratio()
    ax.set_aspect(ratio, adjustable="box")


# get a scatter plot with thumbnail overlays
get_scatter_plot_with_thumbnails()