In [1]:
from train_utils import *

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np

import matplotlib.pyplot as plt
from collections import Counter

from base import *

device = 'cpu'

Importing all assests from base.py...
Imported modules   : numpy, pandas, matplotlib.pyplot
Imported functions : npy, axes_off, get_var_name, shapes, tqdm, plot_history, minmax, values


In [2]:
class DenseBlock(nn.Module):
    def __init__(self, in_features, out_features, bias, activation=None):
        super().__init__()
        self.layer = nn.Linear(in_features, out_features, bias)
        self.act = activation

    def forward(self, x):
        if self.act is None:
            return self.layer(x)
        
        return self.act(self.layer(x))

class Autoencoder(nn.Module):
    def __init__(self, **kwargs):
        """
        Requires: enc_channels, dec_channels, bias, activations (nn.ReLU(), nn.GELU(), etc)
        """
        super(Autoencoder, self).__init__()

        enc_channels, dec_channels = kwargs['enc_channels'], kwargs['dec_channels']
        bias = kwargs['bias']
        activations = kwargs['activations']

        if enc_channels[-1] != dec_channels[0]: 
            print("[WARN] First shape of dec_channels does not match the terminal channel in enc_channels, proceeding with additional layer...")
            dec_channels = (enc_channels[-1],)+dec_channels

        self.enc = nn.Sequential()
        for i in range(len(enc_channels)-1):
            self.enc.add_module(f'enc_dense{i}', DenseBlock(enc_channels[i], enc_channels[i+1], bias=bias, activation=activations))

        self.dec = nn.Sequential()
        for i in range(len(dec_channels)-1):
            self.dec.add_module(f'dec_dense{i}', DenseBlock(dec_channels[i], dec_channels[i+1], bias=bias, activation=activations))

    def forward(self, x, return_embed=False):
        x = x.float()
        embed = self.enc(x)
        out = self.dec(embed)

        if return_embed: return embed, out
        return out

class GroupedModel(nn.Module):
    def __init__(self, n_clusters, model_class, **kwargs):
        super(GroupedModel, self).__init__()
        self.n_clusters = n_clusters

        self.AE = nn.ModuleList(model_class(**kwargs) for i in range(n_clusters))

    def _return_embed(self, embed, out, flag):
        return (embed, out) if flag else out

    def forward_with_clust(self, x, centers, clust, return_embed=False):
        """ 
        To be used in the batched phase, computes output for input x, all belonging to the same cluster 
        """
        embed, out = self.AE[clust](x - centers[clust], True)

        return self._return_embed(embed, out, return_embed)

    def forward_with_centers(self, x, centers, return_embed=False):
        """ 
        To be used in warmup phase, computes output for input x for all clusters. 
        Output format (batch, n_clusters, <data shape>)
        """
        embed, out = self.AE[0](x - centers[0].reshape(x.shape[1:]), True)
        embed, out = embed.reshape((len(x),1)+embed.shape[1:]), out.reshape((len(x),1)+out.shape[1:])

        for n in range(1, self.n_clusters):
            e, o = self.AE[n](x - centers[n].reshape(x.shape[1:]), True)
            embed = torch.cat((embed, e.reshape((len(x),1)+e.shape[1:])), dim=1)
            out = torch.cat((out, o.reshape((len(x),1)+o.shape[1:])), dim=1)

        return self._return_embed(embed, out, return_embed)
    
    def forward(self, *args):
        raise NotImplementedError("Use one of `forward_with_centers` or `forward_with_clust`")

In [3]:
# sanity check: inputs of shape 5 (1d), AE follows 5->2->1->2->5 (1 = embed_dim)

kwargs = {'enc_channels': (5,2,1),
          'dec_channels': (1,2,5), 
          'bias': False, 
          'activations': nn.ReLU()}

ae = Autoencoder(**kwargs)
gae = GroupedModel(2, Autoencoder, **kwargs)

dummy_input = torch.randn(64, 5)
dummy_true = torch.randn(64,5) 
dummy_centers = torch.randn(2, 5)

In [4]:
embed, out = ae(dummy_input, True)
shapes(embed, out)

embed, out = gae.forward_with_centers(dummy_input, dummy_centers, True)
shapes(embed, out)

embed, out = gae.forward_with_clust(dummy_input, dummy_centers, 0, True)
shapes(embed, out)

embed  : torch.Size([64, 1])
out    : torch.Size([64, 5])
embed  : torch.Size([64, 2, 1])
out    : torch.Size([64, 2, 5])
embed  : torch.Size([64, 1])
out    : torch.Size([64, 5])


In [5]:
def to_npy(x): 
    if isinstance(x, torch.Tensor):
        return x.detach().cpu().numpy()
    return x

In [6]:
from sklearn.cluster import kmeans_plusplus

class TensorizedAutoencoder(nn.Module):
    def __init__(self, grouped_model, Y_data, regularizer_coef=0.01):
        super(TensorizedAutoencoder, self).__init__()
        self.n_clusters = grouped_model.n_clusters
        self.centers = None
        
        self.autoencoders = grouped_model
        self.reg = regularizer_coef

        self.random_state = None    # set manually if needed, call model._init_clusters(Y) again
        self._init_clusters(to_npy(Y_data))
    
    def mse(self, x, y, dim=None):
        return F.mse_loss(x, y, reduction='none').mean(dim=dim) if dim is not None else F.mse_loss(x, y)
    
    def _init_clusters(self, Y):
        centers = kmeans_plusplus(Y.reshape(len(Y), -1), n_clusters=self.n_clusters, random_state=self.random_state)[0]
        self.centers = torch.from_numpy(centers)
        assert self.centers.shape == (self.n_clusters, Y.shape[1])
    
    def forward_with_clust(self, x, clust, return_embed=False):
        return self.autoencoders.forward_with_clust(x, self.centers, clust, return_embed)
    
    def forward_with_centers(self, x, return_embed=False):
        return self.autoencoders.forward_with_centers(x, self.centers, return_embed)
    
    def assign_centers_to_data(self, data, one_hot=False, centers=None):
        centers = centers or self.centers
        centers = centers.reshape(self.n_clusters, -1)

        assignments = torch.tensor([])
        for i in range(self.n_clusters):
            d = torch.norm(data.reshape(len(data), -1) - self.centers[i], dim=1).reshape(-1,1)
            assignments = torch.cat((assignments,d), dim=1)
        if one_hot: return F.one_hot(assignments.argmin(dim=1), self.n_clusters).T
        return assignments.argmin(dim=1)

    def update_centers(self, Y, clust_assign):
        assert clust_assign.shape == (self.n_clusters, len(Y)), "check if clust_assign is in one-hot format"
        clust_assign.to(Y.device)
        new_centers = clust_assign.float() @ Y.reshape(len(Y), -1).float()
        new_norm = torch.sum(clust_assign, axis=1, dtype=torch.float).reshape(-1, 1) @ \
                        torch.ones(1, self.centers.shape[1], dtype=torch.float)
        new_centers = new_centers / new_norm
        return new_centers
    
    def compute_loss_warmup(self, x, y, clust_assign):
        clusts = clust_assign.argmax(dim=0)
        b = x.shape[0]

        x_, y_ = x.clone(), y.clone()

        embed, out = self.autoencoders.forward_with_centers(x_, self.centers, True)
        # (batch, n_clusters, <data shape>)

        new_assignment = torch.zeros((self.n_clusters, b), dtype=int)
        loss = 0
        for samp in range(b):
            true, e, o = y_[samp][None,:].repeat_interleave(self.n_clusters, dim=0), embed[samp], out[samp]

            e_, o_ = e.reshape(self.n_clusters, -1), o.reshape(self.n_clusters, -1)

            loss_proxy = self.mse(o_, true, dim=1) + self.reg * (torch.norm(e_, dim=1) ** 2)
            # shapes(loss_proxy)

            new_center = loss_proxy.argmin(dim=0)
            loss += self.mse(o_[clusts[samp]], true[0]) + self.reg * (torch.norm(e_) ** 2)

            new_assignment[:,samp][new_center] = 1

        return loss, new_assignment

    def _collect_data(self, x, y, clust_assign):
        clusts = clust_assign.argmax(dim=0)
        return [(i,x[clusts==i],y[clusts==i]) for i in torch.unique(clusts)]

    def compute_loss_batch(self, x, y, clust_assign):
        collected = self._collect_data(x, y, clust_assign)
        loss = 0
        for c,x,y in collected:
            x_ = x.clone() - self.centers[c]
            y_ = y.clone() - self.centers[c]

            embed, out = self.forward_with_clust(x_, c, return_embed=True)
            loss += self.mse(out, y_) + self.reg * (torch.norm(embed) ** 2)
        
        return loss / len(collected)

In [7]:
def train_tae(tae, X, Y, epochs, lr, batch_size, warmup=0.3, verbose=True, grad_clip=5):
    clust_assign = tae.assign_centers_to_data(Y, one_hot=True)
    tae.centers = tae.update_centers(Y, clust_assign)
    
    optimizer = torch.optim.Adagrad(tae.parameters(), 0.1)
    warmup_dataloader = GenericDataset(X, Y).get_dataloader(batch_size=1, shuffle=False)
    
    warmup_epochs = int(warmup*epochs)

    if verbose: print(f"PHASE 1: Warmup — {warmup_epochs}/{epochs}")
    pbar_warmup = ProgressBar(warmup_epochs, 75, verbose)
    warmup_losses = []
    for epoch in pbar_warmup.bar:
        batch_losses = []
        new_clust_assign = torch.tensor([])
        for n, (x,y) in enumerate(warmup_dataloader):
            x, y = x.to(device), y.to(device)

            loss, new_assignment = tae.compute_loss_warmup(x, y, clust_assign[:,n:n+1])
            new_clust_assign = torch.cat((new_clust_assign, new_assignment), dim=1)

            # if new_assignment.argmax() != clust_assign[:, sl_clust].argmax(): print(f"updated {samp}")

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(tae.parameters(), grad_clip)
            optimizer.step()

            batch_losses.append(loss.item())

        clust_assign = new_clust_assign.clone()
        tae.centers = (tae.update_centers(Y, clust_assign) + epoch * tae.centers) / (epoch + 1)

        warmup_losses.append(np.mean(batch_losses))
        pbar_warmup.update(warmup_losses[-1])

    
    clust_wise_data = tae._collect_data(X, Y, clust_assign)
    dataloaders = [(c, GenericDataset(x, y).get_dataloader(batch_size=batch_size, shuffle=True)) for (c,x,y) in clust_wise_data]    

    batched_epochs = (epochs - warmup_epochs)
    
    if verbose: print(f"PHASE 2: Batched — {batched_epochs}/{epochs}")
    clust_losses = []
    for data in dataloaders:
        ae = tae.autoencoders.AE[data[0]]
        optimizer = torch.optim.Adam(ae.parameters(), lr)
        clust_losses.append(train_ae(ae, data[1], optimizer, epochs=batched_epochs, verbose=verbose)[1])
        
    # pbar_batched = ProgressBar(batched_epochs, 75, verbose)
    # losses = []
    # for epoch in pbar_batched.bar:
    #     batch_losses = []
    #     for n, (x,y) in enumerate(dataloader):
    #         idx = n * batch_size

    #         x_, y_ = x.to(device), y.to(device)
    #         loss = tae.compute_loss_batch(x_, y_, clust_assign[:,idx:idx+batch_size])

    #         optimizer.zero_grad()
    #         loss.backward()
    #         # nn.utils.clip_grad_norm_(tae.parameters(), grad_clip)

    #         optimizer.step()

    #         batch_losses.append(loss.item())
        
    #     losses.append(np.mean(batch_losses))
    #     pbar_batched.update(losses[-1])

    return warmup_losses, clust_losses, clust_assign.argmax(dim=0)

In [8]:
# train_tae working test

tae = TensorizedAutoencoder(gae, dummy_true)
warmup, batched, clusts = train_tae(tae, dummy_input, dummy_true, 50, 5e-3, 8, warmup=0.2, verbose=1, grad_clip=1)

PHASE 1: Warmup — 10/50


Epoch: 10/10 |███████████████████████████████████| [00:01<00:00, loss: nan]


PHASE 2: Batched — 40/50


Epoch: 40/40 |████████████████████████████| [00:00<00:00, loss: nan, es: 0]


# Testing on datasets

In [9]:
import sys 
sys.path.append("../")

from data.data import parallel_line, orthogonal, triangle, lines_3D
from sklearn.metrics.cluster import adjusted_rand_score as ari
from sklearn.metrics import mean_squared_error as mse
from sklearn.cluster import KMeans

X, Y, X_noise, n_clusters = parallel_line(noise=0.1)
randperm = torch.randperm(len(X))
X, Y, X_noise = X[randperm].float(), Y[randperm].float(), X_noise[randperm].float()

In [10]:
kwargs = {'enc_channels': (5,2),
          'dec_channels': (2,5), 
          'bias': False, 
          'activations': None}

In [11]:
ae = Autoencoder(**kwargs)
optimizer = torch.optim.Adam(ae.parameters(), 5e-3)
dataloader = GenericDataset(X, X).get_dataloader(batch_size=8, shuffle=False)

ae, losses = train_ae(ae, dataloader, optimizer, 50, verbose=1)

Epoch: 50/50 |█████████████████████████| [00:00<00:00, loss: 0.0060, es: 7]


In [12]:
ari(KMeans(n_clusters).fit(to_npy(ae.enc(X.float()))).labels_,Y)

0.1882917549396861

In [13]:
gae = GroupedModel(n_clusters, Autoencoder, **kwargs)
tae = TensorizedAutoencoder(gae, X)

warmup_losses, losses, clusts = train_tae(tae, X.float(), X.float(), 50, 5e-3, 8, 0.4, verbose=1, grad_clip=1)

PHASE 1: Warmup — 20/50


Epoch: 20/20 |████████████████████████████████| [00:04<00:00, loss: 0.7715]


PHASE 2: Batched — 30/50


Epoch: 30/30 |█████████████████████████| [00:00<00:00, loss: 0.0796, es: 0]
Epoch: 30/30 |█████████████████████████| [00:00<00:00, loss: 0.0309, es: 0]


In [14]:
ari(clusts, Y)

0.9733321641655686

# Working with reconstruction mse

In [15]:
num_runs = 1
verbose = 1

lae, ltae, lkm = [], [], []
for run in tqdm(range(num_runs), disable=verbose):
    # kmeans_ari = ari(KMeans(n_clusters).fit(to_npy(X)).labels_,Y)
    
    ae = Autoencoder(**kwargs)
    optimizer = torch.optim.Adam(ae.parameters(), 5e-3)
    dataloader = GenericDataset(X, X).get_dataloader(batch_size=8, shuffle=False)

    ae, losses = train_ae(ae, dataloader, optimizer, 50, verbose=verbose)
    # ae_ari = ari(KMeans(n_clusters).fit(to_npy(ae.enc(X.float()))).labels_,Y)

    ae_mse = mse(to_npy(ae(X)), X)

    gae = GroupedModel(n_clusters, Autoencoder, **kwargs)
    tae = TensorizedAutoencoder(gae, X)

    warmup_losses, losses, clusts = train_tae(tae, X.float(), X.float(), 50, 5e-3, 8, 0.4, verbose=verbose, grad_clip=0.5)
    # tae_ari = ari(clusts, Y)

    tae_mse = np.zeros(X.shape)
    for i in range(len(X)):
        tae_mse[i] = to_npy(tae.forward_with_clust(X[i:i+1], clusts[i:i+1]) + tae.centers[clusts[i]])

    tae_mse = mse(tae_mse, X)

    lae.append(ae_mse); ltae.append(tae_mse)

print(f"TAE: {np.mean(ltae):.4f}")
print(f"AE: {np.mean(lae):.4f}")

# print(f"Kmeans after TAE: {np.mean(ltae):.4f}")
# print(f"Kmeans after AE: {np.mean(lae):.4f}")
# print(f"Direct kmeans: {np.mean(lkm):.4f}")

Epoch: 50/50 |█████████████████████████| [00:00<00:00, loss: 0.0061, es: 0]


PHASE 1: Warmup — 20/50


Epoch: 20/20 |████████████████████████████████| [00:05<00:00, loss: 1.0930]


PHASE 2: Batched — 30/50


Epoch: 30/30 |█████████████████████████| [00:00<00:00, loss: 0.0236, es: 1]
Epoch: 30/30 |█████████████████████████| [00:00<00:00, loss: 0.0255, es: 0]

TAE: 0.0457
AE: 0.0058



