In [6]:
import numpy as  np
import torch
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split

t = np.load('cluster_labels.npy')
X = np.load('latents_vector.npy')
X = X.reshape(103250, 1024)

X_train, X_test, y_train, y_test = train_test_split(X, t, test_size=0.2, random_state=42)

X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.long)

train_ds = TensorDataset(X_train, y_train)
test_ds = TensorDataset(X_test, y_test)

batch_size = 64
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

In [7]:
import torch.nn as nn
import torch.optim as optim

def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

In [8]:
from models import ClusterNetwork

input_dim = X_train.shape[1]
output_dim = len(torch.unique(y_train))

model = ClusterNetwork(input_dim, output_dim)
model.apply(init_weights)
model.cuda()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

In [9]:
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.cuda(), labels.cuda()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    model.eval()
    correct = 0
    total = 0
    test_loss = 0.0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.cuda(), labels.cuda()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        accuracy = 100 * correct / total
        
    average_test_loss = test_loss / len(test_loader)
    print(f"Test Loss: {average_test_loss:.4f}, Test Accuracy: {accuracy:.2f}%")

Test Loss: 0.9078, Test Accuracy: 71.19%
Test Loss: 0.7972, Test Accuracy: 72.63%
Test Loss: 0.7502, Test Accuracy: 73.50%
Test Loss: 0.7085, Test Accuracy: 74.24%
Test Loss: 0.6920, Test Accuracy: 74.43%
Test Loss: 0.6840, Test Accuracy: 74.93%
Test Loss: 0.6545, Test Accuracy: 75.54%
Test Loss: 0.6458, Test Accuracy: 76.00%
Test Loss: 0.6546, Test Accuracy: 75.64%
Test Loss: 0.6399, Test Accuracy: 76.10%
