In [0]:
!pip install -U -q PyDrive
import os
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# 1. Authenticate and create the PyDrive client.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

# choose a local (colab) directory to store the data.
local_download_path = os.path.expanduser('~/data')
try:
  os.makedirs(local_download_path)
except: pass

# 2. Auto-iterate using the query syntax
#    https://developers.google.com/drive/v2/web/search-parameters
file_list = drive.ListFile(
    {'q': "'1mCsY5LEsgCnc0Txv0rpAUhKVPWVkbw5I' in parents"}).GetList()

for f in file_list:
  # 3. Create & download by id.
  print('title: %s, id: %s' % (f['title'], f['id']))
  fname = os.path.join(local_download_path, f['title'])
  print('downloading to {}'.format(fname))
  f_ = drive.CreateFile({'id': f['id']})
  f_.GetContentFile(fname)

!tar -xzf ~/data/faces.tar.gz -C ~/data
!ls ~/data
os.makedirs('~/pretrain')
os.makedirs('~/samples')

In [0]:
import torch
import torch.utils.data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as dset
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import numpy as np

# GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# For reproducibility
torch.manual_seed(0)

# Constants
NOISE_LENGTH = 100
NOISE_BASELINE = torch.randn(5 * 5, NOISE_LENGTH, device=device)

In [0]:
def visualize_data(dataloader):
    batch = next(iter(dataloader))
    plt.figure(figsize=(8, 8))
    plt.axis("off")
    plt.title("Training images")
    images = vutils.make_grid(batch[0][:64], normalize=True, range=(-1, 1))
    images = images.permute(1, 2, 0)
    plt.imshow(images)
    plt.show()


def visualize_batch(data, batches_done):
    plt.figure(figsize=(5, 5))
    plt.axis("off")
    plt.title("Batches done %d" % batches_done)
    images = vutils.make_grid(data.cpu().detach()[:25],
                              5,
                              normalize=True,
                              range=(-1, 1))
    images = images.permute(1, 2, 0)
    plt.imshow(images)
    plt.show()


def load_dataset(root, batch_size):
    dataset = dset.ImageFolder(
        root,
        transforms.Compose([
            transforms.Resize(64),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]))
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size,
                                             True,
                                             num_workers=2,
                                             pin_memory=True)
    return dataloader

In [0]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.fc = nn.Linear(100, 1024 * 4 * 4)
        self.proc = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, 5, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
            nn.ConvTranspose2d(128, 3, 4, 2, 1),
        )

    def forward(self, x):
        output = self.fc(x).view(-1, 1024, 4, 4)
        output = self.proc(output)
        output = torch.tanh(output)
        return output


class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        self.proc = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(512, 1, 4),
        )

    def forward(self, input):
        return self.proc(input).flatten()

In [0]:
def train(dataloader, n_epochs, lr, n_critic=5, c=0.01):
    def weights_init(m):
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)
        elif isinstance(m, nn.Linear):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

    netG = Generator().to(device)
    netG.apply(weights_init)
    optG = optim.RMSprop(netG.parameters(), lr)

    netD = Critic().to(device)
    netD.apply(weights_init)
    optD = optim.RMSprop(netD.parameters(), lr)

    noise = torch.randn(dataloader.batch_size, NOISE_LENGTH, device=device)

    batches_done = 0

    for epoch in range(n_epochs):
        # Train critic
        for i, (data, _) in enumerate(dataloader):
            data = data.to(device)

            optD.zero_grad()
            noise.normal_()
            # Gradients from G are not used, so detach to avoid computing them
            # Maximize (3) -> minimize its inverse
            lossD = -(netD(data).mean() - netD(netG(noise).detach()).mean())
            lossD.backward()
            optD.step()
            # Clamp the weights
            for p in netD.parameters():
                p.data.clamp_(-c, c)

            # Train generator
            if i % n_critic == 0:
                optG.zero_grad()
                noise.normal_()
                # Minimize EM distance
                lossG = -netD(netG(noise)).mean()
                lossG.backward()
                optG.step()
                print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" %
                      (epoch, n_epochs, batches_done % len(dataloader),
                       len(dataloader), lossD.item(), lossG.item()))

            if batches_done % 100 == 0:
                fake = netG(NOISE_BASELINE)
                vutils.save_image(fake,
                                  "~/samples/%d.png" % batches_done,
                                  5,
                                  normalize=True,
                                  range=(-1, 1))
                visualize_batch(fake, batches_done)

            batches_done += 1

        # Save the model
        torch.save(netG.state_dict(), '~/pretrain/netG_epoch_%d.pth' % epoch)
        torch.save(netD.state_dict(), '~/pretrain/netD_epoch_%d.pth' % epoch)

In [0]:
def main():
    dataloader = load_dataset('~/data', 64)
    visualize_data(dataloader)

    train(dataloader, 500, 0.00005)


if __name__ == "__main__":
    main()