In [None]:
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.

import pytorch_lightning as pl
import torch
import torchvision
from torchvision.transforms import Compose, Lambda
from torch import nn

from lightly.loss import NTXentLoss
from lightly.models.modules import SimCLRProjectionHead
from lightly.transforms.simclr_transform import SimCLRTransform

In [None]:
class SimCLR(pl.LightningModule):
    def __init__(self):
        super().__init__()
        resnet = torchvision.models.resnet18()
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])
        self.projection_head = SimCLRProjectionHead(512, 2048, 2048)
        self.criterion = NTXentLoss()

    def forward(self, x):
        x = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(x)
        return z

    def training_step(self, batch, batch_index):
        (x0, x1) = batch[0]
        z0 = self.forward(x0)
        z1 = self.forward(x1)
        loss = self.criterion(z0, z1)
        return loss

    def configure_optimizers(self):
        optim = torch.optim.SGD(self.parameters(), lr=0.06)
        return optim

In [None]:
model = SimCLR()

# this part is little tricky
# random_gray_scale=1.0
# but may work
# FIXME: start here next time
# FIXME: cleanup
# FIXME: prepare 2-3 clear notebooks
simclr_transform = SimCLRTransform(input_size=32, random_gray_scale=1.0)
mnist_to_3_channels = Lambda(lambda x: x.expand(3, -1, -1))
transform = Compose([mnist_to_3_channels, simclr_transform])

# dataset = torchvision.datasets.CIFAR10(
#     "datasets/cifar10", download=True, transform=transform, train=True
# )

# ... or MNIST
dataset = torchvision.datasets.MNIST(
    "datasets/mnist-clean-torchvision", download=True, transform=transform, train=True,
    transform
)

# ... or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=256,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

In [None]:
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)

In [None]:
# training
trainer.fit(model=model, train_dataloaders=dataloader)

# read from checkpoint
# model = SimCLR.load_from_checkpoint(
#     r"lightning_logs\ "
#     )

## Generate embeddings

In [None]:
import matplotlib.pyplot as plt
from utils import generate_embeddings
from lightly.transforms.utils import IMAGENET_NORMALIZE

In [None]:
test_transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=IMAGENET_NORMALIZE["mean"],
            std=IMAGENET_NORMALIZE["std"],
        ),
    ]
)

In [None]:
# note, that this is not 'test' in a train/test sense
# SIMCLR is uses SSL method, so labels doesn't really matter
# test_dataset = torchvision.datasets.CIFAR10(
#     "datasets/test-cifar10", download=True, transform=test_transform, train=False
# )

test_dataset = torchvision.datasets.CIFAR10(
    "datasets/mnist-clean-torchvision",
    download=True,
    transform=test_transform,
    train=False,
)

In [None]:
len(dataset)

In [None]:
test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=256,
    # this is important!
    shuffle=False,
    drop_last=True,
    num_workers=8,
)

In [None]:
embeddings = generate_embeddings(model, test_dataloader)

In [None]:
print(type(embeddings))
print(len(embeddings))
print(embeddings.shape)

### Random

In [None]:
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA

In [None]:
num_principal_components = 3
n_clusters = 20

In [None]:
pca = PCA(n_components=num_principal_components)
embeddings_reduced = pca.fit_transform(embeddings.to("cpu").numpy())

kmeans = KMeans(n_clusters=n_clusters)
labels = kmeans.fit_predict(embeddings_reduced)
centroids = kmeans.cluster_centers_

In [None]:
centroids.shape

In [None]:
ax = plt.figure().add_subplot(projection="3d")

ax.scatter(
    embeddings_reduced[:, 0],
    embeddings_reduced[:, 1],
    embeddings_reduced[:, 2],
    c=labels,
    alpha=0.05,
)

ax.scatter(
    centroids[:, 0],
    centroids[:, 1],
    centroids[:, 2],
    c="red",
    s=100,
    alpha=1.0,
)

plt.show()