In [15]:
import torch
import torchvision
import torchvision.datasets as datasets
from torch import nn, optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import numpy as np
import time
import os
import scipy.misc
import pickle
%matplotlib inline

In [16]:
%run Utils.py

In [99]:
class DiscNet(torch.nn.Module):
    def __init__(self, in_dim, out_dim, ndf = 128, ngpu = 1):
        super(DiscNet, self).__init__()
        
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.ndf = ndf
        self.ngpu = ngpu
        
        self.conv = nn.Sequential(
            #Start with (N x 1 x input_size x input_size) 64
            nn.Conv2d(self.in_dim, self.ndf, 4, 2, 1),
            nn.LeakyReLU(0.2),
            #Gives me (N x ndf x input_size//2 x input_size//2) 32
            nn.Conv2d(self.ndf, self.ndf*2, 4, 2, 1),
            nn.BatchNorm2d(self.ndf*2),
            nn.LeakyReLU(0.2),
            #Then have N x 2*ndf x input_size//4 x input_size//4 16
            nn.Conv2d(self.ndf*2, self.ndf*4, 4, 2, 1),
            nn.BatchNorm2d(self.ndf*4),
            nn.LeakyReLU(0.2),
            #Now have N x 4*ndf x input_size//8 x input_size//8 8
            nn.Conv2d(self.ndf*4, self.ndf*8, 4, 2, 1),
            nn.BatchNorm2d(self.ndf*8),
            nn.LeakyReLU(0.2),
        )
        self.last = nn.Sequential(
            #Now N x 8*ndf x input_size//16 x input_size//16 (128 x 1024 x 4 x 4)
#             nn.Conv2d(self.ndf*8, self.out_dim, 4, 1, 0),
            nn.Conv2d(self.ndf*8, self.out_dim, 4, 1, 0),
            nn.Sigmoid()
            #Output N x out_dim x 
        )
        
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)
    
    def forward(self, inp):
        x = self.conv(inp)
        print(x.size())
        x = self.last(x)
        return x

    

In [100]:
class GenNet(torch.nn.Module):
    def __init__(self, in_dim, out_dim, ngf = 128, ngpu = 1):
        super(GenNet, self).__init__()
        
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.ngf = ngf
        self.ngpu = ngpu
        
        self.conv = nn.Sequential(
            #start with latent dimensional input
            nn.ConvTranspose2d(self.in_dim, 8*self.ngf, 4, 1, 0),
            nn.BatchNorm2d(8*self.ngf),
            nn.ReLU(),
            #100 x 8*ngf x 2 x 2
            nn.ConvTranspose2d(8*self.ngf, 4 * self.ngf, 4, 1, 0),
            nn.BatchNorm2d(4*self.ngf),
            nn.ReLU(),
            #Now 100 x 4*ngf x 4 x 4
            nn.ConvTranspose2d(4*self.ngf, 2*self.ngf, 4, 2, 1),
            nn.BatchNorm2d(2*self.ngf),
            nn.ReLU(),
            #Now 100 x 2*ngf x 8 x 8
            nn.ConvTranspose2d(2*self.ngf, self.ngf, 4, 2, 1),
            nn.BatchNorm2d(self.ngf),
            nn.ReLU(),
            #Now 100 x ngf x 16 x 16
            nn.ConvTranspose2d(self.ngf, self.out_dim, 4, 2, 1),
            nn.Tanh()
            #Output 100 x out_dim (num_channels) x 32 x 32 images.
        )
    
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)
    
    def forward(self, inp):
        return self.conv(inp)

In [103]:
img_size = 64
batch_size = 128

transform = transforms.Compose([
        transforms.Scale(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, ), std=(0.5,))
])
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True)

D = DiscNet(1, 1, 128)

x = train_loader.__iter__().__next__()[0]

D(x).size(), D(x).view(-1).size()

torch.Size([128, 1024, 4, 4])
torch.Size([128, 1024, 4, 4])


(torch.Size([128, 1, 1, 1]), torch.Size([128]))

In [89]:
def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()
        

In [107]:
class GAN(object):
    def __init__(self):
        self.num_epochs = 20
        self.batch_size = 128
        self.image_size = 64
        self.z_dim = 100
        self.ndf = 128
        self.ngf = 128
        self.lr = 0.0002
        self.beta1 = 0.5
        self.beta2 = 0.999
        
        self.ngpu = 2
        self.dataset = 'MNIST'
        self.save_dir = 'models/'
        self.result_dir = 'results/'
        self.model_name = 'DCGAN'
        self.sample_num = 25
        self.num_workers = 2
        
        #NOTE: Change the normalization if not using MNIST.
        trans = transforms.Compose([
#             transforms.Resize(self.image_size),
#             transforms.CenterCrop(self.image_size),
            transforms.Scale(self.image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5,), std=(0.5,))
#             transforms.Normalize(mean = (.1307, ), std = (0.3081, ))
        ])
        
        self.data_loader = torch.utils.data.DataLoader(
            datasets.MNIST(root = './data', train = True, download = True, transform = trans),
            batch_size = self.batch_size,
            shuffle = True,
            num_workers = self.num_workers
        )
        data = self.data_loader.__iter__().__next__()[0]
        self.num_channels = data.size()[1]
        
        self.device = torch.device("cuda:0" if (torch.cuda.is_available() and self.ngpu >0) else "cpu")
        
        
        self.D = DiscNet(in_dim = self.num_channels, out_dim = 1, ndf = self.ndf , ngpu = self.ngpu)
        self.G = GenNet(in_dim = self.z_dim, out_dim = self.num_channels, ngf = self.ngf, ngpu = self.ngpu)
        
        self.G = self.G.to(self.device)
        self.D = self.D.to(self.device)
        
        self.criterion = nn.BCELoss()
        self.sample_z = torch.randn(self.batch_size, self.z_dim, 1, 1, device = self.device)
        
        if (self.device.type == 'cuda') and (self.ngpu > 1):
            self.G = nn.DataParallel(self.G, list(range(self.ngpu)))
            self.D = nn.DataParallel(self.D, list(range(self.ngpu)))
        
        self.G.weight_init(mean=0., std= 0.02)
        self.D.weight_init(mean=0., std= 0.02)
        
        self.D_optimizer = optim.Adam(self.D.parameters(), lr = self.lr, betas = (self.beta1, self.beta2))
        self.G_optimizer = optim.Adam(self.G.parameters(), lr = self.lr, betas = (self.beta1, self.beta2))

        
    def plot_train(self):
        real_batch = next(iter(self.data_loader))
        plt.figure(figsize=(8,8))
        plt.axis("off")
        plt.title("Training Images")
        plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(self.device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))


    def train(self):
        self.train_hist = {}
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []
        self.train_hist['per_epoch_time'] = []
        self.train_hist['total_time'] = []
        
        real_label = 1
        fake_label = 0
        
        start_time = time.time()
        
        for epoch in range(self.num_epochs):
            
            epoch_start_time = time.time()
            
            for i, (x, _) in enumerate(self.data_loader, 0):
                
                #TRAIN D
                self.D.zero_grad()
                
                b_size = x.size()[0]
                x = Variable(x.to(self.device))
                y_real = Variable(torch.ones(b_size).to(self.device))
                y_fake = Variable(torch.zeros(b_size).to(self.device))
                
#                 real_im = x.to(self.device)
#                 b_size = real_im.size()[0]
#                 label = torch.full((b_size, ), real_label, device = self.device)
                
                output = self.D(x).view(-1)
#                 print(self.D(real_im).size(), output.size(), y_real.size())
                D_real_loss = self.criterion(output, y_real)
                
                noise = Variable(torch.randn(b_size, self.z_dim, 1, 1, device = self.device))
                fake = self.G(noise)
#                 label.fill_(fake_label)
                output = self.D(fake).view(-1)
                D_fake_loss = self.criterion(output, y_fake)
                                
                D_loss = D_real_loss + D_fake_loss
                self.train_hist['D_loss'].append(D_loss.item())
                D_loss.backward()
                self.D_optimizer.step()
                
                #Train G
                self.G.zero_grad()
#                 label.fill_(real_label)
                noise = Variable(torch.randn(b_size, self.z_dim, 1, 1, device = self.device))
                fake = G(noise)
                
                output = self.D(fake).view(-1)
                G_loss = self.criterion(output, y_real)
                
                G_loss.backward()
                self.G_optimizer.step()
                
                self.train_hist['G_loss'].append(G_loss.item())
                
                self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
                
            self.train_hist['total_time'].append(time.time() - start_time)
            with torch.no_grad():
                self.visualize_results((epoch+1))
            print("Completed epoch {}".format(epoch+1))
        
        print("Done Training!")
        self.save()
        
        
        generate_animation('{}/{}/{}/{}'.format(self.result_dir, self.dataset, self.model_name, self.model_name),
                                 self.epoch)
        loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name)
        
        
    def visualize_results(self, epoch, fix=True):
        self.G.eval()
        
        if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name):
            os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name)
            
        tot_num_samples = min(self.sample_num, self.batch_size)
        image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
        
        samples = self.G(self.sample_z)
        
        if self.ngpu>0:
            samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1)
        else:
            samples = samples.data.numpy().transpose(0, 2, 3, 1)
            
        samples = (samples + 1)/2
        
        
        save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], 
                    [image_frame_dim, image_frame_dim],
        '{}/{}/{}/{}_epoch{:03}.png'.format(self.result_dir, self.dataset, self.model_name, self.model_name, epoch))
        
    
    def save(self):
        save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
        
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        
        torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl'))
        torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl'))
        
        with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f:
            pickle.dump(self.train_hist, f)
    
    def load(self):
        save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
        
        self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl')))
        self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl')))     
        

In [108]:
my_GAN = GAN()
my_GAN.train()

AssertionError: Invalid device id

In [109]:
torch.cuda.device_count()

1