# Load modules

In [None]:
import os
import torch
import torch.nn as nn
import torchvision
import pytorch_lightning as pl

from lightly.loss import NegativeCosineSimilarity, NTXentLoss
from lightly.models.modules.heads import SimSiamPredictionHead, SimSiamProjectionHead

from sklearn import random_projection
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt
from einops import rearrange
from PIL import Image
import matplotlib.offsetbox as osb
from matplotlib import rcParams as rcp
import matplotlib.gridspec as gridspec

from cuml import UMAP
from cuml.cluster import hdbscan

import torchvision.transforms.functional as functional

from data.dataset import SDOTilesDataset
from data.augmentation_list import AugmentationList
import matplotlib.image as mpimg
import seaborn as sns

from utils.image_utils import read_image

seed = 42  # So clever.
pl.seed_everything(seed, workers=True)

# Data Setup

### Define augmentation

In [None]:
augmentation_list = AugmentationList('euv')
augmentation_list.keys

In [None]:
augmentation_list.keys = ['h_flip']

In [None]:
augmentation_list.keys

In [None]:
augmentation_list.randomize()

### Initialize dataset

In [None]:
# DATA_PATH = '/home/jovyan/scratch_space/andresmj/data/AIA_211_193_171_256x256_small'
# DATA_PATH = '/d0/euv/aia/preprocessed_ext/AIA_211_193_171/AIA_211_193_171_256x256'
DATA_PATH = '/d0/euv/aia/preprocessed_ext/AIA_211_193_171/AIA_211_193_171_256x256_small'
DATA_STRIDE = 10
dataset = SDOTilesDataset(
    data_path=DATA_PATH, augmentation_list=augmentation_list, augmentation_strategy='single', data_stride=DATA_STRIDE
)
dataset.__len__()

### Visualize Solar image

In [None]:
image_path = '/home/jovyan/scratch_space/andresmj/data/20141109_080002_aia_211_193_171.jpg'
img = mpimg.imread(image_path)
plt.imshow(img)
plt.axis('off');

### Visualize Augmentation

In [None]:
# Get random index
idx = np.random.randint(0, high=dataset.__len__())
idx

In [None]:
x0, x1, _ = dataset.__getitem__(idx)

fig = plt.figure(figsize=np.array([4, 2]), constrained_layout=True)
spec = fig.add_gridspec(ncols=2, nrows=1, wspace=0, hspace=0)

ax = fig.add_subplot(spec[0, 0])
ax.imshow(rearrange(x0, 'c h w -> h w c'))
ax.set_xticks([])
ax.set_yticks([])
ax.set_title("Original")

ax = fig.add_subplot(spec[0, 1])
ax.imshow(rearrange(x1, 'c h w -> h w c'))
ax.set_xticks([])
ax.set_yticks([])
ax.set_title("Augmented")


In [None]:
DEVICE = 'cuda'
EPOCHS = 2
BATCH_SIZE = 256
AUGMENTATION = 'single'
LOSS = 'contrast'   # 'contrast' or 'cos'
LEARNING_RATE = 0.1
PROJECTION_HEAD_SIZE = 128
PREDICTION_HEAD_SIZE = 128
EMBEDING_SIZE = 64

CHECKPOINT_PATH = '/d0/amunozj/git_repos/hss-self-supervision/sim_siam_256_ext/epoch-epoch=15.ckpt'

# Build dataloader

In [None]:
val_dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    drop_last=False,
    num_workers=4,
)

# Setup SimSiam model

In [None]:
class SimSiam(pl.LightningModule):
    def __init__(self):
        super().__init__()
        resnet = torchvision.models.resnet18()
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])
        self.projection_head = SimSiamProjectionHead(512, 512, PROJECTION_HEAD_SIZE)
        self.prediction_head = SimSiamPredictionHead(PROJECTION_HEAD_SIZE, EMBEDING_SIZE, PREDICTION_HEAD_SIZE)
        self.criterion = NegativeCosineSimilarity()

        self.loss = LOSS
        self.loss_cos = NegativeCosineSimilarity()
        self.loss_contrast = NTXentLoss()

    def forward(self, x):
        f = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(f)
        p = self.prediction_head(z)
        z = z.detach()
        return z, p

    def training_step(self, batch, batch_idx):
        (x0, x1, _) = batch
        z0, p0 = self.forward(x0)
        z1, p1 = self.forward(x1)

        loss_cos = 0.5 * (self.loss_cos(p0, z1) + self.loss_cos(p1, z0))
        loss_contrast = 0.5 * (self.loss_contrast(p0, z1) + self.loss_contrast(p1, z0))

        if self.loss == 'cos':
            loss = loss_cos
        else:
            loss = loss_contrast

        self.log('loss cos', loss_cos)
        self.log('loss contrast', loss_contrast)
        self.log('loss', loss)

        return loss

    def configure_optimizers(self):
        optim = torch.optim.SGD(self.parameters(), lr=0.06)
        return optim
        
model = SimSiam.load_from_checkpoint(CHECKPOINT_PATH).to(DEVICE)
model

# Visualize Output

In [None]:
# Now that the model is trained, embed images into dataset
embeddings = []
filenames = []

# disable gradients for faster calculations
model.eval()
with torch.no_grad():
    # passes batches and filenames to model to find embeddings
    # embedding -> vectorize image, simpler representation of image
    for i, (x, _, fnames) in tqdm(enumerate(val_dataloader), total=len(val_dataloader)):
        # move the images to the gpu
        # x = x.to(DEVICE)
        # embed the images with the pre-trained backbone
        y = model.backbone(x.to(DEVICE)).flatten(start_dim=1)
        # store the embeddings and filenames in lists
        embeddings.append(y)
        filenames = filenames + list(fnames)

# concatenate the embeddings and convert to numpy
embeddings = torch.cat(embeddings, dim=0)
embeddings = embeddings.cpu().numpy()

In [None]:
n_neighbors=5
min_dist=0.0
n_components=2
metric='euclidean'
spread = 0.5
repulsion_strength = 2

fit = UMAP(
    n_neighbors=n_neighbors,
    # min_dist=min_dist,
    # n_components=n_components,
    metric=metric,
    # spread=spread,
    # repulsion_strength=repulsion_strength,
    verbose=True
)

embeddings_2d = fit.fit_transform(embeddings)
# normalize the embeddings to fit in the [0, 1] square
M = np.max(embeddings_2d, axis=0)
m = np.min(embeddings_2d, axis=0)
embeddings_2d = (embeddings_2d - m) / (M - m)

In [None]:
# # for the scatter plot we want to transform the images to a two-dimensional
# # vector space using a random Gaussian projection
# projection = random_projection.GaussianRandomProjection(n_components=2)
# embeddings_2d = projection.fit_transform(embeddings)

# # normalize the embeddings to fit in the [0, 1] square
# M = np.max(embeddings_2d, axis=0)
# m = np.min(embeddings_2d, axis=0)
# embeddings_2d = (embeddings_2d - m) / (M - m)

In [None]:
# display a scatter plot of the dataset
# clustering similar images together

def get_scatter_plot_with_thumbnails():
    """Creates a scatter plot with image overlays."""
    # initialize empty figure and add subplot
    fig = plt.figure(figsize=(9,9), dpi=150)
    fig.suptitle("Scatter Plot of the SDO/AIA 171 Tiles")
    ax = fig.add_subplot(1, 1, 1)
    # shuffle images and find out which images to show
    shown_images_idx = []
    shown_images = np.array([[1.0, 1.0]])
    iterator = [i for i in range(embeddings_2d.shape[0])]
    np.random.shuffle(iterator)
    for i in iterator:
        # only show image if it is sufficiently far away from the others
        dist = np.sum((embeddings_2d[i] - shown_images) ** 2, 1)
        if np.min(dist) < 5e-4:
            continue
        shown_images = np.r_[shown_images, [embeddings_2d[i]]]
        shown_images_idx.append(i)

    # plot image overlays
    for idx in shown_images_idx:
        thumbnail_size = int(rcp["figure.figsize"][0] * 2.0)
        # path = os.path.join(path_to_data, filenames[idx])
        img = Image.open(filenames[idx])
        img = functional.resize(img, thumbnail_size)
        img = np.array(img)
        img_box = osb.AnnotationBbox(
            osb.OffsetImage(img, cmap=plt.cm.gray_r),
            embeddings_2d[idx],
            pad=0.2,
        )
        ax.add_artist(img_box)

    # set aspect ratio
    ratio = 1.0 / ax.get_data_ratio()
    ax.set_aspect(ratio, adjustable="box")
    return ratio


# get a scatter plot with thumbnail overlays
ratio = get_scatter_plot_with_thumbnails()

# Cluster

In [None]:
min_samples = 5
min_cluster_size = 5

clusterer = hdbscan.HDBSCAN(min_samples=min_samples, min_cluster_size=min_cluster_size)
clusterer.fit(embeddings_2d)

# Plotting best results
sns.color_palette('Paired', clusterer.labels_.max()+1)

color_palette = sns.color_palette('Paired', clusterer.labels_.max()+1)
cluster_colors = [color_palette[x] if x >= 0 else (0.5, 0.5, 0.5) for x in clusterer.labels_]
cluster_member_colors = [sns.desaturate(x, p) for x, p in zip(cluster_colors, clusterer.probabilities_)]
cluster_alphas = np.ones_like(clusterer.labels_)*0.5
cluster_alphas[clusterer.labels_==-1] = 0.1

In [None]:
MARKER_SIZE = 3
fig = plt.figure(figsize=(9,9), dpi=150)
u = embeddings_2d
if n_components == 1:
    ax = fig.add_subplot(111)
    ax.scatter(u[:,0], range(len(u)), c=cluster_colors, s=MARKER_SIZE, alpha=cluster_alphas)
if n_components == 2:
    ax = fig.add_subplot(111)
    ax.scatter(u[:,0], u[:,1], c=cluster_colors, s=MARKER_SIZE, alpha=cluster_alphas)
if n_components == 3:
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(u[:,0], u[:,1], u[:,2], c=cluster_colors, s=MARKER_SIZE, alpha=cluster_alphas)
ax.set_aspect(ratio, adjustable="box")
# plt.title(title, fontsize=18)

# Visualize clusters

In [None]:
def cluster_plot(cluster_rows:int, cluster_col:int, images_list:np.array, dpi:int):
  fig = plt.figure(figsize=[cluster_col, cluster_rows], layout='constrained', dpi=dpi)
  spec = gridspec.GridSpec(ncols=cluster_col, nrows=cluster_rows, figure=fig, wspace=0, hspace=0)

  # Shuffle list and use first 16 filepaths to plot images
  np.random.shuffle(images_list)

  # For loop to go through and use Team Yellow's load image module
  n = 0
  for j in range(cluster_rows):
    for i in range(cluster_col):
      if images_list.shape[0] > n:
        image = read_image(image_loc = images_list[n], image_format = "jpg")
        # Scatter plot
        ax1 = fig.add_subplot(spec[j, i])
        ax1.imshow(image)
        ax1.set_xticks([])
        ax1.set_yticks([])
      else:
        break
      n += 1

In [None]:
# Loop through all cluster labels
n_clusters_2_plot = 20

# Create a random draw of integers between a range
clusters_2_plot = np.random.choice(clusterer.labels_[clusterer.labels_ >= 0],
                                   size=n_clusters_2_plot,
                                   replace=False)

for cluster in clusters_2_plot:
  cluster_plot(cluster_rows=2,
               cluster_col=10,
               images_list=np.array(filenames)[clusterer.labels_==cluster],
               dpi=300)