In [None]:
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import torch
import torchvision.models as models
import numpy as np
from sklearn.cluster import KMeans

In [None]:
class Clustering:
    def __init__(self, n_cluster, device="cuda"):
        self.device = torch.device(device) 
        self.model = models.resnet18(weights="DEFAULT").to(self.device)
        self.model.eval() 
        self.model = torch.nn.Sequential(*list(self.model.children())[:-1])
        
        self.kmeans = KMeans(n_clusters=n_cluster)
        
        self.n_cluster = n_cluster
        
    def extract_features(self, dataloader): 
        features = []
        with torch.no_grad(): 
            for images, _ in dataloader: 
                images = images.to(self.device)
                output = self.model(images)
                output = output.view(output.size(0), -1)
                features.append(output.cpu().numpy())
        return np.vstack(features)
    
    def fit(self, dataloader): 
        features = self.extract_features(dataloader)
        self.kmeans.fit(features)
        labels = self.kmeans.labels_
        return labels, features

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

In [None]:
clustering = Clustering(n_cluster=10, device="cuda")

In [None]:
labels, features = clustering.fit(dataloader)

In [None]:
pca = PCA(n_components=2)
reduced_features = pca.fit_transform(features)

plt.figure(figsize=(10, 8))
for cluster in range(clustering.n_cluster):
    cluster_indices = np.where(labels == cluster)
    plt.scatter(reduced_features[cluster_indices, 0], reduced_features[cluster_indices, 1], label=f'Cluster {cluster}')
plt.legend()
plt.title("Image Clustering using K-Means")
plt.xlabel("PCA Component 1")
plt.ylabel("PCA Component 2")
plt.show()