In [1]:
import numpy as np
import torch
import argparse
import math
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision
from utils import cuda, MIEvaluator
from model import VIBencoder, Decoder, Disc, Encoder
from numbers import Number
from pathlib import Path
from thop import profile

Parameters

In [2]:
Cuda = None
epoch = 30
batch_size = 100
lr = 1e-2
eps = 1e-9
K = 10

SVIB Class

In [3]:
class SIB():
    def __init__(self, K=10, n=2, beta=1e-3):
        super(SIB, self).__init__()
        ens = []
        for i in range(n):
            en = Encoder(K=K, n=2)
            ens.append(en)
        self.n = n
        self.K = K
        self.beta = beta
        self.encode = nn.ModuleList(ens)
        self.decode = Decoder(K)
        self.disc = Disc(K, n)
        self.w = nn.Parameter(cuda(torch.ones([n, 1]), Cuda))#.cuda()

    def train(self, dataloader):
        for i, net in enumerate(self.encode):
            print(f"Training encoder {i + 1}")
            if i == 0:
                net.train()
                self.decode.train()
                optim1 = optim.Adam(net.parameters(), lr=lr, betas=(0.5, 0.999))
                optim2 = optim.Adam(self.decode.parameters(), lr=lr, betas=(0.5, 0.999))
                global_epoch = 0

                for e in range(epoch):
                    global_epoch += 1
                    global_iter = 0
                    for idx, (images, labels) in enumerate(dataloader):
                        global_iter += 1

                        x = Variable(cuda(images, Cuda))
                        y = Variable(cuda(labels, Cuda))
                        (mu, std), emb = net(x)
                        logit = self.decode(emb)

                        mu2 = torch.zeros(*mu.size()).to(mu.device)
                        std2 = torch.ones(*std.size()).to(std.device)

                        class_loss = F.cross_entropy(logit, y)
                        info_loss = 0.5 * (2 * (torch.log(std2+1e-8) - torch.log(std+1e-8))+ ((mu - mu2) / std2) ** 2+ (std / std2) ** 2- 1)
                        info_loss = info_loss.mean()
                        total_loss = class_loss + self.beta * (info_loss)


                        optim1.zero_grad()
                        optim2.zero_grad()
                        total_loss.backward()
                        optim1.step()
                        optim2.step()

            else:
                net.train()
                self.decode.eval()
                self.disc.train()
                optim0 = optim.Adam(net.parameters(), lr=lr, betas=(0.5, 0.999))
                optim1 = optim.Adam(self.disc.parameters(), lr=lr, betas=(0.5, 0.999))
                global_epoch = 0
                Nan = False
                for e in range(epoch):
                    if Nan:
                        break
                    global_epoch += 1
                    global_iter = 0
                    for idx, (images, labels) in enumerate(dataloader):
                        global_iter += 1

                        x = Variable(cuda(images, Cuda))
                        y = Variable(cuda(labels, Cuda))
                        (mu, std), emb = net(x)
                        logit = self.decode(emb)

                        before_mu = cuda(torch.zeros([batch_size, self.K]), Cuda)
                        before_std = cuda(torch.zeros([batch_size, self.K]), Cuda)
                        for j in range(i):
                            self.encode[j].eval()
                            (current_mu, current_std), _ = self.encode[j](x)
                            before_mu += current_mu
                            before_std += current_std
                        before_encoding = self.reparametrize_n(before_mu, before_std)

                        true = torch.concat((emb, before_encoding), dim=1)
                        index1 = torch.randperm(emb.shape[0])
                        index2 = torch.randperm(emb.shape[0])
                        shuffled1 = emb[index1, :]
                        shuffled2 = before_encoding[index2, :]
                        false = torch.concat((shuffled1, shuffled2), dim=1)
                        true_stat = self.disc(true.detach().clone())
                        false_stat = self.disc(false.detach().clone())

                        mu2 = torch.zeros(*mu.size()).to(mu.device)
                        std2 = torch.ones(*std.size()).to(std.device)

                        class_loss = F.cross_entropy(logit, y)
                        info_loss = 0.5 * (2 * (torch.log(std2+1e-8) - torch.log(std+1e-8))+ ((mu - mu2) / std2) ** 2 + (std / std2) ** 2- 1)
                        info_loss = info_loss.mean()
                        total_loss = class_loss + self.beta * (info_loss) - torch.log(true_stat+1e-8).mean()
                        if torch.isnan(total_loss):
                            Nan = True
                            break

                        optim0.zero_grad()
                        total_loss.backward(retain_graph=True)
                        optim0.step()

                        #true_stat = self.disc(true.detach().clone())
                        #false_stat = self.disc(false.detach().clone())
                        disc_loss = -torch.log((1-true_stat)+1e-8).mean()-torch.log(false_stat+1e-8).mean()

                        optim1.zero_grad()
                        disc_loss.backward()
                        optim1.step()

        print('Train W')
        optim0 = optim.Adam([self.w], lr=lr, betas=(0.5, 0.999))
        global_epoch = 0
        Nan = False
        for e in range(epoch):
            if Nan:
                break
            global_epoch += 1
            global_iter = 0
            for idx, (images, labels) in enumerate(dataloader):
                global_iter += 1

                x = Variable(cuda(images, Cuda))
                y = Variable(cuda(labels, Cuda))
                mu = cuda(torch.zeros([batch_size, self.K]), Cuda)
                std = cuda(torch.zeros([batch_size, self.K]), Cuda)
                for j in range(self.n):
                    self.encode[j].eval()
                    (current_mu, current_std), _ = self.encode[j](x)
                    mu += current_mu * self.w[j]
                    std += current_std * self.w[j]**2
                encoding = self.reparametrize_n(mu, std)
                logit = self.decode(encoding)

                mu2 = torch.zeros(*mu.size()).to(mu.device)
                std2 = torch.ones(*std.size()).to(std.device)

                class_loss = F.cross_entropy(logit, y)
                info_loss = 0.5 * (2 * (torch.log(std2+1e-8) - torch.log(std+1e-8))+ ((mu - mu2) / std2) ** 2+ (std / std2) ** 2- 1)
                info_loss = info_loss.mean()
                total_loss = class_loss + self.beta * (info_loss)

                if torch.isnan(total_loss):
                    Nan = True
                    break

                optim0.zero_grad()
                total_loss.backward()
                optim0.step()


    #def save(self, encode_path, decode_path, w_path):
    #    self.encode.load_state_dict(torch.load(encode_path))
    #    self.decode.load_state_dict(torch.load(decode_path))
    #    self.w = torch.load(w_path)
        #s = torch.sum(self.w)
        #self.w /= s.detach().clone()

    def load(self, encode_path, decode_path, w_path):
        self.encode.load_state_dict(torch.load(encode_path))
        self.decode.load_state_dict(torch.load(decode_path))
        self.w = torch.load(w_path)

    def fit(self, dataloader):

        def enc(x):
            mu = cuda(torch.zeros([batch_size, self.K]), Cuda)
            std = cuda(torch.zeros([batch_size, self.K]), Cuda)
            for j in range(self.n):
                self.encode[j].eval()
                (current_mu, current_std), _ = self.encode[j](x)
                mu += current_mu * self.w[j]
                std += current_std * self.w[j]
            emb = self.reparametrize_n(mu, std)
            return (mu, std), emb

        accuracy = 0
        global_iter = 0
        for idx, (images, labels) in enumerate(dataloader):
            global_iter += 1

            x = Variable(cuda(images, Cuda))
            y = Variable(cuda(labels, Cuda))
            mu = cuda(torch.zeros([batch_size, self.K]), Cuda)
            std = cuda(torch.zeros([batch_size, self.K]), Cuda)
            for j in range(self.n):
                self.encode[j].eval()
                (current_mu, current_std), _ = self.encode[j](x)
                mu += current_mu * self.w[j]
                std += current_std * self.w[j]
            logit = self.decode(mu)

            prediction = F.softmax(logit, dim=1).max(1)[1]
            accuracy += torch.eq(prediction, y).float().mean()

        accuracy /= global_iter
        print('acc:{:.4f}'
                  .format(accuracy.data), end=' ')

        mi = MIEvaluator(enc, self.decode, mu.device)
        izx_bound = mi.eval_mi_x_z_monte_carlo(dataloader)
        izy_bound = mi.eval_mi_y_z_variational_lb(dataloader, 10)

        print('IZY:{:.2f} IZX:{:.2f}'
                  .format(izy_bound.data, izx_bound.data))

        return izx_bound.data, izy_bound.data

    def reparametrize_n(self, mu, std, n=1):
        def expand(v):
            if isinstance(v, Number):
                return torch.Tensor([v]).expand(n, 1)
            else:
                return v.expand(n, *v.size())

        if n != 1:
            mu = expand(mu)
            std = expand(std)

        eps = Variable(cuda(std.data.new(std.size()).normal_(), std.is_cuda))

        return mu + eps * std

    def parameter_counter(self):
        a = 0
        for p in self.encode.parameters():
            l = 1
            for j in p.size():
                l *= j
            a = a + l

        b = 0
        for p in self.decode.parameters():
            l = 1
            for j in p.size():
                l *= j
            b = b + l
        return a + b

VIB Class

In [4]:
class VIB():
    def __init__(self, K=10, n=2, beta=1e-3):
        super(VIB, self).__init__()
        self.n = n
        self.K = K
        self.beta = beta
        self.encoder = VIBencoder(K=K)
        self.decoder = Decoder(K=K)

    def train(self, dataloader):
        global_epoch = 0
        optim1 = optim.Adam(self.encoder.parameters(), lr=lr, betas=(0.5, 0.999))
        optim2 = optim.Adam(self.decoder.parameters(), lr=lr, betas=(0.5, 0.999))
        for e in range(epoch):
            global_epoch += 1
            self.encoder.train()
            self.decoder.train()
            global_iter = 0
            for idx, (images, labels) in enumerate(dataloader):
                global_iter += 1

                x = Variable(cuda(images, Cuda))
                y = Variable(cuda(labels, Cuda))
                (mu, std), emb = self.encoder(x)
                logit = self.decoder(emb)

                mu2 = torch.zeros(*mu.size()).to(mu.device)
                std2 = torch.ones(*std.size()).to(std.device)

                class_loss = F.cross_entropy(logit, y)
                info_loss = 0.5 * (2 * (torch.log(std2+1e-8) - torch.log(std+1e-8))+ ((mu - mu2) / std2) ** 2+ (std / std2) ** 2- 1)
                info_loss = info_loss.mean()
                total_loss = class_loss + self.beta * (info_loss)

                optim1.zero_grad()
                optim2.zero_grad()
                total_loss.backward()
                optim1.step()
                optim2.step()



    #def save(self, encode_path, decode_path, w_path):
    #    self.encode.load_state_dict(torch.load(encode_path))
    #    self.decode.load_state_dict(torch.load(decode_path))
    #    self.w = torch.load(w_path)

    def load(self, encode_path, decode_path):
        self.encoder.load_state_dict(torch.load(encode_path))
        self.decoder.load_state_dict(torch.load(decode_path))

    def fit(self, dataloader):
        global_epoch = 0
        accuracy = 0
        for idx, (images, labels) in enumerate(dataloader):
            global_epoch += 1
            self.encoder.eval()
            self.decoder.eval()
            x = Variable(cuda(images, Cuda))
            y = Variable(cuda(labels, Cuda))
            (mu, std), emb = self.encoder(x)
            logit = self.decoder(mu)


            prediction = F.softmax(logit, dim=1).max(1)[1]
            accuracy += torch.eq(prediction, y).float().mean()

        accuracy /= global_epoch
        print('acc:{:.4f}'
                      .format(accuracy.data), end=' ')

        mi = MIEvaluator(self.encoder, self.decoder, 'cpu')
        izx_bound = mi.eval_mi_x_z_monte_carlo(dataloader)
        izy_bound = mi.eval_mi_y_z_variational_lb(dataloader, 10)

        print('IZY:{:.2f} IZX:{:.2f}'
                      .format(izy_bound.data, izx_bound.data))

        return izx_bound.data, izy_bound.data

    def parameter_counter(self):
        a = 0
        for p in self.encoder.parameters():
            l = 1
            for j in p.size():
                l *= j
            a = a+l

        b = 0
        for p in self.decoder.parameters():
            l = 1
            for j in p.size():
                l *= j
            b = b + l
        return a+b


DataLoader

In [5]:
transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307), (0.3081))
         ])

train_dataset = torchvision.datasets.MNIST(root='./data',
                                               train=True,
                                               transform=transform,
                                               download=True)
test_dataset = torchvision.datasets.MNIST(root='./data',
                                              train=False,
                                              transform=transform,
                                              download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=100,
                                              shuffle=False)

Compare numbers of parameters

In [6]:
vib_mnist = VIB(beta=torch.tensor(10 ** 0).float())
sib_mnist = SIB(beta=torch.tensor(10 ** 0).float()) #k=10, n=2
print(vib_mnist.parameter_counter())
print(sib_mnist.parameter_counter())

470274
437526


Result of VIB

In [7]:
vib_mnist.train(train_loader)
ixz, izy = vib_mnist.fit(test_loader)

acc:0.9429 IZY:2.01 IZX:3.05


result of SIB

In [8]:
sib_mnist.train(train_loader)
ixz, izy = sib_mnist.fit(test_loader)

Training encoder 1
Training encoder 2
Train W
acc:0.9568 IZY:2.12 IZX:2.98
