In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torch.optim import Adam
import itertools
from sklearn.mixture import GaussianMixture
from sklearn.metrics import accuracy_score
import numpy as np
import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset


def get_data(data="MNIST", batch_size=128):
    # Datasets loading
    data_dir = f"./data/{data}/"
    if data == "MNIST":
        train_dataset = datasets.MNIST(
            root="./mnist_data/",
            train=True,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Lambda(torch.flatten)]
            ),
            download=True,
        )
        test_dataset = datasets.MNIST(
            root="./mnist_data/",
            train=False,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Lambda(torch.flatten)]
            ),
            download=False,
        )
    elif data == "FashionMNIST":
        train_dataset = datasets.FashionMNIST(
            root="./mnist_data/",
            train=True,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Lambda(torch.flatten)]
            ),
            download=True,
        )
        test_dataset = datasets.FashionMNIST(
            root="./mnist_data/",
            train=False,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Lambda(torch.flatten)]
            ),
            download=False,
        )
    # Data Loader (Input Pipeline)
    train_loader = DataLoader(
        dataset=train_dataset, batch_size=batch_size, shuffle=True
    )
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader


def cluster_acc(Y_pred, Y):
    # from sklearn.utils.linear_assignment_ import linear_assignment - is broken
    # from scipy.optimize import linear_sum_assignment as linear_assignment - could be a replacement but is a bit different
    assert Y_pred.size == Y.size
    D = max(Y_pred.max(), Y.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)
    for i in range(Y_pred.size):
        w[Y_pred[i], Y[i]] += 1
    ind = linear_assignment(w.max() - w)
    return sum([w[i, j] for i, j in ind]) * 1.0 / Y_pred.size, w


def get_hidden_layer(in_dim, out_dim):
    return [nn.Linear(in_dim, out_dim), nn.ReLU(True)]


class Encoder(nn.Module):
    def __init__(self, input_dim=784, hidden_dims=[512, 512, 2048], stat_dim=10):
        super(Encoder, self).__init__()
        self.mu_l = nn.Linear(hidden_dims[-1], stat_dim)
        self.log_sigma2_l = nn.Linear(hidden_dims[-1], stat_dim)
        self.encoder = nn.Sequential(
            *get_hidden_layer(input_dim, hidden_dims[0]),
            *get_hidden_layer(hidden_dims[0], hidden_dims[1]),
            *get_hidden_layer(hidden_dims[1], hidden_dims[2]),
        )

    def forward(self, x):
        e = self.encoder(x)
        return self.mu_l(e), self.log_sigma2_l(e)


class Decoder(nn.Module):
    def __init__(self, input_dim=784, hidden_dims=[512, 512, 2048], stat_dim=10):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            *get_hidden_layer(stat_dim, hidden_dims[-1]),
            *get_hidden_layer(hidden_dims[-1], hidden_dims[-2]),
            *get_hidden_layer(hidden_dims[-2], hidden_dims[-3]),
            nn.Linear(hidden_dims[-3], input_dim),
            nn.Sigmoid(),
        )

    def forward(self, z):
        x_pro = self.decoder(z)
        return x_pro


class VaDE(nn.Module):
    def __init__(
        self,
        n_clusters,
        stat_dim,
        hidden_dims=[512, 512, 2048],
        input_dim=784,
        cuda=True,
    ):
        super(VaDE, self).__init__()
        self.n_clusters = n_clusters
        self.stat_dim = stat_dim
        self.cuda = torch.cuda.is_available() and cuda
        self.encoder = Encoder(
            input_dim=input_dim, hidden_dims=hidden_dims, stat_dim=stat_dim
        )
        self.decoder = Decoder(
            input_dim=input_dim, hidden_dims=hidden_dims, stat_dim=stat_dim
        )
        self.pi_ = nn.Parameter(
            torch.FloatTensor(
                self.n_clusters,
            ).fill_(1)
            / self.n_clusters,
            requires_grad=True,
        )
        self.mu_c = nn.Parameter(
            torch.FloatTensor(self.n_clusters, self.stat_dim).fill_(0),
            requires_grad=True,
        )
        self.log_sigma2_c = nn.Parameter(
            torch.FloatTensor(self.n_clusters, self.stat_dim).fill_(0),
            requires_grad=True,
        )
        if self.cuda:
            self = self.cuda()
            self = nn.DataParallel(self, device_ids=range(4))

    def pre_train(self, dataloader, pre_epoch=10):
        if not os.path.exists("./pretrained_model.pk"):
            Loss_fn = nn.MSELoss()
            opti = Adam(
                itertools.chain(self.encoder.parameters(), self.decoder.parameters())
            )
            print("Pretraining......")
            epoch_bar = tqdm(range(pre_epoch))
            for _ in epoch_bar:
                L = 0
                for x, y in dataloader:
                    if self.cuda:
                        x = x.cuda()

                    z, _ = self.encoder(x)
                    x_ = self.decoder(z)
                    loss = Loss_fn(x, x_)

                    L += loss.detach().cpu().numpy()

                    opti.zero_grad()
                    loss.backward()
                    opti.step()
                epoch_bar.write("L2={:.4f}".format(L / len(dataloader)))
            self.encoder.log_sigma2_l.load_state_dict(self.encoder.mu_l.state_dict())
            Z = []
            Y = []
            with torch.no_grad():
                for x, y in dataloader:
                    if self.cuda:
                        x = x.cuda()
                    z1, z2 = self.encoder(x)
                    assert F.mse_loss(z1, z2) == 0
                    Z.append(z1)
                    Y.append(y)
            Z = torch.cat(Z, 0).detach().cpu().numpy()
            Y = torch.cat(Y, 0).detach().numpy()
            gmm = GaussianMixture(n_components=self.n_clusters, covariance_type="diag")
            pre = gmm.fit_predict(Z)
            # print('Acc={:.4f}%'.format(cluster_acc(pre, Y)[0] * 100))
            if self.cuda:
                self.pi_.data = torch.from_numpy(gmm.weights_).cuda().float()
                self.mu_c.data = torch.from_numpy(gmm.means_).cuda().float()
                self.log_sigma2_c.data = torch.log(
                    torch.from_numpy(gmm.covariances_).cuda().float()
                )
            else:
                self.pi_.data = torch.from_numpy(gmm.weights_).float()
                self.mu_c.data = torch.from_numpy(gmm.means_).float()
                self.log_sigma2_c.data = torch.log(
                    torch.from_numpy(gmm.covariances_).float()
                )
            torch.save(self.state_dict(), "./pretrained_model.pk")
        else:
            self.load_state_dict(torch.load("./pretrained_model.pk"))

    def train(self, dataloader, epochs=100, lr=2e-3, gamma=0.95):
        opti = Adam(self.parameters(), lr=lr)
        lr_s = StepLR(opti, step_size=10, gamma=gamma)
        writer = SummaryWriter("./logs")
        epoch_bar = tqdm(range(epochs))
        for epoch in epoch_bar:
            L = 0
            for x, _ in dataloader:
                if self.cuda:
                    x = x.cuda()
                loss = vade.ELBO_Loss(x)
                opti.zero_grad()
                loss.backward()
                opti.step()
                L += loss.detach().cpu().numpy()
            lr_s.step()
            # pre=[]
            # tru=[]
            # with torch.no_grad():
            # for x, y in dataloader:
            # if self.cuda:
            #    x = x.cuda()
            # tru.append(y.numpy())
            # pre.append(self.predict(x))
            # tru=np.concatenate(tru,0)
            # pre=np.concatenate(pre,0)
            writer.add_scalar("loss", L / len(DL), epoch)
            # writer.add_scalar('acc',cluster_acc(pre,tru)[0]*100,epoch)
            writer.add_scalar("lr", lr_s.get_last_lr()[0], epoch)
            # epoch_bar.write('Loss={:.4f},ACC={:.4f}%,LR={:.4f}'.format(L/len(DL),cluster_acc(pre,tru)[0]*100,lr_s.get_last_lr()[0]))
            epoch_bar.write(
                "Loss={:.4f},LR={:.4f}".format(L / len(DL), lr_s.get_last_lr()[0])
            )

    def predict(self, x):
        z_mu, z_sigma2_log = self.encoder(x)
        z = torch.randn_like(z_mu) * torch.exp(z_sigma2_log / 2) + z_mu
        # pi = self.pi_
        # log_sigma2_c = self.log_sigma2_c
        # mu_c = self.mu_c
        y_c = torch.exp(
            torch.log(self.pi_.unsqueeze(0))
            + self.gaussian_pdfs_log(z, self.mu_c, self.log_sigma2_c)
        )
        y = y_c.detach().cpu().numpy()
        return np.argmax(y, axis=1)

    def ELBO_Loss(self, x, L=1, det=1e-10):
        L_rec = 0
        z_mu, z_sigma2_log = self.encoder(x)
        for l in range(L):
            z = torch.randn_like(z_mu) * torch.exp(z_sigma2_log / 2) + z_mu
            x_pro = self.decoder(z)  # x_pro sometimes has nans
            try:
                L_rec += F.binary_cross_entropy(x_pro, x)
            except:
                print(x_pro.min(), x_pro.max())
        L_rec = L_rec / L
        Loss = L_rec * x.size(1)
        # pi=self.pi_
        # log_sigma2_c=self.log_sigma2_c
        # mu_c=self.mu_c
        z = torch.randn_like(z_mu) * torch.exp(z_sigma2_log / 2) + z_mu
        y_c = (
            torch.exp(
                torch.log(self.pi_.unsqueeze(0))
                + self.gaussian_pdfs_log(z, self.mu_c, self.log_sigma2_c)
            )
            + det
        )
        y_c = y_c / (y_c.sum(1).view(-1, 1))  # batch_size*Clusters
        Loss += 0.5 * torch.mean(
            torch.sum(
                y_c
                * torch.sum(
                    self.log_sigma2_c.unsqueeze(0)
                    + torch.exp(
                        z_sigma2_log.unsqueeze(1) - self.log_sigma2_c.unsqueeze(0)
                    )
                    + (z_mu.unsqueeze(1) - self.mu_c.unsqueeze(0)).pow(2)
                    / torch.exp(self.log_sigma2_c.unsqueeze(0)),
                    2,
                ),
                1,
            )
        )
        Loss -= torch.mean(
            torch.sum(y_c * torch.log(self.pi_.unsqueeze(0) / (y_c)), 1)
        ) + 0.5 * torch.mean(torch.sum(1 + z_sigma2_log, 1))
        return Loss

    def gaussian_pdfs_log(self, x, mus, log_sigma2s):
        G = []
        for c in range(self.n_clusters):
            G.append(
                self.gaussian_pdf_log(
                    x, mus[c : c + 1, :], log_sigma2s[c : c + 1, :]
                ).view(-1, 1)
            )
        return torch.cat(G, 1)

    @staticmethod
    def gaussian_pdf_log(x, mu, log_sigma2):
        return -0.5 * (
            torch.sum(
                np.log(np.pi * 2)
                + log_sigma2
                + (x - mu).pow(2) / torch.exp(log_sigma2),
                1,
            )
        )


In [4]:
from tqdm import tqdm
import numpy as np
from torch.optim import Adam
from sklearn.metrics import accuracy_score
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn


batch_size=800
n_clusters=10
stat_dim=10

DL,_= get_data('MNIST',batch_size)
vade=VaDE(n_clusters, stat_dim)
vade.pre_train(DL,pre_epoch=2)
vade.train(DL, epochs=3)

 33%|███▎      | 1/3 [00:36<01:13, 36.93s/it]

Loss=214.1563,LR=0.0020


 67%|██████▋   | 2/3 [01:15<00:37, 37.87s/it]

Loss=169.0510,LR=0.0020


100%|██████████| 3/3 [01:52<00:00, 37.64s/it]

Loss=153.0145,LR=0.0020



