In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision
import torch.optim as optim

In [None]:
from sklearn.cluster import KMeans
from sklearn.metrics import confusion_matrix, normalized_mutual_info_score
from unidip.dip import diptst
from tqdm import tqdm
import time
import datetime

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Will use {device}')

#### Define Encoder and Decoder

In [None]:
class Encoder(nn.Module):
    def __init__(self, d, m=5):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(nn.Linear(d, 500),
                                     nn.ReLU(),
                                     nn.Linear(500, 500),
                                     nn.ReLU(),
                                     nn.Linear(500, 2000),
                                     nn.ReLU(),
                                     nn.Linear(2000, m),)
    def forward(self, inputs):
        outputs = self.encoder(inputs)
        return outputs
    
class Decoder(nn.Module):
    def __init__(self, d, m=5):
        super(Decoder, self).__init__()
        self.encoder = nn.Sequential(nn.Linear(m, 2000),
                                     nn.ReLU(),
                                     nn.Linear(2000, 500),
                                     nn.ReLU(),
                                     nn.Linear(500, 500),
                                     nn.ReLU(),
                                     nn.Linear(500, d),
                                     nn.Sigmoid())
    def forward(self, inputs):
        outputs = self.encoder(inputs)
        return outputs

#### Define Loss

In [None]:
class Lrec(nn.Module):
    def __init__(self):
        super(Lrec, self).__init__()
        
    def forward(self, output, targets):
        return torch.mean(torch.sum(torch.square(output-targets), dim=-1))
    
class Lclu(nn.Module):
    def __init__(self):
        super(Lclu, self).__init__()
        
    def forward(self, encoded_inputs, encoded_centres, Dc, P):
        k = Dc.shape[0]
        mean = torch.sum(Dc)/(k**2-k)
        std = torch.sqrt(torch.sum(torch.square(Dc-mean))/(k**2-k))
        return ((1+std)/mean)*torch.mean(torch.sum(P*torch.sum(torch.square(torch.unsqueeze(encoded_inputs, 1)-encoded_centres),-1),-1))

#### Define pre-train function

In [None]:
def Pre_train_AE(dataset, epochs=100, lr=0.001, batch_size=256):
    N_feature = np.prod(dataset.data.shape[1:])
    data = dataset.data.view(-1, N_feature).to(device)
    enc = Encoder(N_feature).to(device)
    dec = Decoder(N_feature).to(device)
    optimizer1 = torch.optim.Adam(enc.parameters(), lr=lr)
    optimizer2 = torch.optim.Adam(dec.parameters(), lr=lr)
    Lrec_function = Lrec().to(device)
    
    for epoch in range(epochs):
        total_loss = 0
        random_perm = np.random.permutation(len(dataset))
        mini_batch_index = 0
        while True:
            indices = random_perm[mini_batch_index:mini_batch_index + batch_size]
            inputs = data[indices]
            
            enc.zero_grad() #清空前一次的gradient
            dec.zero_grad() #清空前一次的gradient
            
            encoded = enc(inputs)
            decoded = dec(encoded)
            loss = Lrec_function(decoded, inputs)
            loss.backward()
            
            optimizer1.step()
            optimizer2.step()
            total_loss += loss
            
            mini_batch_index += batch_size
            if mini_batch_index >= len(dataset):
                break
        total_loss /= len(dataset)
        
        print(f'[{epoch+1}/{epochs}] Loss: {total_loss}')
    return enc, dec

#### Define DipDECK

In [None]:
class DipDECK:
    def __init__(self, dataset, enc, dec, k_init=15, P_threshold=0.9, epochs=50, batch_size = 256):
        self.BATCHSIZE = batch_size
        self.N_DATA, self.N_FEATURE = len(dataset), np.prod(dataset.data.shape[1:])
        self.data = dataset.data.type(torch.float32).view(-1,self.N_FEATURE)
        self.P_threshold = P_threshold
        self.k = k_init
        self.enc = enc
        self.dec = dec
        print('******************** Apply K-means ********************')
        with torch.no_grad():
            Km_model = KMeans(self.k)
            Km_model.fit(enc(self.data.to(device)).cpu())
        self.kmCentres = torch.tensor(Km_model.cluster_centers_, dtype=torch.float32)
        self.labels = torch.tensor(Km_model.labels_, dtype=torch.int64)
        print('********** finding closest points to kmCentres **********')
        self.centres = self.find_centres1()
        # print('******************** building DipMatrix ********************')
        self.DipMatrix = self.build_DipMatrix()
        print('******************** Batch optimizing ********************')
        self.Batch_optimize(epochs)
        
    def Batch_optimize(self, epochs, lr=0.0001):
        optimizer1 = torch.optim.Adam(self.enc.parameters(), lr=lr)
        optimizer2 = torch.optim.Adam(self.dec.parameters(), lr=lr)
        Lrec_function = Lrec().to(device)
        Lclu_function = Lclu().to(device)
        
        i = 0
        while i<epochs:
            print(f'********** epoch {i+1} of {epochs} **********')
            random_perm = np.random.permutation(self.N_DATA)
            mini_batch_index = 0
            j = 0
            while True:
                indices = random_perm[mini_batch_index:mini_batch_index + self.BATCHSIZE]
                if i!=0:
                    self.update_labels(indices)
                inputs = self.data[indices].to(device)
                encoded_centres = self.enc(self.centres.to(device))
                r = torch.sum(torch.square(encoded_centres), axis=1)
                Dc = torch.sqrt((r+encoded_centres@encoded_centres.T).T+r)
                specific_PMetrix = self.nDipMatrix[torch.unsqueeze(self.labels[indices],0)[0]].to(device)
                
                self.enc.zero_grad()
                self.dec.zero_grad()
                
                encoded = self.enc(inputs)
                decoded = self.dec(encoded)
                L1 = Lrec_function(decoded, inputs)
                L2 = Lclu_function(encoded, encoded_centres, Dc, specific_PMetrix)
                L = L1 + L2
                L.backward()
                
                optimizer1.step()
                optimizer2.step()
                print(f'[{j+1}/{self.N_DATA//self.BATCHSIZE+1 if self.N_DATA%self.BATCHSIZE else self.N_DATA//self.BATCHSIZE}] batch loss:{L/self.BATCHSIZE} (Lrec:{L1/self.BATCHSIZE}, Lclu:{L2/self.BATCHSIZE})')
                
                mini_batch_index += self.BATCHSIZE
                if mini_batch_index >= self.N_DATA:
                    break
                j += 1
                    
            self.update_labels(range(self.k))
            self.centres = self.find_centres2()
            self.DipMatrix = self.build_DipMatrix()
            i += 1
            with torch.no_grad():
                while torch.max(self.DipMatrix-torch.eye(self.k))>=self.P_threshold:
                    self.k -= 1
                    print(f'********** merging (remain {self.k} cluster) **********')
                    argmax = torch.argmax(self.DipMatrix-torch.eye(self.k+1))
                    Ci, Cj = argmax//self.DipMatrix.shape[1], argmax%self.DipMatrix.shape[1]
                    Ci, Cj = (Ci, Cj) if Ci<=Cj else (Cj, Ci)
                    self.labels[self.labels==Cj] = Ci
                    self.labels[self.labels>Cj] = self.labels[self.labels>Cj]-1
                    new_centre = self.find_merged_centre(Ci, Cj)
                    self.centres[Ci] = new_centre
                    self.centres = torch.concat([self.centres[:Cj,:], self.centres[Cj+1:,:]], 0)
                    self.DipMatrix = self.update_DipMatrix(Ci, Cj)
                    i = 0
            
    @torch.no_grad()
    def find_centres1(self):
        encoded_data = self.enc(self.data.to(device))
        dist = torch.sum(torch.square(encoded_data-torch.unsqueeze(self.kmCentres.to(device), 1)),-1)
        return self.data[torch.argmin(dist, -1)]
    
    @torch.no_grad()
    def find_centres2(self):
        encoded_data = self.enc(self.data.to(device))
        centres = torch.zeros(self.k, self.data.shape[1]).to(device)
        for i in range(self.k):
            centre_ = torch.mean(encoded_data[self.labels==i], axis=0)
            dist = torch.sum(torch.square(encoded_data-centre_), axis=-1)
            centres[i] = self.data[torch.argmin(dist)]
        return centres
    
    @torch.no_grad()
    def find_merged_centre(self, Ci, Cj):
        data_CiCj = self.data[torch.logical_or(self.labels==Ci,self.labels==Cj)].to(device)
        encoded_data_CiCj = self.enc(data_CiCj)
        encoded_centres = self.enc(self.centres.to(device))
        centre_Ci, centre_Cj = encoded_centres[Ci], encoded_centres[Cj]
        N_Ci, N_Cj = torch.sum(self.labels==Ci), torch.sum(self.labels==Cj)
        weighted_centre = (N_Ci*centre_Ci+N_Cj*centre_Cj)/(N_Ci+N_Cj)
        dist = torch.sum(torch.square(encoded_data_CiCj-weighted_centre), axis=-1)
        return data_CiCj[torch.argmin(dist)].cpu()
    
    @torch.no_grad()
    def build_DipMatrix(self):
        print(f'******************** building DipMatrix ********************')
        encoded_data = self.enc(self.data.to(device)).cpu()
        encoded_centres = self.enc(self.centres.to(device)).cpu()
        dip_matrix = torch.eye(self.k, dtype=torch.float32)
        with tqdm(total=self.k*(self.k-1)//2) as pbar:
            for i in range(self.k):
                for j in range(i+1,self.k):
                    points = encoded_data[torch.logical_or(self.labels==i,self.labels==j)]
                    C_1d = torch.sum(points*(encoded_centres[i]-encoded_centres[j]), dim=-1)
#                     C_1d_hist, _ = torch.histogram(C_1d, bins=int(len(C_1d)/3))
                    P1 = diptst(C_1d.numpy(), is_hist=False)[1]
                    N_Ci, N_Cj = torch.sum(self.labels==i), torch.sum(self.labels==j)
                    (Ci, Cj, N_Ci, N_Cj) = (i, j, N_Ci, N_Cj) if N_Ci<=N_Cj else (j, i, N_Cj, N_Ci)
                    if N_Cj>2*N_Ci and N_Ci!=0:
                        points_Cj = encoded_data[self.labels==Cj]
                        dist2Ci = torch.sum(torch.square(points_Cj-encoded_centres[Ci]), dim=-1)
                        partition_points_Cj = points_Cj[dist2Ci<=np.max(np.partition(dist2Ci,2*N_Ci)[:2*N_Ci])]
                        points = torch.concat([encoded_data[self.labels==Ci], partition_points_Cj])
                        C_1d = torch.sum(points*(encoded_centres[i]-encoded_centres[j]), axis=-1)
#                         C_1d_hist, _ = torch.histogram(C_1d, bins=int(len(C_1d)/3))
                        P2 = diptst(C_1d.numpy(), is_hist=False)[1]
                        dip_matrix[i,j] = dip_matrix[j,i] = min(P1, P2)
                    else:
                        dip_matrix[i,j] = dip_matrix[j,i] = P1
                    pbar.update(1)
        self.nDipMatrix = self.Matrix2affine(dip_matrix)
        return dip_matrix
    
    @torch.no_grad()
    def update_DipMatrix(self, idx, jdx):
        print(f'******************** updating DipMatrix ********************')
        encoded_data = self.enc(self.data.to(device)).cpu()
        encoded_centres = self.enc(self.centres.to(device)).cpu()
        dip_matrix = torch.concat([self.DipMatrix[:,:jdx], self.DipMatrix[:,jdx+1:]], 1)
        dip_matrix = torch.concat([dip_matrix[:jdx,:], dip_matrix[jdx+1:,:]], 0)
        with tqdm(total=self.k-1) as pbar:
            for j in range(self.k):
                if j==idx:
                    continue
                points = encoded_data[torch.logical_or(self.labels==idx,self.labels==j)]
                C_1d = torch.sum(points*(encoded_centres[idx]-encoded_centres[j]), axis=-1)
#                 C_1d_hist, _ = torch.histogram(C_1d, bins=int(len(C_1d)/3))
                P1 = diptst(C_1d.numpy(), is_hist=False)[1]
                N_Ci, N_Cj = torch.sum(self.labels==idx), torch.sum(self.labels==j)
                (Ci, Cj, N_Ci, N_Cj) = (idx, j, N_Ci, N_Cj) if N_Ci<=N_Cj else (j, idx, N_Cj, N_Ci)
                if N_Cj>2*N_Ci and N_Ci!=0:
                    points_Cj = encoded_data[self.labels==Cj]
                    dist2Ci = torch.sum(torch.square(points_Cj-encoded_centres[Ci]), axis=-1)
                    partition_points_Cj = points_Cj[dist2Ci<=np.max(np.partition(dist2Ci,2*N_Ci)[:2*N_Ci])]
                    points = torch.concat([encoded_data[self.labels==Ci], partition_points_Cj])
                    C_1d = torch.sum(points*(encoded_centres[idx]-encoded_centres[j]), axis=-1)
#                     C_1d_hist, _ = torch.histogram(C_1d, bins=int(len(C_1d)/3))
                    P2 = diptst(C_1d.numpy(), is_hist=False)[1]
                    dip_matrix[idx,j] = dip_matrix[j,idx] = min(P1, P2)
                else:
                    dip_matrix[idx,j] = dip_matrix[j,idx] = P1
                pbar.update(1)
        self.nDipMatrix = self.Matrix2affine(dip_matrix)
        print(dip_matrix)
        return dip_matrix
    
    def Matrix2affine(self, matrix):
        return (matrix/matrix.sum(1)).T
    
    @torch.no_grad()
    def update_labels(self, indices):
        encoded_data = self.enc(self.data.to(device))
        encoded_centres = self.enc(self.centres.to(device))
        D = torch.sum(torch.square(encoded_data-torch.unsqueeze(encoded_centres, 1)),-1)
        new_labels = torch.argmin(D, axis=0).type(torch.int64).cpu()
        self.labels[indices] = new_labels[indices]

#### Test

In [None]:
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import transforms

import matplotlib.pyplot as plt

In [None]:
train_data = datasets.MNIST('./dataset',
                            train=True,
                            download=False,
                            transform=transforms.ToTensor())
train_data.data = train_data.data/255

In [None]:
start = time.time()
enc, dec = Pre_train_AE(train_data, epochs=100, lr=0.001)
end = time.time()
print(f'\n預訓練耗時{str(datetime.timedelta(seconds=end-start))}')

In [None]:
start = time.time()
Model = DipDECK(train_data, enc.to(device), dec.to(device), k_init=35, epochs=50)
end = time.time()
print(f'共分了{Model.k}群，耗時{str(datetime.timedelta(seconds=end-start))}')

#### confusion_matrix

In [None]:
# #Data=60000, #pre-train epoch=100, #train epoch=50
fig = plt.Figure(figsize=(5,5))
ax = fig.add_subplot(1,1,1)
ax.set_xlabel('cluster', fontsize=16)
ax.set_ylabel('real label', fontsize=16)
confusionmatrix = confusion_matrix(train_data.targets, Model.labels)
ax.set_xticks(range(confusionmatrix.shape[1]))
ax.set_yticks(range(confusionmatrix.shape[0]))
ax.imshow(confusionmatrix[:,:-1], cmap='PuRd')
fig

#### rearranged confusion_matrix

In [None]:
sorted_confusion_matrix = confusionmatrix[:,[2,3,0,1,4,5,6,8,7]]

fig = plt.Figure(figsize=(5,5))
ax = fig.add_subplot(1,1,1)
ax.set_xlabel('rearranged cluster', fontsize=16)
ax.set_ylabel('real', fontsize=16)
ax.set_xticks(range(confusionmatrix.shape[1]))
ax.set_yticks(range(confusionmatrix.shape[0]))
ax.imshow(sorted_confusion_matrix, cmap='PuRd')
fig

#### NMI score

In [None]:
print(f'NMI={normalized_mutual_info_score(train_data.targets, Model.labels)}')

#### Show the Origian, encoded and decoded train data

In [None]:
from ipywidgets import interact
import ipywidgets as widgets

label = 10

encoder = Model.enc.cpu()
decoder = Model.dec.cpu()

dat = train_data.data.view(-1,784)[Model.labels==label]
encoded_img = encoder(dat)
decoded_img = decoder(encoded_img)

encoded_img = encoded_img.detach().numpy()
decoded_img = decoded_img.detach().numpy()

def f(i):
    fig = plt.Figure(figsize=(10,10))
    ax1 = fig.add_subplot(1,3,1)
    ax2 = fig.add_subplot(1,3,2)
    ax3 = fig.add_subplot(1,3,3)
    ax1.axis('off')
    # ax2.axis('off')
    ax3.axis('off')
    ax1.imshow(dat[i].reshape(28,28), cmap='gray')
    ax2.imshow(encoded_img[i].reshape(1,5), cmap='gray')
    print(encoded_img[i].reshape(1,5))
    ax3.imshow(decoded_img[i].reshape(28,28), cmap='gray')
    return fig
    
interact(f, i=widgets.IntSlider(min=0, max=dat.shape[0]-1, step=1, value=0))

#### save DipDECK model

In [None]:
import pickle

filename= # your file name

with open(f'{filename}.pkl', 'wb') as outp:
    pickle.dump((enc, dec), outp, pickle.HIGHEST_PROTOCOL)
    pickle.dump(Model, outp, pickle.HIGHEST_PROTOCOL)