In [1]:
import sys 
sys.path.append("../")

from data.data import parallel_line, orthogonal, triangle, lines_3D
from train_utils import *

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import matplotlib.pyplot as plt
from collections import Counter

from base import *

Importing all assests from base.py...
Imported modules   : numpy, pandas, matplotlib.pyplot
Imported functions : npy, axes_off, get_var_name, shapes, tqdm, plot_history, minmax, values


In [2]:
class DenseBlock(nn.Module):
    def __init__(self, in_features, out_features, bias, activation=None):
        super().__init__()
        self.layer = nn.Linear(in_features, out_features, bias)
        self.act = activation

    def forward(self, x):
        if self.act is None:
            return self.layer(x)
        
        return self.act(self.layer(x))

class Autoencoder(nn.Module):
    def __init__(self, **kwargs):
        """
        Requires: enc_channels, dec_channels, bias, activations (nn.ReLU(), nn.GELU(), etc)
        """
        super(Autoencoder, self).__init__()

        enc_channels, dec_channels = kwargs['enc_channels'], kwargs['dec_channels']
        bias = kwargs['bias']
        activations = kwargs['activations']

        if enc_channels[-1] != dec_channels[0]: 
            print("[WARN] First shape of dec_channels does not match the terminal channel in enc_channels, proceeding with additional layer...")
            dec_channels = (enc_channels[-1],)+dec_channels

        self.enc = nn.Sequential()
        for i in range(len(enc_channels)-1):
            self.enc.add_module(f'enc_dense{i}', DenseBlock(enc_channels[i], enc_channels[i+1], bias=bias, activation=activations))

        self.dec = nn.Sequential()
        for i in range(len(dec_channels)-1):
            self.dec.add_module(f'dec_dense{i}', DenseBlock(dec_channels[i], dec_channels[i+1], bias=bias, activation=activations))

    def forward(self, x, return_embed=False):
        x = x.float()
        embed = self.enc(x)
        out = self.dec(embed)

        if return_embed: return embed, out
        return out

class GroupedModel(nn.Module):
    def __init__(self, n_clusters, model_class, **kwargs):
        super(GroupedModel, self).__init__()
        self.n_clusters = n_clusters

        self.AE = nn.ModuleList(model_class(**kwargs) for i in range(n_clusters))

    def _return_embed(self, embed, out, flag):
        return (embed, out) if flag else out

    def forward_with_clust(self, x, centers, clust, return_embed=False):
        """ 
        To be used in the batched phase, computes output for input x, all belonging to the same cluster 
        """
        embed, out = self.AE[clust](x - centers[clust], True)

        return self._return_embed(embed, out, return_embed)

    def forward_with_centers(self, x, centers, return_embed=False):
        """ 
        To be used in warmup phase, computes output for input x for all clusters. 
        Output format (batch, n_clusters, <data shape>)
        """
        embed, out = self.AE[0](x - centers[0].reshape(x.shape[1:]), True)
        embed, out = embed.reshape((len(x),1)+embed.shape[1:]), out.reshape((len(x),1)+out.shape[1:])

        for n in range(1, self.n_clusters):
            e, o = self.AE[n](x - centers[n].reshape(x.shape[1:]), True)
            embed = torch.cat((embed, e.reshape((len(x),1)+e.shape[1:])), dim=1)
            out = torch.cat((out, o.reshape((len(x),1)+o.shape[1:])), dim=1)

        return self._return_embed(embed, out, return_embed)
    
    def forward(self, *args):
        raise NotImplementedError("Use one of `forward_with_centers` or `forward_with_clust`")

In [3]:
# sanity check: inputs of shape 5 (1d), AE follows 5->2->1->2->5 (1 = embed_dim)

kwargs = {'enc_channels': (5,2,1),
          'dec_channels': (1,2,5), 
          'bias': True, 
          'activations': nn.ReLU()}

ae = Autoencoder(**kwargs)
gae = GroupedModel(2, Autoencoder, **kwargs)

dummy_input = torch.randn(64, 5)
dummy_true = torch.randn(64,5) 
dummy_centers = torch.randn(2, 5)

In [4]:
embed, out = ae(dummy_input, True)
shapes(embed, out)

embed, out = gae.forward_with_centers(dummy_input, dummy_centers, True)
shapes(embed, out)

embed, out = gae.forward_with_clust(dummy_input, dummy_centers, 0, True)
shapes(embed, out)

embed  : torch.Size([64, 1])
out    : torch.Size([64, 5])
embed  : torch.Size([64, 2, 1])
out    : torch.Size([64, 2, 5])
embed  : torch.Size([64, 1])
out    : torch.Size([64, 5])


In [5]:
def to_npy(x): 
    if isinstance(x, torch.Tensor):
        return x.detach().cpu().numpy()
    return x

In [6]:
from sklearn.cluster import kmeans_plusplus

class TensorizedAutoencoder(nn.Module):
    def __init__(self, grouped_model, Y_data, regularizer_coef=0.01):
        super(TensorizedAutoencoder, self).__init__()
        self.n_clusters = grouped_model.n_clusters
        self.centers = None
        
        self.autoencoders = grouped_model
        self.reg = regularizer_coef

        self.random_state = None    # set manually if needed, call model._init_clusters(Y) again
        self._init_clusters(to_npy(Y_data))
    
    def mse(self, x, y, dim=None):
        return F.mse_loss(x, y, reduction='none').mean(dim=dim) if dim is not None else F.mse_loss(x, y)

    def _init_clusters(self, Y):
        centers = kmeans_plusplus(Y.reshape(len(Y), -1), n_clusters=self.n_clusters, random_state=self.random_state)[0]
        self.centers = torch.from_numpy(centers)
        assert self.centers.shape == (self.n_clusters, Y.shape[1])
    
    def forward_with_clust(self, x, centers, clust, return_embed=False):
        return self.autoencoders.forward_with_clust(x, centers, clust, return_embed)
    
    def forward_with_centers(self, x, centers, return_embed=False):
        return self.autoencoders.forward_with_centers(x, centers, return_embed)
    
    def _assign_centers_to_data(self, data, one_hot=False, centers=None):
        centers = centers or self.centers
        centers = centers.reshape(self.n_clusters, -1)

        assignments = torch.tensor([])
        for i in range(self.n_clusters):
            d = torch.norm(data.reshape(len(data), -1) - self.centers[i], dim=1).reshape(-1,1)
            assignments = torch.cat((assignments,d), dim=1)
        if one_hot: return F.one_hot(assignments.argmin(dim=1), self.n_clusters).T
        return assignments.argmin(dim=1)

    def _update_centers(self, Y, clust_assign):
        assert clust_assign.shape == (self.n_clusters, len(Y)), "check if clust_assign is in one-hot format"
        clust_assign.to(Y.device)
        new_centers = clust_assign.float() @ Y.reshape(len(Y), -1).float()
        new_norm = torch.sum(clust_assign, axis=1, dtype=torch.float).reshape(-1, 1) @ \
                        torch.ones(1, self.centers.shape[1], dtype=torch.float)
        new_centers = new_centers / new_norm
        return new_centers
    
    def compute_loss_warmup(self, x, y, clust_assign):
        clusts = clust_assign.argmax(dim=0)
        batch_size = x.shape[0]

        x_, y_ = x.clone(), y.clone()

        embed, out = self.autoencoders.forward_with_centers(x_, self.centers, True)
        # (batch, n_clusters, <data shape>)

        new_assignment = torch.zeros(self.n_clusters, batch_size)
        loss = 0
        for samp in range(batch_size):
            true, e, o = y_[samp][None,:].repeat_interleave(self.n_clusters, dim=0), embed[samp], out[samp]
            # shapes(true, e, o)

            e_, o_ = e.reshape(self.n_clusters, -1), o.reshape(self.n_clusters, -1)

            loss_proxy = self.mse(o_, true, dim=1) + self.reg * (torch.norm(e_, dim=1) ** 2)
            # shapes(loss_proxy)

            new_center = loss_proxy.argmin(dim=0)
            loss += self.mse(o_[new_center], true[0]) + self.reg * (torch.norm(e_) ** 2)

            new_assignment[:,samp][new_center] = 1

        return loss, new_assignment

    # def compute_loss_batch(self, x, y, ):    

In [7]:
# sanity check: inputs of shape 5 (1d)


tae = TensorizedAutoencoder(gae, dummy_true)
clust_assign = tae._assign_centers_to_data(dummy_true, one_hot=True)
# clust_assign = F.one_hot(clusts, tae.n_clusters).T

In [8]:
loss, new_assignment = tae.compute_loss_warmup(dummy_input[10:16], dummy_true[10:16], clust_assign)

In [12]:
# from torchviz import make_dot

# make_dot(loss).render("attached")