In [151]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F
import wandb
import torch.optim as optim
from src.datasets.polynomial import PolynomialDataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from sklearn.cluster import AgglomerativeClustering

In [116]:
class Encoder(nn.Module):
    def __init__(self, input_dim, seq_len, cnn_kernel, cnn_stride, mp_kernel, mp_stride, lstm_hidden_dim) -> None:
        super().__init__()

        self.input_dim = input_dim
        self.seq_len = seq_len
        self.cnn_kernel = cnn_kernel
        self.cnn_stride = cnn_stride
        self.mp_kernel = mp_kernel
        self.mp_stride = mp_stride
        self.lstm_hidden_dim = lstm_hidden_dim

        self.cnn = nn.Conv1d(
            in_channels=self.input_dim,
            out_channels=1,
            kernel_size=self.cnn_kernel,
            stride=self.cnn_stride,
            padding=0,
            dilation=1,
        )
        self.max_pool = nn.MaxPool1d(
            kernel_size=self.mp_kernel, stride=self.mp_stride, padding=0, dilation=1
        )

        self.lstm = nn.LSTM(
            input_size=1,
            hidden_size=self.lstm_hidden_dim,
            num_layers=2,
            batch_first=True,
            dropout=0,
            bidirectional=True,
        )
    
    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = F.relu(self.cnn(x))
        x = self.max_pool(x)

        x = x.permute(0, 2, 1)
        x, (_, _) = self.lstm(x)

        x = x[:, :, : self.lstm_hidden_dim] + x[:, :, self.lstm_hidden_dim :]
        return x


class Decoder(nn.Module):
    def __init__(self, upsample_scale, input_dim, hidden_dim, deconv_kernel, deconv_stride) -> None:
        super().__init__()

        self.upsample_scale = upsample_scale
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.deconv_kernel = deconv_kernel
        self.deconv_stride = deconv_stride

        self.upsample = nn.Upsample(scale_factor=self.upsample_scale)
        self.deconv_cnn = nn.ConvTranspose1d(
            in_channels=self.hidden_dim,
            out_channels=self.input_dim,
            kernel_size=self.deconv_kernel,
            stride=self.deconv_stride,
            padding=0,
        )

    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = self.upsample(x)

        x = self.deconv_cnn(x)
        x = x.permute(0, 2, 1)
        return x


class DTC:
    def __init__(self) -> None:
        super().__init__()

        self.encoder = Encoder(input_dim=1, seq_len=100, cnn_kernel=10, cnn_stride=3, mp_kernel=10, mp_stride=3, lstm_hidden_dim=1)
        self.decoder = Decoder(upsample_scale=2, input_dim=1, hidden_dim=1, deconv_kernel=10, deconv_stride=6)

        self.autoencoder_pretrained = True

    def pretrain_autoencoder(self):
        for _ in range(10):
            self.train_autoencoder_one_epoch()





    def forward(self, x):
        if not self.autoencoder_pretrained:
            raise Exception('Autoencoder not pretrained.')

        if not self.centroids_initialised:
            raise Exception('Cluster centroids not initialised.')

        l = self.encoder(x)

        return l


def euclidean_distance(x, y):
    '''
    Return (x.shape[0], y.shape[0]) matrix where each element is d(x_i, y_i) 
    where x_i is the i-th time series in x => x_i = x[i].
    '''
    a = x.repeat(1,1,y.shape[0]).permute(0,2,1)
    b = y.repeat(x.shape[0],1,1).reshape(a.shape)
    return torch.sqrt(torch.sum((a - b)**2, dim=2))





In [54]:
x = torch.rand(size=(2, 100, 1))
x.shape

torch.Size([2, 100, 1])

In [58]:
run = wandb.init(project="DTC", name='test')
dataset = PolynomialDataset(
    run,
    "tristanbester1/DTC/polynomial_dataset_X:v0",
    "tristanbester1/DTC/polynomial_dataset_Y:v0",
)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_size, test_size]
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtristanbester1[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [67]:
encoder = Encoder(input_dim=1, seq_len=100, cnn_kernel=10, cnn_stride=3, mp_kernel=10, mp_stride=3, lstm_hidden_dim=1)
decoder = Decoder(upsample_scale=2, input_dim=1, hidden_dim=1, deconv_kernel=10, deconv_stride=6)

autoencoder = nn.Sequential(encoder, decoder)



In [121]:
train_loader = DataLoader(dataset=train_dataset, batch_size=10)
test_loader = DataLoader(dataset=test_dataset, batch_size=1)

device = torch.device('cpu')
criterion = nn.MSELoss()
optimizer = optim.Adam(autoencoder.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, factor=0.5, patience=5, threshold=0.01, verbose=True,
)

In [117]:
def train_autoencoder_one_epoch(model, optimizer, criterion, data_loader, device):
    model.train()
    ave_loss = 0

    for x, _ in data_loader:
        x = x.to(device)

        optimizer.zero_grad()
        outputs = model(x)
        loss = criterion(outputs, x)
        loss.backward()
        optimizer.step()

        ave_loss += loss.item()
    return ave_loss / len(data_loader)

def init_centroids(autoencoder, data_loader):
    L = []

    for x, _ in train_loader:
        L.append(encoder(x).detach())
    L = torch.cat(L)

    similarity_matrix = euclidean_distance(L, L)

    cluster_algo = AgglomerativeClustering(
        n_clusters=3,
        affinity='euclidean',
        linkage='complete',
    )

    assignments = cluster_algo.fit_predict(similarity_matrix)

    centroids = []

    for i in np.unique(assignments):
        centroid = L[assignments == i].mean(dim=0).unsqueeze(0)
        centroids.append(centroid)

    centroids = torch.cat(centroids)    
    return centroids

In [77]:
train_autoencoder_one_epoch(autoencoder, optimizer, criterion, train_loader, device)

0.2597264308482409

In [119]:
centroids = init_centroids(autoencoder, train_loader)

  out = hierarchy.linkage(X, method=linkage, metric=affinity)


In [120]:
centroids

tensor([[[ 0.8481],
         [ 0.6248],
         [ 0.2972],
         [ 0.0330],
         [-0.0683],
         [-0.0994],
         [-0.1102],
         [-0.0787]],

        [[-0.2263],
         [-0.1735],
         [-0.1474],
         [-0.1528],
         [-0.1621],
         [-0.1779],
         [-0.2171],
         [-0.2192]],

        [[ 1.0319],
         [ 0.9851],
         [ 0.8217],
         [ 0.5410],
         [ 0.2272],
         [-0.1579],
         [-0.2714],
         [-0.2208]]])

In [146]:
class ClusteringLayer(nn.Module):
    def __init__(self, centroids) -> None:
        super().__init__()

        self.centroids = nn.Parameter(centroids)

    def students_t_distribution_kernel(self, x, alpha):
        num = torch.pow((1 + x/alpha), -(alpha+1)/2)
        denom = num.sum(dim=1).reshape(-1,1).repeat(1, self.centroids.shape[0])
        return num / denom

    def target_distribution(self, Q):
        F = Q.sum(dim=0)
        num = (Q**2) / F
        denom = num.sum(dim=1).reshape(-1,1).repeat(1,self.centroids.shape[0])
        return num / denom

    def forward(self, x):
        D = euclidean_distance(x, self.centroids)
        
        Q = self.students_t_distribution_kernel(D, 3)
        F = Q.sum(dim=0)
        
        P = self.target_distribution(Q)
        
        log_P = torch.log(P)
        log_Q = torch.log(Q)
        return log_Q, log_P


In [148]:
cluster_layer = ClusteringLayer(centroids)

In [156]:
criterion = nn.KLDivLoss(log_target=True, reduction='batchmean')
optimizer = optim.Adam(cluster_layer.parameters(), lr=0.001)

In [157]:
for x, _ in train_loader:
    l = encoder(x).detach()
    l_q, l_p = cluster_layer(l)
    loss = criterion(l_q, l_p)
    loss.backward()
    optimizer.step()
    # print(loss.item())
