In [1]:
import torch
import torch.nn as nn
import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np
from torch.utils.data import DataLoader
from torch_geometric.utils import dense_to_sparse
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
from tqdm import tqdm
from torch.optim import Adam
from matplotlib import pyplot as plt
import json
import os

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# PyTorch implementations of Polynomial Kernels

def polynomial_kernel(X, Y, degree=2, gamma=1.0, coef0=1.0):
    """
    Polynomial kernel function.
    :param X: torch.Tensor of shape (n_samples_1, n_features)
    :param Y: torch.Tensor of shape (n_samples_2, n_features)
    :param degree: int, default=2
    :param gamma: float, default=1.0
    :param coef0: float, default=1.0
    :return: torch.Tensor of shape (n_samples_1, n_samples_2)
    """
    K = (gamma * torch.mm(X, Y.t()) + coef0) ** degree
    return K

def rbf_kernel(X, Y, gamma=None):

    """
    Rbf kernel function.
    :param X: torch.Tensor of shape (n_samples_1, n_features)
    :param Y: torch.Tensor of shape (n_samples_2, n_features)
    :param gamma: float or None, default=None
    :return: torch.Tensor of shape (n_samples_1, n_samples_2)
    """

    if gamma == None:
        gamma = 1.0/X.size(1) # gamma = 1/n_features
    
    # K(x, y) = exp(-gamma ||x-y||^2)
    d_XY = torch.cdist(X, Y, p=2) # pairwise distances between X and Y rows. Shape = (n_samples_1, n_samples_2)
    K = torch.exp(-gamma * d_XY ** 2)

    return K

def laplacian_kernel(X, Y, gamma=None):
    """
    Laplacian kernel function.
    :param X: torch.Tensor of shape (n_samples_1, n_features)
    :param Y: torch.Tensor of shape (n_samples_2, n_features)
    :param gamma: float or None, default=None
    :return: torch.Tensor of shape (n_samples_1, n_samples_2)
    """

    if gamma == None:
        gamma = 1.0/X.size(1) # gamma = 1/n_features
    
    # K(x, y) = exp(-gamma ||x-y||)
    d_XY = torch.cdist(X, Y, p=1) # pairwise distances between X and Y rows. Shape = (n_samples_1, n_samples_2)
    K = torch.exp(-gamma * d_XY)

    return K

def sigmoid_kernel(X, Y, gamma=1.0, coef0=1.0):
    """
    Sigmoid kernel function.
    :param X: torch.Tensor of shape (n_samples_1, n_features)
    :param Y: torch.Tensor of shape (n_samples_2, n_features)
    :param gamma: float, default=1.0
    :param coef0: float, default=1.0
    :return: torch.Tensor of shape (n_samples_1, n_samples_2)
    """
    K = torch.tanh(gamma * torch.mm(X, Y.t()) + coef0)
    return K


In [4]:
def binary_distance(X, Y):
    """Compute distance matrix between  rows of X, Y.

    d(x_i, y_j) = 1 if x_i == y_j, 0 in other case.

    for all rows x_i in X, y_j in Y

    """
    return (X.unsqueeze(1) == Y.unsqueeze(0)).all(-1).float()


def kernel(X, *args, **kwargs):
    """
    Compute similarity matrix of an array X using a variety of kernels
    Parameters:
        X = input data (torch.Tensor or numpy array)
    """
    # kernel parameters

    degree = kwargs.get("degree", None)
    gamma = kwargs.get("gamma", None)
    coef = kwargs.get("coef", None)
    kernel_type = kwargs.get("kernel_type", None)

    if kernel_type == "polynomial":


        # K(x,y) = (gamma * <x,y> + coef)^degree, for vectors x,y

        if (gamma != None) & (coef != None) & (degree != None):
            return polynomial_kernel(X = X, Y = X, degree = degree, gamma = gamma, coef0 = coef)
        else:
            return None

    elif kernel_type == "sigmoid":
 
        # K(x,y) = tanh(gamma * <x,y> + coef), for vectors x,y

        if (gamma != None) & (coef != None):
            return sigmoid_kernel(X = X, Y = X, gamma = gamma, coef0 = coef)
        else:
            return None

    elif kernel_type == "rbf":

        # K(x, y) = exp(-gamma ||x-y||^2)

        if (gamma != None):
            return rbf_kernel(X = X, Y = X, gamma = gamma)
        else:
            return None

    elif kernel_type == "laplacian":

        # K(x, y) = exp(-gamma ||x-y||_1)

        if (gamma != None):
            return laplacian_kernel(X = X, Y = X, gamma = gamma)
        else:
            return None

    else:
        return None


In [10]:
class SVBLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(SVBLinear, self).__init__()
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        nn.init.orthogonal_(self.linear.weight)

    def forward(self, input):
        return self.linear(input)
    
    def singular_value_bounding(self):
        with torch.no_grad():
            # Descomposición SVD manual
            weight = self.linear.weight.data
            u, s, vh = torch.linalg.svd(weight, full_matrices=False)
            # Ajuste de valores singulares
            s = torch.clamp(s, min=0.9, max=1.1)
            # Reconstrucción de la matriz de pesos
            self.linear.weight.copy_(torch.mm(u, torch.mm(torch.diag(s), vh)))

class OrthogonalNN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(OrthogonalNN, self).__init__()
        self.fc1 = SVBLinear(input_dim, 128)
        self.fc2 = SVBLinear(128, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
    def apply_svb(self):
        self.fc1.singular_value_bounding()
        self.fc2.singular_value_bounding()

class Encoder_SLMVP(nn.Module):
    def __init__(self, encoder, kernel_parameters_X, optimizer, label_indep = False, kernel_parameters_Y = None):
        super().__init__()
        self.encoder = encoder
        self.kernel_parameters_X = kernel_parameters_X
        self.optimizer = optimizer
        self.label_indep = label_indep
        self.kernel_parameters_Y = kernel_parameters_Y
        self.criterion = nn.MSELoss()
    def encode(self, x):
        return self.encoder(x)
    
    def train_model(self, data_loader):
        self.train()
        
        epoch_losses = []

        for X, y in tqdm(data_loader):
            self.optimizer.zero_grad()
            Z = self.encode(X)

            # Calculation of similarity matrices
            K_X = kernel(X, **self.kernel_parameters_X)

            if self.label_indep: # data with different labels has not relationship (e.g: multiclass classification)
                K_Y = binary_distance(y.unsqueeze(1), y.unsqueeze(1))
            else:
                if self.kernel_parameters_Y == None:
                    self.kernel_parameters_Y = self.kernel_parameters_X
                K_Y = kernel(y, **self.kernel_parameters_Y)
            K_XY = K_X * K_Y
            D = torch.diag(torch.sum(K_XY, dim = 1))
            loss = torch.norm(torch.matmul(torch.matmul(torch.t(Z), D), Z) - torch.eye(Z.size(1)))
            epoch_losses.append(loss)

            loss.backward()
            self.optimizer.step()
        
        avg_loss = torch.mean(torch.Tensor(epoch_losses))
        return avg_loss.item()
    
    def test_model(self, data_loader):
        self.eval()
        epoch_losses = []

        with torch.no_grad():
            for X, y in data_loader:
                Z = self.encode(X)

                # Calculation of similarity matrices
                K_X = kernel(X, **self.kernel_parameters_X)
                if self.label_indep: # data with different labels has not relationship (e.g: multiclass classification)
                    K_Y = binary_distance(y.unsqueeze(1), y.unsqueeze(1))
                else:
                    if self.kernel_parameters_Y == None:
                        self.kernel_parameters_Y = self.kernel_parameters_X
                    K_Y = kernel(y, **self.kernel_parameters_Y)
                K_XY = K_X * K_Y
                D = torch.diag(torch.sum(K_XY, dim = 1))
                loss = torch.norm(torch.matmul(torch.matmul(torch.t(Z), D), Z) - torch.eye(Z.size(1)))
                epoch_losses.append(loss)
            
        avg_loss = torch.mean(torch.Tensor(epoch_losses))
        return avg_loss.item()

In [11]:
lr = 0.001 # learning rate
batch_size = 64
epochs = 1000
latent_size = 64 # latent space dimension

kernel_param_X = {

    "kernel_type": "rbf",
    "degree": 3,
    "gamma": 1e-2,
    "coef": 0.5e2

}

# kernel function parameters for Y_train / Y_test. If is equal to None, we use the same parameters.

kernel_param_Y = None

"""
kernel_param_Y = {

    "kernel_type": "polynomial",
    "degree": 3,
    "gamma": 1.,
    "coef": 1.

}
"""

label_indep = True # True for categoric labels, False for numerical labels

In [12]:


# train_data = pd.read_csv("/content/tfm_esteban/fashion_MNIST/fashion-mnist_train.csv")
# test_data = pd.read_csv("/content/tfm_esteban/fashion_MNIST/fashion-mnist_test.csv")

train_data = pd.read_csv("../fashion_MNIST/fashion-mnist_train.csv")
test_data = pd.read_csv("../fashion_MNIST/fashion-mnist_test.csv")
data = pd.concat([train_data, test_data], ignore_index = True)
# print(train_data.shape, test_data.shape, data.shape)
# data.head()
train_data, test_data = train_test_split(data, test_size=0.2, shuffle=True, stratify=data["label"], random_state=20)
train_data, val_data = train_test_split(train_data, test_size=0.1, shuffle=True, stratify=train_data["label"], random_state=20)

# print(train_data["label"].value_counts(normalize=True))
# print(val_data["label"].value_counts(normalize=True))
# print(test_data["label"].value_counts(normalize=True))

X_train = np.array(train_data.drop("label", axis = 1, inplace=False))/255
X_train = torch.tensor(X_train, dtype = torch.float)
y_train = train_data["label"].values
y_train = torch.tensor(y_train, dtype = torch.float)
X_val = np.array(val_data.drop("label", axis = 1, inplace=False))/255
X_val = torch.tensor(X_val, dtype = torch.float)
y_val = val_data["label"].values
y_val = torch.tensor(y_val, dtype = torch.float)
X_test = np.array(test_data.drop("label", axis = 1, inplace=False))/255
X_test = torch.tensor(X_test, dtype = torch.float)
y_test = test_data["label"].values
y_test = torch.tensor(y_test, dtype = torch.float)

train_loader = DataLoader(list(zip(X_train, y_train)), shuffle = True, batch_size=batch_size)
val_loader = DataLoader(list(zip(X_val, y_val)), shuffle = True, batch_size=batch_size)
test_loader = DataLoader(list(zip(X_test, y_test)), shuffle = False, batch_size=batch_size)

input_size = X_train.size(1)

In [20]:
encoder = OrthogonalNN(input_size, latent_size)
optimizer = Adam(list(encoder.parameters()), lr = lr)

model = Encoder_SLMVP(encoder, kernel_param_X, optimizer, label_indep, kernel_param_Y)

In [None]:
# # Datos de prueba
# input_dim = 10
# output_dim = 5
# model = OrthogonalNN(input_dim, output_dim)

# x = torch.randn(20, input_dim)
# target = torch.randn(20, output_dim)

# # Definición de la pérdida y el optimizador
# criterion = nn.MSELoss()
# optimizer = optim.Adam(model.parameters(), lr=0.001)

# # Entrenamiento
# num_epochs = 100
# for epoch in range(num_epochs):
#     optimizer.zero_grad()
#     outputs = model(x)
#     loss = criterion(outputs, target)
#     loss.backward()
#     optimizer.step()
    
#     # Aplicación periódica del método SVB
#     if epoch % 10 == 0:
#         model.apply_svb()

#     print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')



In [21]:
for epoch in range(epochs):
    print(f"Epoch: {epoch + 1}/{epochs}")
    avg_loss_train = model.train_model(train_loader)
    avg_loss_val = model.test_model(val_loader)

    print(f"avg_loss_train: {avg_loss_train}")
    print(f"avg_loss_val: {avg_loss_val}")

Epoch: 1/1000


100%|██████████| 788/788 [00:02<00:00, 347.43it/s]


avg_loss_train: 11.380729675292969
avg_loss_val: 7.249608993530273
Epoch: 2/1000


100%|██████████| 788/788 [00:02<00:00, 342.30it/s]


avg_loss_train: 7.157425403594971
avg_loss_val: 7.196472644805908
Epoch: 3/1000


100%|██████████| 788/788 [00:02<00:00, 336.85it/s]


avg_loss_train: 7.20368766784668
avg_loss_val: 7.275579452514648
Epoch: 4/1000


100%|██████████| 788/788 [00:02<00:00, 364.65it/s]


avg_loss_train: 7.491987228393555
avg_loss_val: 7.613012790679932
Epoch: 5/1000


100%|██████████| 788/788 [00:02<00:00, 360.81it/s]


avg_loss_train: 7.670621871948242
avg_loss_val: 7.772434711456299
Epoch: 6/1000


100%|██████████| 788/788 [00:02<00:00, 356.72it/s]


avg_loss_train: 7.804819583892822
avg_loss_val: 8.069823265075684
Epoch: 7/1000


100%|██████████| 788/788 [00:02<00:00, 329.95it/s]


avg_loss_train: 7.835549831390381
avg_loss_val: 7.850907802581787
Epoch: 8/1000


100%|██████████| 788/788 [00:02<00:00, 367.10it/s]


avg_loss_train: 7.836110591888428
avg_loss_val: 7.861189365386963
Epoch: 9/1000


100%|██████████| 788/788 [00:02<00:00, 366.86it/s]


avg_loss_train: 7.844554424285889
avg_loss_val: 7.813504219055176
Epoch: 10/1000


100%|██████████| 788/788 [00:02<00:00, 356.24it/s]


avg_loss_train: 7.779885768890381
avg_loss_val: 7.781555652618408
Epoch: 11/1000


100%|██████████| 788/788 [00:02<00:00, 345.19it/s]


avg_loss_train: 7.773061275482178
avg_loss_val: 7.730377197265625
Epoch: 12/1000


100%|██████████| 788/788 [00:02<00:00, 358.23it/s]


avg_loss_train: 7.776986122131348
avg_loss_val: 7.738363265991211
Epoch: 13/1000


100%|██████████| 788/788 [00:02<00:00, 361.64it/s]


avg_loss_train: 7.740954875946045
avg_loss_val: 7.7389302253723145
Epoch: 14/1000


100%|██████████| 788/788 [00:02<00:00, 344.09it/s]


avg_loss_train: 7.709323406219482
avg_loss_val: 7.686236381530762
Epoch: 15/1000


100%|██████████| 788/788 [00:02<00:00, 356.24it/s]


avg_loss_train: 7.689550876617432
avg_loss_val: 7.792572021484375
Epoch: 16/1000


100%|██████████| 788/788 [00:02<00:00, 365.32it/s]


avg_loss_train: 7.699148654937744
avg_loss_val: 7.685966491699219
Epoch: 17/1000


100%|██████████| 788/788 [00:02<00:00, 367.03it/s]


avg_loss_train: 7.7046966552734375
avg_loss_val: 7.746533393859863
Epoch: 18/1000


100%|██████████| 788/788 [00:02<00:00, 358.67it/s]


avg_loss_train: 7.667443752288818
avg_loss_val: 7.733426094055176
Epoch: 19/1000


100%|██████████| 788/788 [00:02<00:00, 333.38it/s]


avg_loss_train: 7.699223041534424
avg_loss_val: 7.698245048522949
Epoch: 20/1000


100%|██████████| 788/788 [00:02<00:00, 367.48it/s]


avg_loss_train: 7.659453392028809
avg_loss_val: 7.676483631134033
Epoch: 21/1000


100%|██████████| 788/788 [00:02<00:00, 339.32it/s]


avg_loss_train: 7.610566139221191
avg_loss_val: 7.649920463562012
Epoch: 22/1000


100%|██████████| 788/788 [00:02<00:00, 357.69it/s]


avg_loss_train: 7.585335731506348
avg_loss_val: 7.566735744476318
Epoch: 23/1000


100%|██████████| 788/788 [00:02<00:00, 320.67it/s]


avg_loss_train: 7.5711517333984375
avg_loss_val: 7.510639667510986
Epoch: 24/1000


100%|██████████| 788/788 [00:02<00:00, 312.04it/s]


avg_loss_train: 7.530494213104248
avg_loss_val: 7.499542236328125
Epoch: 25/1000


100%|██████████| 788/788 [00:02<00:00, 323.59it/s]


avg_loss_train: 7.523716926574707
avg_loss_val: 7.547187328338623
Epoch: 26/1000


100%|██████████| 788/788 [00:02<00:00, 326.74it/s]


avg_loss_train: 7.537128448486328
avg_loss_val: 7.502420425415039
Epoch: 27/1000


100%|██████████| 788/788 [00:02<00:00, 301.78it/s]


avg_loss_train: 7.479146480560303
avg_loss_val: 7.489846706390381
Epoch: 28/1000


100%|██████████| 788/788 [00:02<00:00, 327.90it/s]


avg_loss_train: 7.490170955657959
avg_loss_val: 7.4649553298950195
Epoch: 29/1000


100%|██████████| 788/788 [00:02<00:00, 308.38it/s]


avg_loss_train: 7.429542541503906
avg_loss_val: 7.440506458282471
Epoch: 30/1000


100%|██████████| 788/788 [00:02<00:00, 314.88it/s]


avg_loss_train: 7.432459354400635
avg_loss_val: 7.455714702606201
Epoch: 31/1000


100%|██████████| 788/788 [00:02<00:00, 322.52it/s]


avg_loss_train: 7.40048360824585
avg_loss_val: 7.469567775726318
Epoch: 32/1000


100%|██████████| 788/788 [00:02<00:00, 336.16it/s]


avg_loss_train: 7.414164066314697
avg_loss_val: 7.364654064178467
Epoch: 33/1000


100%|██████████| 788/788 [00:02<00:00, 295.64it/s]


avg_loss_train: 7.406744956970215
avg_loss_val: 7.414378643035889
Epoch: 34/1000


100%|██████████| 788/788 [00:02<00:00, 313.28it/s]


avg_loss_train: 7.3852763175964355
avg_loss_val: 7.403575420379639
Epoch: 35/1000


100%|██████████| 788/788 [00:02<00:00, 316.69it/s]


avg_loss_train: 7.350806713104248
avg_loss_val: 7.382757663726807
Epoch: 36/1000


100%|██████████| 788/788 [00:02<00:00, 330.24it/s]


avg_loss_train: 7.322176456451416
avg_loss_val: 7.31947135925293
Epoch: 37/1000


100%|██████████| 788/788 [00:02<00:00, 331.44it/s]


avg_loss_train: 7.322463512420654
avg_loss_val: 7.423341751098633
Epoch: 38/1000


100%|██████████| 788/788 [00:02<00:00, 335.70it/s]


avg_loss_train: 7.319526195526123
avg_loss_val: 7.239123344421387
Epoch: 39/1000


100%|██████████| 788/788 [00:02<00:00, 331.21it/s]


avg_loss_train: 7.3119025230407715
avg_loss_val: 7.403135299682617
Epoch: 40/1000


100%|██████████| 788/788 [00:02<00:00, 326.95it/s]


avg_loss_train: 7.319656848907471
avg_loss_val: 7.253818988800049
Epoch: 41/1000


100%|██████████| 788/788 [00:02<00:00, 336.97it/s]


avg_loss_train: 7.3299174308776855
avg_loss_val: 7.419661998748779
Epoch: 42/1000


100%|██████████| 788/788 [00:02<00:00, 334.02it/s]


avg_loss_train: 7.331532001495361
avg_loss_val: 7.3431291580200195
Epoch: 43/1000


100%|██████████| 788/788 [00:02<00:00, 329.68it/s]


avg_loss_train: 7.355350494384766
avg_loss_val: 7.415063858032227
Epoch: 44/1000


100%|██████████| 788/788 [00:02<00:00, 338.62it/s]


avg_loss_train: 7.330095291137695
avg_loss_val: 7.276691913604736
Epoch: 45/1000


100%|██████████| 788/788 [00:02<00:00, 339.86it/s]


avg_loss_train: 7.307084083557129
avg_loss_val: 7.343867778778076
Epoch: 46/1000


100%|██████████| 788/788 [00:02<00:00, 326.61it/s]


avg_loss_train: 7.295622825622559
avg_loss_val: 7.265249252319336
Epoch: 47/1000


100%|██████████| 788/788 [00:02<00:00, 335.44it/s]


avg_loss_train: 7.284976005554199
avg_loss_val: 7.379389762878418
Epoch: 48/1000


100%|██████████| 788/788 [00:02<00:00, 334.45it/s]


avg_loss_train: 7.292080402374268
avg_loss_val: 7.270242214202881
Epoch: 49/1000


100%|██████████| 788/788 [00:02<00:00, 340.74it/s]


avg_loss_train: 7.306924819946289
avg_loss_val: 7.3097453117370605
Epoch: 50/1000


100%|██████████| 788/788 [00:02<00:00, 340.09it/s]


avg_loss_train: 7.264791011810303
avg_loss_val: 7.300506591796875
Epoch: 51/1000


100%|██████████| 788/788 [00:02<00:00, 320.03it/s]


avg_loss_train: 7.281074523925781
avg_loss_val: 7.289846420288086
Epoch: 52/1000


100%|██████████| 788/788 [00:02<00:00, 323.85it/s]


avg_loss_train: 7.303655624389648
avg_loss_val: 7.413479804992676
Epoch: 53/1000


100%|██████████| 788/788 [00:02<00:00, 340.44it/s]


avg_loss_train: 7.287981986999512
avg_loss_val: 7.284619331359863
Epoch: 54/1000


100%|██████████| 788/788 [00:02<00:00, 345.84it/s]


avg_loss_train: 7.3012471199035645
avg_loss_val: 7.383560657501221
Epoch: 55/1000


100%|██████████| 788/788 [00:02<00:00, 316.56it/s]


avg_loss_train: 7.281055927276611
avg_loss_val: 7.271629333496094
Epoch: 56/1000


100%|██████████| 788/788 [00:02<00:00, 299.23it/s]


avg_loss_train: 7.278247356414795
avg_loss_val: 7.238658428192139
Epoch: 57/1000


100%|██████████| 788/788 [00:02<00:00, 326.14it/s]


avg_loss_train: 7.276091575622559
avg_loss_val: 7.2304606437683105
Epoch: 58/1000


100%|██████████| 788/788 [00:02<00:00, 322.21it/s]


avg_loss_train: 7.258589744567871
avg_loss_val: 7.23405122756958
Epoch: 59/1000


100%|██████████| 788/788 [00:02<00:00, 312.41it/s]


avg_loss_train: 7.270962715148926
avg_loss_val: 7.338042736053467
Epoch: 60/1000


100%|██████████| 788/788 [00:02<00:00, 326.27it/s]


avg_loss_train: 7.295009136199951
avg_loss_val: 7.291595935821533
Epoch: 61/1000


100%|██████████| 788/788 [00:02<00:00, 333.14it/s]


avg_loss_train: 7.258220672607422
avg_loss_val: 7.282768726348877
Epoch: 62/1000


100%|██████████| 788/788 [00:02<00:00, 329.55it/s]


avg_loss_train: 7.271157741546631
avg_loss_val: 7.3576765060424805
Epoch: 63/1000


100%|██████████| 788/788 [00:02<00:00, 285.45it/s]


avg_loss_train: 7.281064033508301
avg_loss_val: 7.33717679977417
Epoch: 64/1000


100%|██████████| 788/788 [00:02<00:00, 296.65it/s]


avg_loss_train: 7.263907432556152
avg_loss_val: 7.252191066741943
Epoch: 65/1000


100%|██████████| 788/788 [00:02<00:00, 303.97it/s]


avg_loss_train: 7.24806547164917
avg_loss_val: 7.263692855834961
Epoch: 66/1000


100%|██████████| 788/788 [00:02<00:00, 317.71it/s]


avg_loss_train: 7.2632246017456055
avg_loss_val: 7.293390274047852
Epoch: 67/1000


100%|██████████| 788/788 [00:02<00:00, 313.66it/s]


avg_loss_train: 7.244937896728516
avg_loss_val: 7.271871566772461
Epoch: 68/1000


100%|██████████| 788/788 [00:02<00:00, 317.82it/s]


avg_loss_train: 7.246330261230469
avg_loss_val: 7.293807506561279
Epoch: 69/1000


100%|██████████| 788/788 [00:02<00:00, 318.74it/s]


avg_loss_train: 7.278800964355469
avg_loss_val: 7.258474349975586
Epoch: 70/1000


100%|██████████| 788/788 [00:02<00:00, 324.12it/s]


avg_loss_train: 7.245744228363037
avg_loss_val: 7.236482620239258
Epoch: 71/1000


100%|██████████| 788/788 [00:02<00:00, 313.78it/s]


avg_loss_train: 7.263181209564209
avg_loss_val: 7.300449848175049
Epoch: 72/1000


100%|██████████| 788/788 [00:02<00:00, 323.12it/s]


avg_loss_train: 7.257339000701904
avg_loss_val: 7.242433071136475
Epoch: 73/1000


100%|██████████| 788/788 [00:02<00:00, 318.48it/s]


avg_loss_train: 7.237981796264648
avg_loss_val: 7.259289264678955
Epoch: 74/1000


100%|██████████| 788/788 [00:02<00:00, 313.35it/s]


avg_loss_train: 7.2577738761901855
avg_loss_val: 7.273898124694824
Epoch: 75/1000


100%|██████████| 788/788 [00:02<00:00, 314.02it/s]


avg_loss_train: 7.253095626831055
avg_loss_val: 7.246579170227051
Epoch: 76/1000


100%|██████████| 788/788 [00:02<00:00, 304.45it/s]


avg_loss_train: 7.255299091339111
avg_loss_val: 7.3464813232421875
Epoch: 77/1000


100%|██████████| 788/788 [00:02<00:00, 289.02it/s]


avg_loss_train: 7.253601551055908
avg_loss_val: 7.271107196807861
Epoch: 78/1000


100%|██████████| 788/788 [00:02<00:00, 307.17it/s]


avg_loss_train: 7.276113033294678
avg_loss_val: 7.2567853927612305
Epoch: 79/1000


100%|██████████| 788/788 [00:02<00:00, 312.82it/s]


avg_loss_train: 7.238954067230225
avg_loss_val: 7.235325813293457
Epoch: 80/1000


100%|██████████| 788/788 [00:02<00:00, 294.10it/s]


avg_loss_train: 7.250070095062256
avg_loss_val: 7.274184703826904
Epoch: 81/1000


100%|██████████| 788/788 [00:02<00:00, 300.72it/s]


avg_loss_train: 7.24190616607666
avg_loss_val: 7.372656345367432
Epoch: 82/1000


100%|██████████| 788/788 [00:02<00:00, 294.51it/s]


avg_loss_train: 7.2402825355529785
avg_loss_val: 7.2545084953308105
Epoch: 83/1000


100%|██████████| 788/788 [00:02<00:00, 311.91it/s]


avg_loss_train: 7.235558986663818
avg_loss_val: 7.222206115722656
Epoch: 84/1000


100%|██████████| 788/788 [00:02<00:00, 278.78it/s]


avg_loss_train: 7.249754428863525
avg_loss_val: 7.244662761688232
Epoch: 85/1000


100%|██████████| 788/788 [00:02<00:00, 283.96it/s]


avg_loss_train: 7.229766845703125
avg_loss_val: 7.264463424682617
Epoch: 86/1000


100%|██████████| 788/788 [00:02<00:00, 275.86it/s]


avg_loss_train: 7.267078876495361
avg_loss_val: 7.3079833984375
Epoch: 87/1000


100%|██████████| 788/788 [00:02<00:00, 302.22it/s]


avg_loss_train: 7.241648197174072
avg_loss_val: 7.243194580078125
Epoch: 88/1000


100%|██████████| 788/788 [00:02<00:00, 293.65it/s]


avg_loss_train: 7.251438140869141
avg_loss_val: 7.2953338623046875
Epoch: 89/1000


100%|██████████| 788/788 [00:02<00:00, 289.85it/s]


avg_loss_train: 7.280024528503418
avg_loss_val: 7.277259349822998
Epoch: 90/1000


100%|██████████| 788/788 [00:02<00:00, 303.85it/s]


avg_loss_train: 7.251596927642822
avg_loss_val: 7.269667148590088
Epoch: 91/1000


100%|██████████| 788/788 [00:02<00:00, 307.53it/s]


avg_loss_train: 7.246298313140869
avg_loss_val: 7.278906345367432
Epoch: 92/1000


100%|██████████| 788/788 [00:02<00:00, 305.92it/s]


avg_loss_train: 7.253587245941162
avg_loss_val: 7.2781453132629395
Epoch: 93/1000


100%|██████████| 788/788 [00:02<00:00, 308.43it/s]


avg_loss_train: 7.232203960418701
avg_loss_val: 7.284095764160156
Epoch: 94/1000


100%|██████████| 788/788 [00:02<00:00, 278.59it/s]


avg_loss_train: 7.247283458709717
avg_loss_val: 7.215182781219482
Epoch: 95/1000


100%|██████████| 788/788 [00:02<00:00, 303.74it/s]


avg_loss_train: 7.255198001861572
avg_loss_val: 7.285014629364014
Epoch: 96/1000


100%|██████████| 788/788 [00:02<00:00, 301.65it/s]


avg_loss_train: 7.2989373207092285
avg_loss_val: 7.343233108520508
Epoch: 97/1000


100%|██████████| 788/788 [00:02<00:00, 308.51it/s]


avg_loss_train: 7.272292613983154
avg_loss_val: 7.261857986450195
Epoch: 98/1000


100%|██████████| 788/788 [00:02<00:00, 305.74it/s]


avg_loss_train: 7.2642598152160645
avg_loss_val: 7.247343063354492
Epoch: 99/1000


100%|██████████| 788/788 [00:02<00:00, 291.65it/s]


avg_loss_train: 7.238604545593262
avg_loss_val: 7.265821933746338
Epoch: 100/1000


100%|██████████| 788/788 [00:02<00:00, 297.61it/s]


avg_loss_train: 7.255998134613037
avg_loss_val: 7.321259021759033
Epoch: 101/1000


100%|██████████| 788/788 [00:02<00:00, 299.92it/s]


avg_loss_train: 7.249810218811035
avg_loss_val: 7.277750015258789
Epoch: 102/1000


100%|██████████| 788/788 [00:02<00:00, 314.29it/s]


avg_loss_train: 7.258325576782227
avg_loss_val: 7.239675998687744
Epoch: 103/1000


100%|██████████| 788/788 [00:02<00:00, 275.85it/s]


avg_loss_train: 7.226365566253662
avg_loss_val: 7.217477798461914
Epoch: 104/1000


100%|██████████| 788/788 [00:02<00:00, 267.59it/s]


avg_loss_train: 7.233608722686768
avg_loss_val: 7.313718795776367
Epoch: 105/1000


100%|██████████| 788/788 [00:03<00:00, 262.52it/s]


avg_loss_train: 7.257530689239502
avg_loss_val: 7.342048168182373
Epoch: 106/1000


100%|██████████| 788/788 [00:02<00:00, 282.12it/s]


avg_loss_train: 7.272395133972168
avg_loss_val: 7.236944198608398
Epoch: 107/1000


100%|██████████| 788/788 [00:02<00:00, 278.25it/s]


avg_loss_train: 7.244349956512451
avg_loss_val: 7.231488227844238
Epoch: 108/1000


100%|██████████| 788/788 [00:02<00:00, 276.24it/s]


avg_loss_train: 7.226653099060059
avg_loss_val: 7.28633451461792
Epoch: 109/1000


100%|██████████| 788/788 [00:03<00:00, 258.65it/s]


avg_loss_train: 7.267690181732178
avg_loss_val: 7.322962760925293
Epoch: 110/1000


100%|██████████| 788/788 [00:03<00:00, 248.05it/s]


avg_loss_train: 7.281299114227295
avg_loss_val: 7.291266918182373
Epoch: 111/1000


100%|██████████| 788/788 [00:03<00:00, 225.91it/s]


avg_loss_train: 7.235583305358887
avg_loss_val: 7.2036333084106445
Epoch: 112/1000


100%|██████████| 788/788 [00:03<00:00, 229.08it/s]


avg_loss_train: 7.241645812988281
avg_loss_val: 7.221997261047363
Epoch: 113/1000


100%|██████████| 788/788 [00:03<00:00, 240.61it/s]


avg_loss_train: 7.227609157562256
avg_loss_val: 7.2942280769348145
Epoch: 114/1000


100%|██████████| 788/788 [00:03<00:00, 245.26it/s]


avg_loss_train: 7.258840560913086
avg_loss_val: 7.283763408660889
Epoch: 115/1000


100%|██████████| 788/788 [00:03<00:00, 237.44it/s]


avg_loss_train: 7.243851184844971
avg_loss_val: 7.312512397766113
Epoch: 116/1000


100%|██████████| 788/788 [00:03<00:00, 232.37it/s]


avg_loss_train: 7.2940874099731445
avg_loss_val: 7.305505752563477
Epoch: 117/1000


100%|██████████| 788/788 [00:03<00:00, 237.49it/s]


avg_loss_train: 7.2673187255859375
avg_loss_val: 7.228419780731201
Epoch: 118/1000


100%|██████████| 788/788 [00:03<00:00, 241.56it/s]


avg_loss_train: 7.25789737701416
avg_loss_val: 7.269142150878906
Epoch: 119/1000


100%|██████████| 788/788 [00:03<00:00, 236.42it/s]


avg_loss_train: 7.242605209350586
avg_loss_val: 7.262825012207031
Epoch: 120/1000


100%|██████████| 788/788 [00:03<00:00, 223.14it/s]


avg_loss_train: 7.251901149749756
avg_loss_val: 7.302767753601074
Epoch: 121/1000


100%|██████████| 788/788 [00:03<00:00, 237.21it/s]


avg_loss_train: 7.26939058303833
avg_loss_val: 7.239141941070557
Epoch: 122/1000


100%|██████████| 788/788 [00:03<00:00, 228.99it/s]


avg_loss_train: 7.241206645965576
avg_loss_val: 7.237858295440674
Epoch: 123/1000


100%|██████████| 788/788 [00:03<00:00, 213.11it/s]


avg_loss_train: 7.255202770233154
avg_loss_val: 7.238568305969238
Epoch: 124/1000


100%|██████████| 788/788 [00:03<00:00, 214.87it/s]


avg_loss_train: 7.230442523956299
avg_loss_val: 7.2112650871276855
Epoch: 125/1000


100%|██████████| 788/788 [00:03<00:00, 230.10it/s]


avg_loss_train: 7.238275527954102
avg_loss_val: 7.247949600219727
Epoch: 126/1000


100%|██████████| 788/788 [00:03<00:00, 237.28it/s]


avg_loss_train: 7.2184576988220215
avg_loss_val: 7.217030048370361
Epoch: 127/1000


100%|██████████| 788/788 [00:03<00:00, 226.30it/s]


avg_loss_train: 7.23561429977417
avg_loss_val: 7.251287460327148
Epoch: 128/1000


100%|██████████| 788/788 [00:03<00:00, 236.21it/s]


avg_loss_train: 7.253345489501953
avg_loss_val: 7.332481384277344
Epoch: 129/1000


100%|██████████| 788/788 [00:03<00:00, 234.59it/s]


avg_loss_train: 7.25978422164917
avg_loss_val: 7.219163417816162
Epoch: 130/1000


100%|██████████| 788/788 [00:03<00:00, 234.52it/s]


avg_loss_train: 7.240865707397461
avg_loss_val: 7.30422830581665
Epoch: 131/1000


100%|██████████| 788/788 [00:03<00:00, 233.84it/s]


avg_loss_train: 7.244297981262207
avg_loss_val: 7.293086528778076
Epoch: 132/1000


100%|██████████| 788/788 [00:03<00:00, 231.63it/s]


avg_loss_train: 7.261462688446045
avg_loss_val: 7.240945816040039
Epoch: 133/1000


100%|██████████| 788/788 [00:03<00:00, 234.52it/s]


avg_loss_train: 7.221101760864258
avg_loss_val: 7.235773086547852
Epoch: 134/1000


100%|██████████| 788/788 [00:03<00:00, 240.99it/s]


avg_loss_train: 7.241812229156494
avg_loss_val: 7.239965915679932
Epoch: 135/1000


100%|██████████| 788/788 [00:03<00:00, 240.91it/s]


avg_loss_train: 7.215260982513428
avg_loss_val: 7.207281589508057
Epoch: 136/1000


100%|██████████| 788/788 [00:03<00:00, 232.79it/s]


avg_loss_train: 7.231685638427734
avg_loss_val: 7.280168056488037
Epoch: 137/1000


100%|██████████| 788/788 [00:03<00:00, 236.36it/s]


avg_loss_train: 7.256988525390625
avg_loss_val: 7.293243408203125
Epoch: 138/1000


100%|██████████| 788/788 [00:03<00:00, 232.58it/s]


avg_loss_train: 7.254482746124268
avg_loss_val: 7.241973876953125
Epoch: 139/1000


100%|██████████| 788/788 [00:03<00:00, 242.35it/s]


avg_loss_train: 7.253037929534912
avg_loss_val: 7.280149459838867
Epoch: 140/1000


100%|██████████| 788/788 [00:03<00:00, 247.08it/s]


avg_loss_train: 7.2365946769714355
avg_loss_val: 7.294290542602539
Epoch: 141/1000


100%|██████████| 788/788 [00:03<00:00, 241.65it/s]


avg_loss_train: 7.2469096183776855
avg_loss_val: 7.230136394500732
Epoch: 142/1000


100%|██████████| 788/788 [00:03<00:00, 230.58it/s]


avg_loss_train: 7.223343849182129
avg_loss_val: 7.298620223999023
Epoch: 143/1000


100%|██████████| 788/788 [00:03<00:00, 243.04it/s]


avg_loss_train: 7.249794960021973
avg_loss_val: 7.263702392578125
Epoch: 144/1000


100%|██████████| 788/788 [00:03<00:00, 238.21it/s]


avg_loss_train: 7.233143329620361
avg_loss_val: 7.25086784362793
Epoch: 145/1000


100%|██████████| 788/788 [00:03<00:00, 233.87it/s]


avg_loss_train: 7.236526012420654
avg_loss_val: 7.245935440063477
Epoch: 146/1000


100%|██████████| 788/788 [00:03<00:00, 231.65it/s]


avg_loss_train: 7.231919288635254
avg_loss_val: 7.219418525695801
Epoch: 147/1000


100%|██████████| 788/788 [00:03<00:00, 234.24it/s]


avg_loss_train: 7.237698078155518
avg_loss_val: 7.255433559417725
Epoch: 148/1000


100%|██████████| 788/788 [00:03<00:00, 240.13it/s]


avg_loss_train: 7.255357265472412
avg_loss_val: 7.264097213745117
Epoch: 149/1000


100%|██████████| 788/788 [00:03<00:00, 243.61it/s]


avg_loss_train: 7.224000930786133
avg_loss_val: 7.22709321975708
Epoch: 150/1000


100%|██████████| 788/788 [00:03<00:00, 248.59it/s]


avg_loss_train: 7.230929851531982
avg_loss_val: 7.202674388885498
Epoch: 151/1000


100%|██████████| 788/788 [00:03<00:00, 247.48it/s]


avg_loss_train: 7.230776786804199
avg_loss_val: 7.229377746582031
Epoch: 152/1000


100%|██████████| 788/788 [00:03<00:00, 238.71it/s]


avg_loss_train: 7.2414937019348145
avg_loss_val: 7.271265029907227
Epoch: 153/1000


100%|██████████| 788/788 [00:03<00:00, 248.83it/s]


avg_loss_train: 7.243581771850586
avg_loss_val: 7.222368240356445
Epoch: 154/1000


100%|██████████| 788/788 [00:03<00:00, 250.48it/s]


avg_loss_train: 7.219336032867432
avg_loss_val: 7.2775983810424805
Epoch: 155/1000


100%|██████████| 788/788 [00:03<00:00, 224.89it/s]


avg_loss_train: 7.216922760009766
avg_loss_val: 7.239681720733643
Epoch: 156/1000


100%|██████████| 788/788 [00:03<00:00, 237.92it/s]


avg_loss_train: 7.2356367111206055
avg_loss_val: 7.28166389465332
Epoch: 157/1000


100%|██████████| 788/788 [00:03<00:00, 233.87it/s]


avg_loss_train: 7.2260284423828125
avg_loss_val: 7.218220233917236
Epoch: 158/1000


100%|██████████| 788/788 [00:03<00:00, 236.98it/s]


avg_loss_train: 7.234623908996582
avg_loss_val: 7.24202299118042
Epoch: 159/1000


100%|██████████| 788/788 [00:03<00:00, 231.57it/s]


avg_loss_train: 7.229676723480225
avg_loss_val: 7.280241012573242
Epoch: 160/1000


100%|██████████| 788/788 [00:03<00:00, 247.67it/s]


avg_loss_train: 7.226516246795654
avg_loss_val: 7.215224742889404
Epoch: 161/1000


100%|██████████| 788/788 [00:03<00:00, 223.22it/s]


avg_loss_train: 7.2068328857421875
avg_loss_val: 7.245645523071289
Epoch: 162/1000


100%|██████████| 788/788 [00:03<00:00, 222.01it/s]


avg_loss_train: 7.223431587219238
avg_loss_val: 7.218987941741943
Epoch: 163/1000


100%|██████████| 788/788 [00:03<00:00, 242.59it/s]


avg_loss_train: 7.242233753204346
avg_loss_val: 7.307583808898926
Epoch: 164/1000


100%|██████████| 788/788 [00:03<00:00, 236.14it/s]


avg_loss_train: 7.224029541015625
avg_loss_val: 7.229559421539307
Epoch: 165/1000


100%|██████████| 788/788 [00:03<00:00, 238.61it/s]


avg_loss_train: 7.228494644165039
avg_loss_val: 7.214987277984619
Epoch: 166/1000


100%|██████████| 788/788 [00:03<00:00, 245.98it/s]


avg_loss_train: 7.229432582855225
avg_loss_val: 7.249701976776123
Epoch: 167/1000


100%|██████████| 788/788 [00:03<00:00, 239.81it/s]


avg_loss_train: 7.2325358390808105
avg_loss_val: 7.329684257507324
Epoch: 168/1000


100%|██████████| 788/788 [00:03<00:00, 243.55it/s]


avg_loss_train: 7.232181549072266
avg_loss_val: 7.262760639190674
Epoch: 169/1000


100%|██████████| 788/788 [00:03<00:00, 233.20it/s]


avg_loss_train: 7.250699043273926
avg_loss_val: 7.2873640060424805
Epoch: 170/1000


100%|██████████| 788/788 [00:03<00:00, 235.22it/s]


avg_loss_train: 7.219258785247803
avg_loss_val: 7.227958679199219
Epoch: 171/1000


100%|██████████| 788/788 [00:03<00:00, 239.11it/s]


avg_loss_train: 7.24249792098999
avg_loss_val: 7.315942764282227
Epoch: 172/1000


100%|██████████| 788/788 [00:03<00:00, 235.54it/s]


avg_loss_train: 7.237788200378418
avg_loss_val: 7.292859077453613
Epoch: 173/1000


100%|██████████| 788/788 [00:03<00:00, 237.49it/s]


avg_loss_train: 7.218732833862305
avg_loss_val: 7.234119892120361
Epoch: 174/1000


100%|██████████| 788/788 [00:03<00:00, 228.92it/s]


avg_loss_train: 7.24501371383667
avg_loss_val: 7.234248638153076
Epoch: 175/1000


 63%|██████▎   | 500/788 [00:02<00:01, 206.44it/s]


KeyboardInterrupt: 

In [22]:
model.test_model(test_loader)

7.216961860656738

In [23]:
for X, y in test_loader:
    Z = model.encode(X)
    break

In [24]:
Z

tensor([[-7.1741e-03, -8.8770e-03,  9.1274e-05,  ..., -3.7526e-03,
         -1.4313e-02, -1.3468e-02],
        [ 1.1667e-02,  1.1655e-01, -4.0782e-03,  ..., -3.6326e-02,
         -1.5243e-02, -2.2210e-02],
        [ 2.0192e-03, -4.9030e-03,  8.7791e-05,  ..., -4.9374e-03,
          5.9507e-04, -1.6853e-02],
        ...,
        [ 7.7304e-03, -2.8578e-02,  5.0133e-03,  ...,  8.6156e-02,
         -4.3978e-03,  1.4429e-02],
        [ 5.5903e-03,  2.3808e-01, -2.3548e-03,  ..., -1.5278e-02,
          1.6067e-01, -3.4925e-02],
        [-1.3213e-02, -1.0972e-02,  9.3293e-04,  ...,  2.4970e-02,
         -1.0905e-02,  6.0221e-04]], grad_fn=<AddmmBackward0>)

In [27]:
K_X = kernel(X, **kernel_param_X)
K_Y = binary_distance(y.unsqueeze(1), y.unsqueeze(1))

K_XY = K_X * K_Y
D = torch.diag(torch.sum(K_XY, dim = 1))
torch.matmul(torch.matmul(torch.t(Z), D), Z)

tensor([[ 1.5503e-02,  5.1791e-02, -6.5794e-04,  ..., -1.9561e-04,
         -1.5435e-02, -2.6198e-04],
        [ 5.1791e-02,  7.7149e-01, -5.1065e-03,  ...,  2.2171e-04,
          1.4602e-01, -7.0928e-02],
        [-6.5794e-04, -5.1065e-03,  3.9073e-03,  ...,  5.2087e-03,
         -1.6297e-02,  1.8563e-03],
        ...,
        [-1.9561e-04,  2.2172e-04,  5.2087e-03,  ...,  2.1537e-01,
         -2.8651e-02,  1.7989e-02],
        [-1.5435e-02,  1.4602e-01, -1.6297e-02,  ..., -2.8651e-02,
          6.5723e-01, -4.9386e-02],
        [-2.6198e-04, -7.0928e-02,  1.8563e-03,  ...,  1.7989e-02,
         -4.9386e-02,  4.4094e-02]], grad_fn=<MmBackward0>)