In [1]:
import torch
from torch import nn
from torch.optim import Adam

class ClusteringNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_clusters):
        super(ClusteringNetwork, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_clusters)
        )
        self.centroids = nn.Parameter(torch.rand(n_clusters, hidden_dim))

    def forward(self, x):
        cluster_assignments = self.network(x)
        return cluster_assignments

def wcss_loss(outputs, centroids):
    """Calculate the Within-Cluster Sum of Squares (WCSS) loss."""
    norm_squared = torch.sum((outputs.unsqueeze(1) - centroids) ** 2, 2)
    min_norm_squared = torch.min(norm_squared, 1)[0]
    wcss = torch.sum(min_norm_squared)
    return wcss

def train_clustering_network(data_loader, model, optimizer, epochs):
    for epoch in range(epochs):
        total_loss = 0
        for data in data_loader:
            inputs = data[0]
            optimizer.zero_grad()
            outputs = model(inputs)
            
            loss = wcss_loss(outputs, model.centroids)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f'Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(data_loader)}')

# Example usage:
input_dim = 784  # Number of features in the input
hidden_dim = 64  # Size of the hidden layer
n_clusters = 10  # Number of clusters

In [7]:
centroids = nn.Parameter(torch.rand(n_clusters, hidden_dim))
centroids.shape

torch.Size([10, 64])