### Deep GLVQ

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader,random_split

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from tqdm import tqdm

import matplotlib.pyplot as plt
import numpy as np

from mnist_helpers import process_mnist_dataset

import copy

from sklvq import GLVQ

In [3]:
import copy

##### Loading MNIST Dataset

In [4]:
batch_size = 256
validation_split = 0.2

train_data,validation_data,test_data = process_mnist_dataset(batch_size,validation_split)

train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(dataset=validation_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=False)

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self,n_dims_for_classif):
        super(Encoder, self).__init__()

        self.fc1 = nn.Linear(784, 392)
        self.fc2 = nn.Linear(392, 194)
        self.fc3 = nn.Linear(194, n_dims_for_classif)

        self.bn1 = nn.BatchNorm1d(392)
        self.bn2 = nn.BatchNorm1d(194)

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)

        x = self.fc1(x)
        x = self.bn1(x)  
        x = F.leaky_relu(x, negative_slope=0.01)

        x = self.fc2(x)
        x = self.bn2(x)  
        x = F.leaky_relu(x, negative_slope=0.01)

        x = self.fc3(x)  

        return x

##### GLVQ Loss Function

In [19]:
import torch
import torch.nn.functional as F

def L_GLVQ(X, y, P, Py, gamma=0.4, epsilon=1e-8):
    pairwise_distances = torch.cdist(X, P)  

    same_class_mask = (Py == y.unsqueeze(1))
    diff_class_mask = (Py != y.unsqueeze(1))

    large_number = 1e9

    d_plus = torch.where(same_class_mask, pairwise_distances, torch.full_like(pairwise_distances, large_number))
    d_plus = torch.min(d_plus, dim=1).values  

    d_minus = torch.where(diff_class_mask, pairwise_distances, torch.full_like(pairwise_distances, large_number))
    d_minus = torch.min(d_minus, dim=1).values  

    raw_loss = torch.sigmoid((d_plus - d_minus) / (d_plus + d_minus + epsilon))
    
    loss = torch.mean(torch.where(raw_loss > gamma, raw_loss - gamma, torch.zeros_like(raw_loss)))

    return loss

###### Training Deep GLVQ Encoder

In [20]:
def get_val_acc(model,val_loader,fitted_glvq_object):

    latents = []
    labels = []

    for batch_idx,(data,target) in enumerate(val_loader):

        cur_latents = model(data).detach().numpy()
        cur_labels = target.detach().numpy()

        latents.extend(cur_latents)
        labels.extend(cur_labels)

    glvq_acc = fitted_glvq_object.score(latents,labels)


    return glvq_acc

In [21]:
n_dims_for_classif = 10
encoder = Encoder(n_dims_for_classif)

learning_rate = 1e-4
optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
n_epochs = 40

In [22]:
import torch
import torch.nn.functional as F

#standard GLVQ loss
gamma = 0

for epoch in range(n_epochs):
    
    for batch_idx,(data,target) in tqdm(enumerate(train_loader),desc = f'Epoch {epoch + 1} / {n_epochs}'):
    
        optimizer.zero_grad()

        latents = encoder(data)

        latents = (latents - torch.mean(latents,axis = 0)) / torch.std(latents, axis = 0)

        if batch_idx % 10 == 0:

            glvq = GLVQ(
                distance_type="squared-euclidean",
                activation_type="sigmoid",
                activation_params={"beta": 1},
                solver_type="steepest-gradient-descent",
                solver_params={"max_runs": 20, "step_size": 0.1},
                prototype_n_per_class=8
            )

            latents_np = latents.detach().numpy()
            targets_np = target.detach().numpy()

            glvq.fit(latents_np,targets_np)

            P = torch.Tensor(glvq.prototypes_)
            Py = torch.Tensor(glvq.prototypes_labels_)

        glvq_loss = L_GLVQ(
            latents,target,P,Py,gamma = gamma
        )

        glvq_loss.backward()
        optimizer.step()

    val_acc = get_val_acc(encoder,validation_loader,glvq)
    print(f'Epoch {epoch + 1} validation accuracy: {val_acc}\n')
            

Epoch 1 / 40: 188it [01:09,  2.71it/s]


Epoch 1 validation accuracy: 0.89875



Epoch 2 / 40: 0it [00:01, ?it/s]


KeyboardInterrupt: 