In [None]:
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
from datetime import datetime

from typing import Any, Mapping
from torch.nn.functional import normalize
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
import numpy as np

from lightly.loss import NTXentLoss
from lightly.models.modules import SimCLRProjectionHead
from lightly.transforms.simclr_transform import SimCLRTransform
from lightly.transforms.utils import IMAGENET_NORMALIZE

import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA

from utils import generate_embeddings


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

### Test data preparation

In [None]:
test_transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=IMAGENET_NORMALIZE["mean"],
            std=IMAGENET_NORMALIZE["std"],
        ),
    ]
)

In [None]:
# note, that this is not 'test' in a train/test sense
# SIMCLR is uses SSL method, so labels doesn't really matter
test_dataset = torchvision.datasets.CIFAR10(
    "datasets/test-cifar10", download=True, transform=test_transform, train=False
)

In [None]:
test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=2,

    # this is important!
    shuffle=False,
    
    drop_last=True,
    num_workers=8,
)

### Model preparation

In [None]:
class SimCLR(pl.LightningModule):
    def __init__(self):
        super().__init__()
        resnet = torchvision.models.resnet18()
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])
        self.projection_head = SimCLRProjectionHead(512, 2048, 2048)
        self.criterion = NTXentLoss()
        self.test_dataloader = test_dataloader

    def forward(self, x):
        x = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(x)
        return z

    def training_step(self, batch, batch_index):
        (x0, x1) = batch[0]
        z0 = self.forward(x0)
        z1 = self.forward(x1)
        loss = self.criterion(z0, z1)
        return loss

    def validation_step(self, *args: Any, **kwargs: Any) -> torch.Tensor | Mapping[str, Any] | None:
        # CREATE EMBEDDINGS
        num_principal_components = 3
        n_clusters = 10
        embeddings = []
        with torch.no_grad():
            for (
                img,
                _,
            ) in self.test_dataloader:
                img = img.to(device)
                print("Single image shape: ", img.shape)
                emb = self.backbone(img).flatten(start_dim=1)
                print("Single mbedding shape: ", emb.shape)
                embeddings.append(emb)
                print("Embeddings length: ", len(embeddings))

        embeddings = torch.cat(embeddings, 0)
        print("Embeddings shape: ", embeddings.shape)
        embeddings = normalize(embeddings)

        # CLUSTER EMBEDDINGS
        pca = PCA(n_components=num_principal_components)
        embeddings_reduced = pca.fit_transform(embeddings.to("cpu").numpy())

        kmeans = KMeans(n_clusters=n_clusters)
        labels = kmeans.fit_predict(embeddings_reduced)
        centroids = kmeans.cluster_centers_

        # PLOT EMBEDDINGS
        # Get the current timestamp
        timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')

        # Create the filename with the timestamp
        filename = f'embedding-{timestamp}.png'
        fig = plt.figure()
        ax = fig.add_subplot(projection="3d")

        ax.scatter(
            embeddings_reduced[:, 0],
            embeddings_reduced[:, 1],
            embeddings_reduced[:, 2],
            c=labels,
            alpha=0.05,
        )

        ax.scatter(
            centroids[:, 0],
            centroids[:, 1],
            centroids[:, 2],
            c="red",
            s=100,
            alpha=1.0,
        )

        # Save the plot to a PNG file
        plt.savefig(filename)
        return None
    
    def configure_optimizers(self):
        optim = torch.optim.SGD(self.parameters(), lr=0.06)
        return optim

In [None]:
model = SimCLR()

transform = SimCLRTransform(input_size=32)

### Train data preparation

In [None]:
dataset = torchvision.datasets.CIFAR10(
    "datasets/cifar10", download=True, transform=transform, train=True
)

# ... or MNIST
# dataset = torchvision.datasets.MNIST(
#     "datasets/mnist-clean-torchvision", download=True, transform=transform
# )

# ... or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=256,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

In [None]:
trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=device)

In [None]:
# training
# trainer.fit(model=model, train_dataloaders=dataloader)
trainer.fit(model=model, train_dataloaders=dataloader, val_dataloaders=test_dataloader)

# read from checkpoint
# model = SimCLR.load_from_checkpoint(
#     r"lightning_logs\version_15\checkpoints\epoch=9-step=1950.ckpt"
# )

## Generate embeddings

In [None]:
embeddings = generate_embeddings(model, test_dataloader)

In [None]:
print(type(embeddings))
print(len(embeddings))
print(embeddings.shape)

## Validate embeddings

In [None]:
num_principal_components = 3
n_clusters = 10

In [None]:
pca = PCA(n_components=num_principal_components)
embeddings_reduced = pca.fit_transform(embeddings.to("cpu").numpy())

kmeans = KMeans(n_clusters=n_clusters)
labels = kmeans.fit_predict(embeddings_reduced)
centroids = kmeans.cluster_centers_

In [None]:
centroids.shape

In [None]:
%matplotlib widget

### New labels

In [None]:
ax = plt.figure().add_subplot(projection="3d")

ax.scatter(
    embeddings_reduced[:, 0],
    embeddings_reduced[:, 1],
    embeddings_reduced[:, 2],
    c=labels,
    alpha=0.05,
)

ax.scatter(
    centroids[:, 0],
    centroids[:, 1],
    centroids[:, 2],
    c="red",
    s=100,
    alpha=1.0,
)

plt.show()

### Original labels

In [None]:
ax = plt.figure().add_subplot(projection="3d")

ax.scatter(
    embeddings_reduced[:, 0],
    embeddings_reduced[:, 1],
    embeddings_reduced[:, 2],
    c=test_dataset.targets,
    alpha=0.05,
)

ax.scatter(
    centroids[:, 0],
    centroids[:, 1],
    centroids[:, 2],
    c="red",
    s=100,
    alpha=1.0,
)

plt.show()

## Neighbours

In [None]:
from sklearn.neighbors import NearestNeighbors

In [None]:
n_neighbors = 10
num_examples = 10

## Misc

In [None]:
from utils import get_distances_between_centroids

In [None]:
distances = get_distances_between_centroids(embeddings.to("cpu").numpy(), n_clusters=20)

In [None]:
print(f"Distances shape: {distances.shape}")
print(f"Distances rank: {np.linalg.matrix_rank(distances)}")
print(f"Distancess:\n {distances}")

In [None]:
plt.matshow(distances, cmap="viridis")