In [5]:
import sys
sys.path.append("../")

import torch
import numpy as np
import matplotlib.pyplot as plt

from tae_module import *
from data.data import parallel_line, orthogonal, triangle, lines_3D
import torch.nn as nn


class TensorizedAutoencoder(nn.Module):
    def __init__(self, autoencoder, Y_data, n_clusters, regularizer_coef=0.01, device='cpu'):
        super().__init__()
        self.logging = True
        self.device = device
        self.n_clusters = n_clusters
        self.centers = None
        
        self.autoencoders = GroupedAE(autoencoder, n_clusters)
        self.mse = nn.functional.mse_loss
        self.reg = regularizer_coef

        self.random_state = None    # set manually if needed, call model._init_clusters(Y) again
        self._init_clusters(Y_data)

    def _init_clusters(self, Y):
        self.centers, _ = kmeans_plusplus(Y.reshape(len(Y), -1).numpy(), n_clusters=self.n_clusters, random_state=self.random_state)
        self.centers = torch.from_numpy(self.centers)

    def forward_pass(self, x, clust_idx=None, centers=None, return_embed=False):
        """
        Getting output for a given batch of inputs for 
        1. computed centers and Nearest-Center-Assignment (NCA)
        2. input override centers and NCA
        3. input override assigned clusters
        """
        if self.centers is None and centers is None:
            print("Cannot compute output!") 
            return

        device = x.device
        centers = centers if centers is not None else self.centers
        centers = centers.to(device).reshape(self.n_clusters,-1)
        if clust_idx is None:
            clust_idx = self._assign_centers_to_data(x, centers)

        embed, out = torch.tensor([]).to(device), torch.tensor([]).to(device)
        for idx in range(len(x)):
            samp = x[idx:idx+1]     # 1, 3, 32, 32
            clust = clust_idx[idx]
            e, o = self.autoencoders.AE[clust](samp-centers[clust].reshape(samp.shape), True)
            embed = torch.cat((embed, e))
            out = torch.cat((out, o))

        if return_embed: return embed, out
        return out
    
    def compute_loss(self, x, y, centers, assigned_clusters):
        x_ = x.clone()  # [batch, 3, 32, 32]
        y_ = y.clone()  # [batch, 3, 32, 32]
        
        embed, out = self.autoencoders(x_, centers, True)   # [batch, n_clusters, 3, 32, 32]

        losses, new_idxs = torch.tensor([]).to(x_.device), torch.tensor([])
        for idx in range(len(x_)):
            embed_, out_, true_ = embed[idx], out[idx], y_[idx]
            # embed_, out_: (n_clusters, 3, 32, 32)

            embed_norm = embed_.clone()
            while len(embed_norm.shape) > 1:
                embed_norm = torch.norm(embed_norm, dim=1)

            dims = [i for i in range(1,len(x_.shape[1:])+1)]
            mse_proxy = torch.sum((out_ - true_.unsqueeze(0)) ** 2, dim=dims)
            loss_proxy = mse_proxy + self.reg * (embed_norm ** 2)

            new_idx = loss_proxy.argmin()   #reassigned clust
            
            clust = assigned_clusters[idx]  
            loss = self.mse(out_[clust], true_) + self.reg * (embed_norm[clust] ** 2) #using currently assigned

            losses = torch.cat((losses, loss.unsqueeze(0)))
            new_idxs = torch.cat((new_idxs, new_idx.unsqueeze(0)))

        new_indices = nn.functional.one_hot(new_idxs.long(), self.n_clusters).T
        losses = sum(losses)/len(losses)        

        return losses, new_indices

    def _update_centers(self, Y, clust_assign):
        clust_assign.to(Y.device)
        new_centers = clust_assign.float() @ Y.reshape(len(Y), -1).float()
        new_norm = torch.sum(clust_assign, axis=1, dtype=torch.float).reshape(-1, 1) @ \
                        torch.ones(1, self.centers.shape[1], dtype=torch.float)
        new_centers = new_centers / new_norm
        return new_centers

    def _assign_centers_to_data(self, data, centers=None):
        centers = centers if centers is not None else self.centers
        centers = centers.reshape(self.n_clusters, -1)
        
        assignments = torch.tensor([])
        for i in range(self.n_clusters):
            d = torch.norm(data.reshape(len(data), -1) - self.centers[i], dim=1).reshape(-1,1)
            assignments = torch.cat((assignments,d), dim=1)
        return assignments.argmin(dim=1)

    def train(self, X, Y, Dataset, epochs, lr, batch_size, **dataset_kwargs):
        self.configure_optimizers(lr)

        device = self.device
        X = X.to(device)
        Y = Y.to(device)

        clust_assign = self._assign_centers_to_data(Y)
        clust_assign_onehot = nn.functional.one_hot(clust_assign).T.float()

        self.centers = self._update_centers(Y, clust_assign_onehot)

        dataset = Dataset(X, Y, **dataset_kwargs)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size, shuffle=False)

        pbar = ProgressBar(epochs)
        es = EarlyStopping()

        losses = []
        for epoch in pbar.bar:
            
            for batch_idx, (x, y) in enumerate(dataloader):
                index = batch_idx * batch_size
                x, y = x.to(device), y.to(device)
                loss, new_indices = self.compute_loss(x, y, self.centers, clust_assign)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                clust_assign_onehot[:,index:index+batch_size] = new_indices
                clust_assign = clust_assign_onehot.argmax(0)

                new_centers = self._update_centers(Y, clust_assign_onehot) 
                self.centers = (index * self.centers + batch_size * new_centers) / (index + batch_size)

                losses.append(loss.item())

            es.update(losses[-1])
            pbar.update(epoch, losses[-1], es.es)
        
        return losses
    
    def configure_optimizers(self, lr):
        self.optimizer = torch.optim.Adagrad(self.parameters(), lr=lr)

device = 'cpu'#'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

X, Y, X_noise, n_clusters = triangle(noise=0.1)

ae = Autoencoder(
    enc_channels=(5,2), 
    dec_channels=(2,5),
    bias=False
)

tae = TensorizedAutoencoder(ae, X, n_clusters=3).to(device)

dummy_in = torch.randn((64,5)).to(device)
dummy_c = torch.randn((3,5)).to(device)
dummy_ass = tae._assign_centers_to_data(dummy_in, dummy_c)

tae.autoencoders(dummy_in,dummy_c).shape
tae.forward_pass(dummy_in, return_embed=False).shape
loss, new_indices = tae.compute_loss(dummy_in, dummy_in, tae.centers, dummy_ass)
print(loss, new_indices.shape)

ImportError: attempted relative import with no known parent package

In [28]:
dataset = DummyDataset(X, X)

In [29]:
loader = torch.utils.data.DataLoader(dataset, shuffle=False)

In [31]:
for n, (x,y) in enumerate(loader):
    print((x == X[n]).all())

tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)

In [24]:
from inpainting_utils import DummyDataset

losses = tae.train(X.float(), X.float(), DummyDataset, 50, 1e-3, 16)

Epoch 50: 100%|███████████████████████████████████| 50/50 [00:14<00:00,  3.39it/s, loss: nan, es: 0]
