# Load modules

In [None]:
import torch
import torch.nn as nn
import torchvision
import pytorch_lightning as pl

from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules.heads import (
    SimSiamPredictionHead,
    SimSiamProjectionHead
)

import numpy as np
import matplotlib.pyplot as plt
from einops import rearrange

from data.dataset import SDOTilesDataset

# Setup training parameters

In [None]:
seed = 42 # So clever.
pl.seed_everything(seed, workers=True)

data_path = '/home/jovyan/scratch_space/andresmj/hss-self-supervision/AIA_211_193_171_128x128_small'
epochs = 16
data_stride = 1
batch_size = 64
augmentation = 'single'
loss = 'contrast'   # 'contrast' or 'cos'
learning_rate = 0.1
cosine_scheduler_start = .1
cosine_scheduler_end = 1.0
projection_size = 128
prediction_size = 128

# Initialize dataset

In [None]:
dataset = SDOTilesDataset(data_path=data_path, augmentation=augmentation, data_stride=data_stride)

In [None]:
# Get random index
idx = np.random.randint(0, high=dataset.__len__())
idx

In [None]:
x0, x1, _ = dataset.__getitem__(idx)

fig = plt.figure(figsize=np.array([4, 2]), constrained_layout=True)
spec = fig.add_gridspec(ncols=2, nrows=1, wspace=0, hspace=0)

ax = fig.add_subplot(spec[0, 0])
ax.imshow(rearrange(x0, 'c h w -> h w c'))
ax.set_xticks([])
ax.set_yticks([])
ax.set_title("Original")

ax = fig.add_subplot(spec[0, 1])
ax.imshow(rearrange(x1, 'c h w -> h w c'))
ax.set_xticks([])
ax.set_yticks([])
ax.set_title("Augmented")


# Build dataloader

In [None]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=False,
    num_workers=4,
)

# Setup SimSiam model

In [None]:
class SimSiam(pl.LightningModule):
    def __init__(self):
        super().__init__()
        resnet = torchvision.models.resnet18()
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])
        self.projection_head = SimSiamProjectionHead(512, 512, 128)
        self.prediction_head = SimSiamPredictionHead(128, 64, 128)
        self.criterion = NegativeCosineSimilarity()

    def forward(self, x):
        f = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(f)
        p = self.prediction_head(z)
        z = z.detach()
        return z, p

    def training_step(self, batch, batch_idx):
        (x0, x1, _) = batch
        z0, p0 = self.forward(x0)
        z1, p1 = self.forward(x1)
        loss = 0.5 * (self.criterion(z0, p1) + self.criterion(z1, p0))
        return loss

    def configure_optimizers(self):
        optim = torch.optim.SGD(self.parameters(), lr=0.06)
        return optim
        
model = SimSiam()
model

# Train model

In [None]:
trainer = pl.Trainer(max_epochs=epochs,
                     accelerator="cpu", devices=1, strategy="auto",deterministic=True)

trainer.fit(model=model, train_dataloaders=dataloader)