In [None]:
%run env_setup.py
%matplotlib notebook
import importlib

In [None]:
import os
import torch
from torch import FloatTensor, nn
from torch.autograd import Variable
import torchvision
import yama
import yama.vision as yamavision
from tqdm import tqdm_notebook as tqdm

In [None]:
from yama.vision.datasets import LocalStorage, PaperSpaceGradientStorage

In [None]:
storage = LocalStorage(os.path.abspath('../../_data'))
#storage = PaperSpaceGradientStorage()

In [None]:
batch_size, image_size, noise_size = 64, (64, 64), 100

In [None]:
%env http_proxy=http://127.0.0.1:1087
%env https_proxy=http://127.0.0.1:1087

In [None]:
data_path = os.path.abspath('../../_data/cifar10')

data = torchvision.datasets.CIFAR10(root=data_path, download=True,
    transform=torchvision.transforms.Compose([
        torchvision.transforms.Resize(image_size),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
)

In [None]:
#data_path = os.path.abspath('../../_data/lsun/')
data_path = '../../_data/lsun'
data = yama.vision.datasets.LSUN(storage, classes=['bedroom_train'],
    transform=torchvision.transforms.Compose([
        torchvision.transforms.Resize(image_size),
        torchvision.transforms.CenterCrop(image_size),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
)

In [None]:
dataloader = torch.utils.data.DataLoader(data, batch_size, shuffle=True, num_workers=os.cpu_count())
n = len(dataloader); n
sample_x, sample_y = data[0]
input_size = tuple(sample_x.size())

In [None]:
class DCGan_D(nn.Module):
    def __init__(self, input_size, feature_num, mid_layers=1):
        super().__init__()
        img_channels, img_size = input_size[0], input_size[1:]
        assert img_size[0] == img_size[1]
        img_size = img_size[0]
    
        main = nn.Sequential()

        def conv_block(name, in_channels, out_channels, kernel_size, stride=1, padding=0):
            main.add_module('{name}_conv_{in_channels}_{out_channels}_{kernel_size}'.format(**locals()),
                           nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding))
            main.add_module('{name}_batchnorm'.format(**locals()),
                           nn.BatchNorm2d(out_channels))
            main.add_module('{name}_LeakyRelu'.format(**locals()),
                           nn.LeakyReLU(0.2))
            return out_channels
        
        last_channels = conv_block('b1', img_channels, feature_num,
                                   kernel_size=4, stride=2, padding=1)
        n_feature = feature_num // 2
        for l in range(mid_layers):
            last_channels = conv_block('mid-{}'.format(l), last_channels,
                                      n_feature, kernel_size=3, padding=1)
        feature_map_size = img_size // 2
        while feature_map_size > 4:
            last_channels = conv_block('pyramid-{}'.format(feature_map_size), last_channels,
                                       last_channels*2, kernel_size=4, stride=2, padding=1)
            feature_map_size //= 2
        main.add_module('final-{}-conv'.format(last_channels),
                        nn.Conv2d(last_channels, 1, feature_map_size, bias=False))
        self.main = main
    
    def forward(self, images):
        out = self.main(images).mean(0)
        return out.view(1)
        

class DCGan_G(nn.Module):
    def __init__(self, noise_len, out_size, feature_num, mid_layers=1):
        super().__init__()
        main = nn.Sequential()
        def deconv_block(name, in_channels, out_channels, kernel_size, stride=1, padding=0):
            main.add_module("{name}_conv_{in_channels}_{out_channels}_{kernel_size}".format(**locals()),
                            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding))
            main.add_module("{}_batch_norm".format(name), nn.BatchNorm2d(out_channels))
            main.add_module("{}_relu".format(name), nn.ReLU())
            return out_channels
        img_channels, img_size = out_size[0], out_size[1:]
        assert img_size[0] == img_size[1]
        img_size = img_size[0]
        assert img_size % 16 == 0
        
        tmp_img_size, feature_num = 4, feature_num//2
        while tmp_img_size != img_size:
            feature_num *= 2
            tmp_img_size *= 2
        
        last_channels = deconv_block('init', noise_len, feature_num, 4)
        
        feature_size = 4
        while feature_size < img_size//2:
            last_channels = deconv_block('pyramid_{}'.format(feature_size), last_channels,
                                        last_channels // 2, 4, 2, 1)
            feature_size *= 2
        for l in range(mid_layers):
            last_channels = deconv_block('mid_{}'.format(l), last_channels,
                                        last_channels, 3, 1, 1)
        main.add_module('final_convt',
                        nn.ConvTranspose2d(last_channels, img_channels, 4, 2, 1))
        main.add_module('final_tanh', nn.Tanh())
        
        self.main = main
    
    def forward(self, in_noise):
        return self.main(noise)

def make_noise(batch_size, noise_channels):
    return Variable(torch.randn(batch_size, noise_channels, 1, 1))

def make_trainable(m, b=True):
    for v in m.parameters(): v.require_grad = b
    
def train(D, G, opt_D, opt_G, loader, epochs, batch_size, noise_channels, first=True, use_gpu=False):
    
    n = len(loader)
    make_trainable(D)
    for ep in range(epochs):
        gen_iter = 0
        d_iter = 0
        d_iter_tgt = 0
        bar = tqdm(loader, desc='{ep}/{epochs}'.format(**locals()))
        for real in bar:
            real_x, real_y = real
            if use_gpu:
                real_x, real_y = real_x.cuda(), real_y.cuda()
            real_x = Variable(real_x)
            if d_iter_tgt == 0 or d_iter >= d_iter_tgt:
                is_warm_up = first and gen_iter < 25
                if is_warm_up or gen_iter % 500 == 0:
                    d_iter_tgt = 100
                else:
                    d_iter_tgt = 5
                d_iter = 0
            fake = G(make_noise(real_x.size()[0], noise_channels)).detach()
            err = D(real_x) - D(fake)
            err.backward()
            opt_D.step()
            bar.set_postfix(loss=float(err.data))
            
            print('iter', d_iter, d_iter_tgt)
            d_iter += 1
            if d_iter == d_iter_tgt:
                make_trainable(D, False)
                fake = G(make_noise(batch_size, noise_channels))
                D(fake).backward()   # Generator want to cheat D -> 0 (means real)
                opt_G.step()
                gen_iter += 1
                make_trainable(D, True)

def weight_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        m.weight.data.normal_(0, 0.02)
    elif isinstance(m, (nn.BatchNorm2d,)):
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [None]:
p = next(iter(dataloader))

In [None]:
x=DCGan_D(input_size, 64, 1).cuda()(Variable(p[0].cuda())); x.size()

In [None]:
noise = make_noise(batch_size, noise_size).cuda()
DCGan_G(noise_size, input_size, feature_num=64, mid_layers=1).cuda()(Variable(p[0].cuda())).size()

In [None]:
g(noise).data.cpu().numpy()[0]

In [None]:
g = DCGan_G(noise_size, input_size, feature_num=64, mid_layers=1).cuda()
d = DCGan_D(input_size, feature_num=64, mid_layers=1).cuda()

opt_d = torch.optim.RMSprop(d.parameters())
opt_g = torch.optim.RMSprop(g.parameters())

for m in [d, g]:
    m.apply(weight_init)
train(d, g, opt_d, opt_g, dataloader,
      epochs=100, batch_size=32, noise_channels=noise_size, use_gpu=True)