In [1]:
import argparse
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from tqdm import tqdm
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
import itertools
from sklearn.mixture import GaussianMixture
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from scipy.optimize import linear_sum_assignment
from sklearn.manifold import TSNE
from tensorboardX import SummaryWriter
import numpy as np
import os

In [2]:
# Create custom dataset class
class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels[:,0]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

In [3]:
def cluster_acc(Y_pred, Y):
    assert Y_pred.size == Y.size
    D = max(Y_pred.max(), Y.max())+1
    w = np.zeros((D,D), dtype=np.int64)
    for i in range(Y_pred.size):
        w[Y_pred[i], Y[i]] += 1
    row_ind, col_ind = linear_sum_assignment(w.max() - w)
    return sum([w[i, j] for i, j in zip(row_ind, col_ind)]) * 1.0 / Y_pred.size, w

In [4]:
class Encoder(nn.Module):
    def __init__(self, latent_dim, input_channels, hidden_dim):
        super(Encoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv1d(input_channels, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm1d(128),
            nn.ReLU(True),
            nn.Conv1d(128, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm1d(64),
            nn.ReLU(True),
        )

        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_sigma = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x):
        x = self.encoder(x)
        x = x.contiguous().view(-1, x.size(-1)) # flatten
        x = self.fc1(x)
        mu = self.fc_mu(x)
        logvar = self.fc_sigma(x)

        return mu, logvar

In [5]:
class Decoder(nn.Module):
    def __init__(self, latent_dim, output_channels, hidden_dim):
        super(Decoder, self).__init__()

        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(64, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm1d(128),
            nn.ReLU(True),
            nn.ConvTranspose1d(128, output_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

        self.fc2 = nn.Linear(latent_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, z):
        z = self.fc2(z)
        z = self.fc3(z)
        z = z.contiguous().view(-1, 64, z.size(1)) # reshape
        z = self.decoder(z)
        return z

In [6]:
class VaDE(nn.Module):
    def __init__(self, args):
        super(VaDE, self).__init__()
        self.encoder = Encoder(10, 299, 25)
        self.decoder = Decoder(10, 299, 25)

        self.pi = nn.Parameter(torch.FloatTensor(args.nClusters,).fill_(1)/args.nClusters, requires_grad=True) # inital cluster probabilities
        self.mu_c = nn.Parameter(torch.FloatTensor(args.nClusters, args.latent_dim).fill_(0), requires_grad=True) # inital cluster means
        self.logvar_c = nn.Parameter(torch.FloatTensor(args.nClusters, args.latent_dim).fill_(0), requires_grad=True) # inital cluster variances

        self.args = args

    def pre_train(self, dataloader, pre_epoch=10):
        
        if not os.path.exists("./pretrain_model.pk"):
            Loss = nn.MSELoss()
            optimizer = Adam(itertools.chain(self.encoder.parameters(), self.decoder.parameters()), lr=1e-3)

            print("Pre-training...")
            epoch_bar = tqdm(range(pre_epoch))
            for _ in epoch_bar:
                L = 0
                for x, y in dataloader:
                    if self.args.cuda:
                        x = x.cuda()
                    
                    z, _ = self.encoder(x)
                    x_recon = self.decoder(z)
                    loss = Loss(x_recon, x)
                    L += loss.detach().cpu().numpy()

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                epoch_bar.write('L2={:.4f}'.format(L/len(dataloader)))

            self.encoder.fc_sigma.load_state_dict(self.encoder.fc_mu.state_dict())

            Z = []
            Y = []
            with torch.no_grad():
                for x, y in dataloader:
                    if self.args.cuda:
                        x = x.cuda()
                    
                    z_mu, z_logvar = self.encoder(x) # dim -> (batch_size*64, latent_dim)
                    assert F.mse_loss(z_mu, z_logvar) == 0
                    y = y.repeat_interleave(64) # dim -> (batch_size*64,)
                    Z.append(z_mu)
                    Y.append(y)

            Z = torch.cat(Z, dim=0).cpu().numpy()
            Y = torch.cat(Y, dim=0).cpu().numpy()

            gmm = GaussianMixture(n_components=self.args.nClusters, covariance_type='diag')

            pre = gmm.fit_predict(Z)
            print('Acc={:.4f}%'.format(cluster_acc(pre, Y)[0] * 100))

            self.pi.data = torch.from_numpy(gmm.weights_).cuda().float()
            self.mu_c.data = torch.from_numpy(gmm.means_).cuda().float()
            self.logvar_c.data = torch.log(torch.from_numpy(gmm.covariances_).cuda().float())

            torch.save(self.state_dict(), './pretrain_model.pk')

        else:

            self.load_state_dict(torch.load('./pretrain_model.pk'))

    def predict(self, x):
        z_mu, z_logvar = self.encoder(x)
        z = torch.randn_like(z_mu) * torch.exp(z_logvar/2) + z_mu # reparameterization trick
        pi = self.pi
        mu_c = self.mu_c
        logvar_c = self.logvar_c
        yita_c = torch.exp( torch.log(pi.unsqueeze(0)) + self.gaussian_pdfs_log(z, mu_c, logvar_c) ) # p(c)*p(z|c)
        
        yita = yita_c.detach().cpu().numpy()
        return np.argmax(yita, axis=1)
    
    def ELBO_loss(self, x, L=1):
        det = 1e-10
        L_rec = 0
        z_mu, z_logvar = self.encoder(x)
        for l in range(L):
            z = torch.randn_like(z_mu) * torch.exp(z_logvar/2) + z_mu # reparameterization trick
            x_pro = self.decoder(z)
            L_rec += F.binary_cross_entropy(x_pro, x)

        L_rec/=L
        Loss = L_rec*x.size(1)

        pi=self.pi
        log_sigma2_c=self.logvar_c
        mu_c=self.mu_c

        z = torch.randn_like(z_mu) * torch.exp(z_logvar / 2) + z_mu
        yita_c=torch.exp(torch.log(pi.unsqueeze(0))+self.gaussian_pdfs_log(z,mu_c,log_sigma2_c))+det

        yita_c=yita_c/(yita_c.sum(1).view(-1,1))#batch_size*Clusters

        Loss+=0.5*torch.mean(torch.sum(yita_c*torch.sum(log_sigma2_c.unsqueeze(0)+
                                                torch.exp(z_logvar.unsqueeze(1)-log_sigma2_c.unsqueeze(0))+
                                                (z_mu.unsqueeze(1)-mu_c.unsqueeze(0)).pow(2)/torch.exp(log_sigma2_c.unsqueeze(0)),2),1))

        Loss-=torch.mean(torch.sum(yita_c*torch.log(pi.unsqueeze(0)/(yita_c)),1))+0.5*torch.mean(torch.sum(1+z_logvar,1))


        return Loss


    def gaussian_pdfs_log(self,x,mus,log_sigma2s):
        G=[]
        for c in range(self.args.nClusters):
            G.append(self.gaussian_pdf_log(x,mus[c:c+1,:],log_sigma2s[c:c+1,:]).view(-1,1))
        return torch.cat(G,1)
    
    @staticmethod
    def gaussian_pdf_log(x,mu,log_sigma2):
        return -0.5*(torch.sum(np.log(np.pi*2)+log_sigma2+(x-mu).pow(2)/torch.exp(log_sigma2),1))

In [7]:
# args setting
class Args:
    def __init__(self, nClusters=15, latent_dim=10, cuda=True):
        self.nClusters = nClusters
        self.latent_dim = latent_dim
        self.cuda = cuda

args = Args()

In [8]:
# Load data
data = torch.load('data_100_mouth.pt')
labels = torch.load('labels_100_mouth.pt')

In [9]:
# data set
dataset = CustomDataset(data, labels)
# dataloder
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [10]:
# pre-train
vade = VaDE(args)
if args.cuda:
    vade.cuda()

vade.pre_train(dataloader, pre_epoch=100)

In [12]:
# train on elbo loss
optimizer = Adam(vade.parameters(), lr=2e-3)
lr_s = StepLR(optimizer, step_size=100, gamma=0.5)

writer=SummaryWriter()

epoch_bar = tqdm(range(300))
tsne = TSNE()

for epoch in epoch_bar:
    L = 0
    for x, _ in dataloader:
        if args.cuda:
            x = x.cuda()
        loss = vade.ELBO_loss(x)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_s.step()

        L += loss.detach().cpu().numpy()

    pre = []
    tru = []

    with torch.no_grad():
        for x, y in dataloader:
            if args.cuda:
                x = x.cuda()
            
            y = y.repeat_interleave(64)
            tru.append(y.numpy())
            pre.append(vade.predict(x))

    tru = np.concatenate(tru, 0)
    pre = np.concatenate(pre, 0)

    writer.add_scalar('loss',L/len(dataloader),epoch)
    writer.add_scalar('acc',cluster_acc(pre,tru)[0]*100,epoch)
    writer.add_scalar('lr',lr_s.get_lr()[0],epoch)

    epoch_bar.write('Loss={:.4f},ACC={:.4f}%,LR={:.4f}'.format(L/len(dataloader),cluster_acc(pre,tru)[0]*100,lr_s.get_lr()[0]))

  0%|          | 0/300 [00:00<?, ?it/s]


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.