<a href="https://colab.research.google.com/github/adivas24/NNFLAssignment/blob/main/MNIST_fm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/Sleepychord/ImprovedGAN-pytorch.git
%cd /content/ImprovedGAN-pytorch

Cloning into 'ImprovedGAN-pytorch'...
remote: Enumerating objects: 77, done.[K
remote: Total 77 (delta 0), reused 0 (delta 0), pack-reused 77[K
Unpacking objects: 100% (77/77), done.
/content/ImprovedGAN-pytorch


In [80]:
import torch
from torchvision import datasets, transforms
import numpy as np
from torch.nn.parameter import Parameter
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
import pdb
import math
from __future__ import print_function 
import torch.optim as optim
from torch.utils.data import DataLoader,TensorDataset
import sys
import argparse
import os
!pip3 install tensorboardX
import tensorboardX



In [81]:
def log_sum_exp(x, axis = 1):
    return torch.max(x, dim = 1)[0] + torch.log(torch.sum(torch.exp(x - m.unsqueeze(1)), dim = axis))

def rnp(L, stdv, weight_scale = 1.):
    assert type(L) == torch.nn.Linear
    torch.nn.init.normal(L.weight, std=weight_scale / math.sqrt(L.weight.size()[0]))
    
class CustomLayer(torch.nn.Module):
    def __init__(self, in_features, out_features, bias=True, weight_scale=None, weight_init_stdv=0.1):
        super(CustomLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.randn(out_features, in_features) * weight_init_stdv)
        if bias:
            self.bias = Parameter(torch.zeros(out_features))
        else:
            self.register_parameter('bias', None)
        if weight_scale is not None:
            assert type(weight_scale) == int
            self.weight_scale = Parameter(torch.ones(out_features, 1) * weight_scale)
        else:
            self.weight_scale = 1 
    
    def forward(self, x):
        W = self.weight * self.weight_scale / torch.sqrt(torch.sum(self.weight ** 2, dim = 1, keepdim = True))
        return F.linear(x, W, self.bias)

In [61]:
class Discriminator(nn.Module):
    def __init__(self, input_dim = 28*28, output_dim = 10):
        super(Discriminator, self).__init__()
        self.input_dim = input_dim
        self.layers = torch.nn.ModuleList([
            CustomLayer(input_dim, 1000),
            CustomLayer(1000, 500),
            CustomLayer(500, 250),
            CustomLayer(250, 250),
            CustomLayer(250, 250)]
        )
        self.final = CustomLayer(250, output_dim, weight_scale=1)

    def forward(self, x, feature_matching = False, cuda = False):
        x = x.view(-1, self.input_dim)
        noise = torch.randn(x.size()) * 0.3 if self.training else torch.Tensor([0])
        if cuda:
            noise = noise.cuda()
        x = x + Variable(noise, requires_grad = False)
        for i in range(len(self.layers)):
            m = self.layers[i]
            x_f = F.relu(m(x))
            noise = torch.randn(x_f.size()) * 0.5 if self.training else torch.Tensor([0])
            if cuda:
                noise = noise.cuda()
            x = (x_f + Variable(noise, requires_grad = False))
        if feature_matching:
            return x_f, self.final(x)
        return self.final(x)


class Generator(nn.Module):
    def __init__(self, z_dim, output_dim = 28 ** 2):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        self.layerfc1 = nn.Linear(z_dim, 500, bias = False)
        self.layerbn1 = nn.BatchNorm1d(500, affine = False, eps=1e-6, momentum = 0.5)
        self.layerfc2 = nn.Linear(500, 500, bias = False)
        self.layerbn2 = nn.BatchNorm1d(500, affine = False, eps=1e-6, momentum = 0.5)
        self.layerfc3 = CustomLayer(500, output_dim, weight_scale = 1)
        self.bn1_b , self.bn2_b = Parameter(torch.zeros(500)), Parameter(torch.zeros(500))
        nn.init.xavier_uniform(self.layerfc1.weight)
        nn.init.xavier_uniform(self.layerfc2.weight)

    def forward(self, batch_size, cuda = False):
        x = Variable(torch.rand(batch_size, self.z_dim), requires_grad = False, volatile = not self.training)
        x = F.softplus(self.layerbn1(self.layerfc1(x)) + self.bn1_b)
        x = F.softplus(self.layerbn2(self.layerfc2(x)) + self.bn2_b)
        x = F.softplus(self.layerfc3(x))
        return x


In [77]:
class GAN(object):
    def __init__(self, G, D, lab, unlab, test):
        self.G ,self.D, self.labeled, self.unlabeled, self.test  = G, D, lab, unlab, test
        self.writer = tensorboardX.SummaryWriter(log_dir='./logfile')
        # using ADAM optimizer
        self.Doptim = optim.Adam(self.D.parameters(), lr=0.003, betas= (0.5, 0.999))
        self.Goptim = optim.Adam(self.G.parameters(), lr=0.003, betas = (0.5,0.999))

    def trainDiscriminator(self, x_label, y, x_unlabel):
        x_label, x_unlabel, y = Variable(x_label), Variable(x_unlabel), Variable(y, requires_grad = False)
        
        label_out=self.D(x_label, cuda=False)
        unlabel_out=self.D(x_unlabel, cuda=False)
        fake_out = self.D(self.G(x_unlabel.size()[0], cuda = False).view(x_unlabel.size()).detach(), cuda=False)

        ls = -torch.mean(torch.gather(label_out, 1, y.unsqueeze(1))) + torch.mean(log_sum_exp(label_out))
        lu = 0.5 * (-torch.mean(log_sum_exp(unlabel_out)) + torch.mean(F.softplus(log_sum_exp(unlabel_out)))+ torch.mean(F.softplus(log_sum_exp(fake_out))) )
        loss = ls + lu
        #print(loss)
        self.Doptim.zero_grad()
        loss.backward()
        self.Doptim.step()
        return ls.data.cpu().numpy(), lu.data.cpu().numpy(), torch.mean((label_out.max(1)[1] == y).float())
    
    def trainGenerator(self, x_unlabel):
        # FEATURE MATCHING
        generated, fake = self.D(self.G(x_unlabel.size()[0], cuda = False).view(x_unlabel.size()), feature_matching=True, cuda=False)
        unlabel, _ = self.D(Variable(x_unlabel), feature_matching=True, cuda=False)

        generated= torch.mean(generated, dim = 0)
        unlabel = torch.mean(unlabel, dim = 0)
        loss_feature_matching = torch.mean((generated - unlabel) ** 2)
        self.Goptim.zero_grad()
        self.Doptim.zero_grad()
        loss_feature_matching.backward()
        self.Goptim.step()
        return loss_feature_matching.data.cpu().numpy()

    def train(self):
        #assert self.unlabeled.__len__() > self.labeled.__len__()
        #assert type(self.labeled) == TensorDataset
        times = int(np.ceil(self.unlabeled.__len__() * 1. / self.labeled.__len__()))
        t1 = self.labeled.tensors[0].clone()
        t2 = self.labeled.tensors[1].clone()
        tile_labeled = TensorDataset(t1.repeat(times,1,1,1),t2.repeat(times))
        gn = 0
        for epoch in range(10):
            self.G.train()
            self.D.train()
            unlabel_loader1 = DataLoader(self.unlabeled, batch_size = 100, shuffle=True, drop_last=True, num_workers = 4)
            unlabel_loader2 = DataLoader(self.unlabeled, batch_size = 100, shuffle=True, drop_last=True, num_workers = 4).__iter__()
            label_loader = DataLoader(tile_labeled, batch_size = 100, shuffle=True, drop_last=True, num_workers = 4).__iter__()
            loss_supervised = loss_unsupervised = loss_gen = accuracy = 0.
            batch_num = 0
            for (unlabel1, _label1) in unlabel_loader1:
                batch_num += 1
                unlabel2, _label2 = unlabel_loader2.next()
                x, y = label_loader.next()
                ll, lu, acc = self.trainDiscriminator(x, y, unlabel1)
                
                loss_supervised += ll
                loss_unsupervised += lu
                accuracy += acc
                lg = self.trainGenerator(unlabel2)
                if epoch > 1 and lg > 1:
                    lg = self.trainGenerator(unlabel2)
                loss_gen += lg
                if (batch_num + 1) % 100 == 0:
                    print('Training: %d / %d' % (batch_num + 1, len(unlabel_loader1)))
                    gn += 1
                    with torch.no_grad():
                        self.writer.add_scalars('loss', {'loss_supervised':ll, 'loss_unsupervised':lu, 'loss_gen':lg}, gn)
                        self.writer.add_histogram('real_feature', self.D(Variable(x), cuda=False, feature_matching = True)[0], gn)
                        self.writer.add_histogram('fake_feature', self.D(self.G(100, cuda = False), cuda=False, feature_matching = True)[0], gn)
                        self.writer.add_histogram('fc3_bias', self.G.layerfc3.bias, gn)
                        self.writer.add_histogram('D_feature_weight', self.D.layers[-1].weight, gn)
                    self.D.train()
                    self.G.train()
            loss_supervised /= batch_num
            loss_unsupervised /= batch_num
            loss_gen /= batch_num
            accuracy /= batch_num
            print("Iteration %d, loss_supervised = %.4f, loss_unsupervised = %.4f, loss_gen = %.4f train acc = %.4f" % (epoch, loss_supervised, loss_unsupervised, loss_gen, accuracy))
            sys.stdout.flush()
            if (epoch + 1) % 1 == 0:
                print("Eval: correct %d / %d"  % (self.eval(), self.test.__len__()))
                

    def predict(self, x):
        with torch.no_grad():
            pred = torch.max(self.D(Variable(x), cuda=False), 1)[1].data
        return pred

    def eval(self):
        self.G.eval()
        self.D.eval()
        d, l = [], []
        for (datum, label) in self.test:
            d.append(datum)
            l.append(label)
        x, y = torch.stack(d), torch.LongTensor(l)
        pred = self.predict(x)
        return torch.sum(pred == y)
    def draw(self, batch_size):
        self.G.eval()
        return self.G(batch_size, cuda=False)

In [79]:
np.random.seed(1)

def get_labelled_MNIST(class_num):
    raw_dataset = datasets.MNIST('../data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),]))
    class_total, data , labels, positive_total, total = 10*[0] , [], [], 0 ,0
    perm = np.random.permutation(raw_dataset.__len__())
    for i in range(raw_dataset.__len__()):
        datum, label = raw_dataset.__getitem__(perm[i])
        if class_total[label] < class_num:
            data.append(datum.numpy())
            labels.append(label)
            class_total[label] += 1
            total += 1
            if total >= 10*class_num:
                break
    return TensorDataset(torch.FloatTensor(np.array(data)), torch.LongTensor(np.array(labels)))

MNIST_labelled=get_labelled_MNIST(10)
MNIST_unlabelled=raw_dataset = datasets.MNIST('../data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),]))
MNIST_test=datasets.MNIST('../data', train=False, download=True,transform=transforms.Compose([transforms.ToTensor(),]))

model = GAN(Generator(100), Discriminator(), MNIST_labelled, MNIST_unlabelled, MNIST_test)
model.train()



Training: 100 / 600
Training: 200 / 600
Training: 300 / 600
Training: 400 / 600
Training: 500 / 600
Training: 600 / 600
Iteration 0, loss_supervised = 0.1515, loss_unsupervised = 0.4526, loss_gen = 0.1472 train acc = 0.9483
Eval: correct 8016 / 10000
Training: 100 / 600
Training: 200 / 600
Training: 300 / 600
Training: 400 / 600
Training: 500 / 600
Training: 600 / 600
Iteration 1, loss_supervised = 0.0092, loss_unsupervised = 0.3951, loss_gen = 0.2532 train acc = 0.9983
Eval: correct 8620 / 10000
Training: 100 / 600
Training: 200 / 600
Training: 300 / 600
Training: 400 / 600
Training: 500 / 600
Training: 600 / 600
Iteration 2, loss_supervised = 0.0066, loss_unsupervised = 0.3806, loss_gen = 0.3698 train acc = 0.9989
Eval: correct 8757 / 10000
Training: 100 / 600
Training: 200 / 600
Training: 300 / 600
Training: 400 / 600
Training: 500 / 600
Training: 600 / 600
Iteration 3, loss_supervised = 0.0056, loss_unsupervised = 0.3771, loss_gen = 0.5710 train acc = 0.9991
Eval: correct 8976 / 10