In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from torchvision.models import resnet50
from tqdm import tqdm

In [None]:
# Define device variable for cuda, mps or cpu

if torch.cuda.is_available():
    device = torch.device('cuda')
    print('Using CUDA')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
    print('Using MPS')
else :
    device = torch.device('cpu')
    print('Using CPU')


In [None]:
#We want to use a resnet50 from torchvision to have the embedding of an image, use a pretrained resnet and remove the last layer

class ResNet(nn.Module):
    def __init__(self):
        super(ResNet, self).__init__()
        self.resnet = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', pretrained=True)
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])
        self.resnet.eval()
        for param in self.resnet.parameters():
            param.requires_grad = False

    def forward(self, x):
        return self.resnet(x)

In [None]:
#Define a clustering model that will take the embeddings and output the clusters and the cluster centers

class ClusteringModel(nn.Module):
    def __init__(self, n_clusters, embedding_dim):
        super(ClusteringModel, self).__init__()
        self.n_clusters = n_clusters
        self.cluster_centers = nn.Parameter(torch.randn(n_clusters, embedding_dim))
    
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = x.unsqueeze(1)
        cluster_centers = self.cluster_centers.unsqueeze(0)
        distances = torch.sum((x - cluster_centers) ** 2, dim=2)
        return F.log_softmax(-distances, dim=1)