In [6]:
import torch
import torchvision
import torchvision.datasets as datasets
from torch import nn, optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import numpy as np
import time
import os
import scipy.misc
import pickle
%matplotlib inline

In [7]:
%run Utils.py

In [23]:
def g_loss(real, *argv):
    #args represent disc outputs. ALWAYS assume at least 1 loss
    #Seems to be working fine at first glance. Testing w/ dummy model.
    criterion = nn.BCELoss()
    curr_loss = criterion(argv[0], real)
    for d_out in argv[1:]:
        curr_loss = curr_loss + criterion(d_out, real)
    return curr_loss

def d_loss(desired, current_ind, *d_outs):
    """
    Discriminators play game against G trying to recognize true vs. false samples,
    play pairwise penny matching game:
        lower index discriminator tries to match,
        higher index tries NOT to match.
    
    Desired - desired output. For training vs. G or on training set.
    d_outs - all the discriminator outputs.
    current_ind - index of current discriminator being trained in d_outs
    
    """

    if current_ind > len(d_outs):
        raise IndexError("Index out of bounds")
    
    criterion = nn.BCELoss()
    d_out = d_outs[current_ind]
    curr_loss = criterion(d_out, desired)
    
    def match_pennies(d1_out, d2_out, sign):
        #Matching pennies loss function.
        """
        If 1 matching pennies w/ 2, then loss for 2 is -MSE(1, 2).
        while lsos for 1 is MSE(1, 2)
        
        If 2 matching pennies with 3, loss for 2 is MSE(2, 3)
        
        TOtal loss for 2: -MSE(1, 2) + MSE(2, 3)
        Easiest sol - probably just keeping track of index. 
        """
        crit = nn.MSELoss()
        return sign * crit(d1_out, d2_out)
    
    
    for i, other_d_out in enumerate(d_outs):
        if i < current_ind:
            sign = 1
        elif i > current_ind:
            sign = -1
        else:
            sign = 0

        scale = 1./1 if len(d_outs) == 1 else 1./(len(d_outs)-1)
            
        curr_loss = curr_loss + scale * match_pennies(d_out, other_d_out, sign)

    
    return curr_loss


# g_loss(x, y, z).backward()
# d_loss(x, 2, out, y, z)


In [24]:
#Discriminator class. Assumes 64x64 inputs.
class DiscNet(torch.nn.Module):
    def __init__(self, in_dim, out_dim, ndf = 128, ngpu = 1):
        super(DiscNet, self).__init__()
        
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.ndf = ndf
        self.ngpu = ngpu
        
        self.conv = nn.Sequential(
            #Start with (N x 1 x input_size x input_size) 64
            nn.Conv2d(self.in_dim, self.ndf, 4, 2, 1),
            nn.LeakyReLU(0.2),
            #Gives me (N x ndf x input_size//2 x input_size//2) 32
            nn.Conv2d(self.ndf, self.ndf*2, 4, 2, 1),
            nn.BatchNorm2d(self.ndf*2),
            nn.LeakyReLU(0.2),
            #Then have N x 2*ndf x input_size//4 x input_size//4 16
            nn.Conv2d(self.ndf*2, self.ndf*4, 4, 2, 1),
            nn.BatchNorm2d(self.ndf*4),
            nn.LeakyReLU(0.2),
            #Now have N x 4*ndf x input_size//8 x input_size//8 8
            nn.Conv2d(self.ndf*4, self.ndf*8, 4, 2, 1),
            nn.BatchNorm2d(self.ndf*8),
            nn.LeakyReLU(0.2),
        )
        self.last = nn.Sequential(
            #Now N x 8*ndf x input_size//16 x input_size//16 (128 x 1024 x 4 x 4)
            nn.Conv2d(self.ndf*8, self.out_dim, 4, 1, 0),
            nn.Sigmoid()
            #Output N x out_dim x 
        )
        
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)
    
    def forward(self, inp):
        x = self.conv(inp)
#         print(x.size())
        x = self.last(x)
        return x

    

In [25]:
#Generator class. Assumes 64x64 outputs.
class GenNet(torch.nn.Module):
    def __init__(self, in_dim, out_dim, ngf = 128, ngpu = 1):
        super(GenNet, self).__init__()
        
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.ngf = ngf
        self.ngpu = ngpu
        
        self.conv = nn.Sequential(
            #start with latent dimensional input
            nn.ConvTranspose2d(self.in_dim, 8*self.ngf, 4, 1, 0),
            nn.BatchNorm2d(8*self.ngf),
            nn.ReLU(),
            #100 x 8*ngf x 2 x 2
            nn.ConvTranspose2d(8*self.ngf, 4 * self.ngf, 4, 2, 1),
            nn.BatchNorm2d(4*self.ngf),
            nn.ReLU(),
            #Now 100 x 4*ngf x 4 x 4
            nn.ConvTranspose2d(4*self.ngf, 2*self.ngf, 4, 2, 1),
            nn.BatchNorm2d(2*self.ngf),
            nn.ReLU(),
            #Now 100 x 2*ngf x 8 x 8
            nn.ConvTranspose2d(2*self.ngf, self.ngf, 4, 2, 1),
            nn.BatchNorm2d(self.ngf),
            nn.ReLU(),
            #Now 100 x ngf x 16 x 16
            nn.ConvTranspose2d(self.ngf, self.out_dim, 4, 2, 1),
            nn.Tanh()
            #Output 100 x out_dim (num_channels) x 32 x 32 images.
        )
    
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)
    
    def forward(self, inp):
        return self.conv(inp)

In [26]:
img_size = 64
batch_size = 128

transform = transforms.Compose([
        transforms.Scale(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, ), std=(0.5,))
])
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True)

len(train_loader)

469

In [27]:
def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

In [28]:
#GAN class container.
class GAN(object):
    def __init__(self):
        self.num_epochs = 30
        self.batch_size = 128
        self.image_size = 64
        self.z_dim = 128
        self.ndf = 128
        self.ngf = 128
        self.lr = 0.0002
        self.beta1 = 0.5
        self.beta2 = 0.999
        
        self.num_disc = 2
        
        self.ngpu = 1
        self.workers = 8
        self.dataset = 'MNIST'
        self.save_dir = 'models/'
        self.result_dir = 'results/'
        self.model_name = '{}_Disc_DCGAN_mean'.format(self.num_disc)
        self.sample_num = 64
        
        #NOTE: Change the normalization if not using MNIST.
        trans = transforms.Compose([
            transforms.Resize(self.image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5,), std=(0.5,))
#             transforms.Normalize(mean = (.1307, ), std = (0.3081, ))
        ])
        
        self.data_loader = torch.utils.data.DataLoader(
            datasets.MNIST(root = './data', train = True, download = True, transform = trans),
            batch_size = self.batch_size,
            shuffle = True,
            num_workers = self.workers
        )
        
#         self.num_channels = next(iter(self.data_loader))[0].size()[1]
        self.num_channels = 1 if self.dataset == 'MNIST' else 3
        
        self.device = torch.device("cuda:0" if (torch.cuda.is_available() and self.ngpu >0) else "cpu")
        
        self.D1 = DiscNet(in_dim = self.num_channels, out_dim = 1, ndf = self.ndf , ngpu = self.ngpu)
        self.G = GenNet(in_dim = self.z_dim, out_dim = self.num_channels, ngf = self.ngf, ngpu = self.ngpu)
        
        if self.num_disc >= 2:
            self.D2 = DiscNet(in_dim = self.num_channels, out_dim = 1, ndf = self.ndf , ngpu = self.ngpu)
            self.D2 = self.D2.to(self.device)
        
        self.G = self.G.to(self.device)
        self.D1 = self.D1.to(self.device)
        
#         self.criterion = nn.BCELoss().to(self.device)
        self.D_criterion = d_loss
        self.G_criterion = g_loss

        self.sample_z = torch.randn(self.batch_size, self.z_dim, 1, 1, device = self.device)
        
        if (self.device.type == 'cuda') and (self.ngpu > 1):
            self.G = nn.DataParallel(self.G, list(range(self.ngpu)))
            self.D1 = nn.DataParallel(self.D, list(range(self.ngpu)))
        
        self.G.weight_init(mean=0., std= 0.02)
        self.D1.weight_init(mean=0., std= 0.02)
        
        self.D1_optimizer = optim.Adam(self.D1.parameters(), lr = self.lr, betas = (self.beta1, self.beta2))
        self.D2_optimizer = optim.Adam(self.D2.parameters(), lr = self.lr, betas = (self.beta1, self.beta2))

        self.G_optimizer = optim.Adam(self.G.parameters(), lr = self.lr, betas = (self.beta1, self.beta2))

        
    def plot_train(self):
        real_batch = next(iter(self.data_loader))
        plt.figure(figsize=(8,8))
        plt.axis("off")
        plt.title("Training Images")
        plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(self.device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))


    def train(self):
        self.train_hist = {}
        self.train_hist['D1_loss'] = []
        self.train_hist['D2_loss'] = []
        self.train_hist['G_loss'] = []
        self.train_hist['per_epoch_time'] = []
        self.train_hist['total_time'] = []
        
        real_label = 1
        fake_label = 0
        
        start_time = time.time()
        
        for epoch in range(self.num_epochs):
            
            epoch_start_time = time.time()
            
            for i, (x, _) in enumerate(self.data_loader, 0):
                batch_start_time = time.time()
                
                #TRAIN D
                self.D1.zero_grad()
                self.D2.zero_grad()
                
                b_size = x.size()[0]
                x = Variable(x.to(self.device))
                y_real = Variable(torch.ones(b_size, device = self.device))
                y_fake = Variable(torch.zeros(b_size, device = self.device))
                
                d1_output = self.D1(x).view(-1)
                d2_output = self.D2(x).view(-1)
                
                d1_real_loss = self.D_criterion(y_real, 0, d1_output, d2_output)
                d2_real_loss = self.D_criterion(y_real, 1, d1_output, d2_output)
                
#                 D_real_loss = self.criterion(output, y_real)
                del d1_output
                del d2_output
                        
                noise = Variable(torch.randn(b_size, self.z_dim, 1, 1, device = self.device))
                fake = self.G(noise)
                
                d1_output = self.D1(fake).view(-1)
                d2_output = self.D2(fake).view(-1)
                
                
#                 output = self.D(fake).view(-1)
                del noise
                del fake
                
                d1_fake_loss = self.D_criterion(y_fake, 0, d1_output)
                d2_fake_loss = self.D_criterion(y_fake, 0, d2_output)
                
#                 D_fake_loss = self.criterion(output, y_fake)
                del d1_output
                del d2_output
                            
                d1_loss = d1_real_loss + d1_fake_loss
                d2_loss = d2_real_loss + d2_real_loss
        
                del d1_real_loss
                del d1_fake_loss
                del d2_real_loss
                del d2_fake_loss
        
#                 D_loss = D_real_loss + D_fake_loss
#                 del D_real_loss
#                 del D_fake_loss
                
                self.train_hist['D1_loss'].append(d1_loss.item())
                self.train_hist['D2_loss'].append(d2_loss.item())
                
                d1_loss.backward(retain_graph=True)
                d2_loss.backward(retain_graph=True)
                self.D1_optimizer.step()
                self.D2_optimizer.step()
                del d1_loss
                del d2_loss
                
                #Train G
                self.G.zero_grad()
                noise = Variable(torch.randn(b_size, self.z_dim, 1, 1, device = self.device))
                fake = self.G(noise)
                del noise
                
                d1_output = self.D1(fake).view(-1)
                d2_output = self.D2(fake).view(-1)
                del fake

                
                G_loss = self.G_criterion(y_real, d1_output, d2_output)
                del d1_output
                del d2_output
                
                G_loss.backward()
                self.G_optimizer.step()
                
                self.train_hist['G_loss'].append(G_loss.item())
                
            self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
                
            self.train_hist['total_time'].append(time.time() - start_time)
            with torch.no_grad():
                self.visualize_results((epoch+1))
            print("Completed epoch {} in {}s".format(epoch+1, time.time() - epoch_start_time))
        
        print("Done Training!")
        self.save()
        
        
        generate_animation('{}/{}/{}/{}'.format(self.result_dir, self.dataset, self.model_name, self.model_name),
                                 self.num_epochs)
        loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name)
        
        
    def visualize_results(self, epoch, fix=True):
        self.G.eval()
        
        if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name):
            os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name)
            
        tot_num_samples = min(self.sample_num, self.batch_size)
        image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
        
        samples = self.G(self.sample_z)
        
        if self.ngpu>0:
            samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1)
        else:
            samples = samples.data.numpy().transpose(0, 2, 3, 1)
            
        samples = (samples + 1)/2
        
        
        save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], 
                    [image_frame_dim, image_frame_dim],
        '{}/{}/{}/{}_epoch{:03}.png'.format(self.result_dir, self.dataset, self.model_name, self.model_name, epoch))
        
    
    def save(self):
        save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
        
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        
        torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl'))
        torch.save(self.D1.state_dict(), os.path.join(save_dir, self.model_name + '_D1.pkl'))
        torch.save(self.D2.state_dict(), os.path.join(save_dir, self.model_name + '_D2.pkl'))

        
        with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f:
            pickle.dump(self.train_hist, f)
    
    def load(self):
        save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
        
        self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl')))
        self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl')))     
        

In [None]:
my_GAN = GAN()
my_GAN.train()

Completed epoch 1 in 864.0501117706299s
Completed epoch 2 in 865.1706988811493s
Completed epoch 3 in 863.1535921096802s
Completed epoch 4 in 861.1947169303894s
Completed epoch 5 in 860.7646996974945s
Completed epoch 6 in 859.7481396198273s
Completed epoch 7 in 860.0431108474731s
Completed epoch 8 in 857.9803256988525s
Completed epoch 9 in 857.964339017868s
Completed epoch 10 in 855.6677293777466s
Completed epoch 11 in 857.0269229412079s
Completed epoch 12 in 855.6680085659027s


In [69]:
t_hist = my_GAN.train_hist
l = []
# print(len(t_hist['per_epoch_time']))
for i in range(468, len(t_hist['per_epoch_time']), 468):
    x = t_hist['per_epoch_time'][i]
    print(x)
    l.append(x)

print("Mean time: {} seconds".format(np.mean(l)))
# t_hist['per_epoch_time'][468]

533.569614648819
525.9256238937378
523.5429208278656
520.8068869113922
517.9585914611816
517.6175518035889
517.6820816993713
512.1806309223175
508.7970519065857
509.93553352355957
509.23811078071594
507.3241722583771
503.73891377449036
504.3291959762573
505.31306314468384
502.40385222435
502.0548310279846
501.0621898174286
497.44885873794556
496.7175602912903
Mean time: 510.8823617815971 seconds
