In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [None]:
from datetime import datetime
import os
import numpy as np
from glob import glob
from PIL import Image

import torchvision.transforms as transforms
import torch.utils.data as data

import torch
from torch import autograd
from torch.nn import functional as F
from torch import nn
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data

In [None]:
# Before execute this notebook, create folder named "SinGAN_data" in your google drive and upload train image.
data_dir = '/content/gdrive/MyDrive/SinGAN_data'

# If you have pretrained model and you want to validate it, save your pretrained model's name to load_model.
# And must valiation value change to 1
validation = 0
load_model = None

# you can change gantype to "wgangp"
gantype = 'zerogp'

batch_size = 1
gpu = 0
img_size_min = 25
# if you valiate pretrained model, must change img_size_max value to 1025
img_size_max = 250

In [None]:
if load_model is not None:
    this_time_model = load_model
else:
    this_time_model = f'SinGAN_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}_{gantype}'
if os.path.isdir('./logs') is False:
    os.makedirs('./logs')
if os.path.isdir('./results') is False:
    os.makedirs('./results')
if load_model is None:
    os.makedirs(os.path.join('./logs', this_time_model))      
if os.path.isdir(os.path.join('./results', this_time_model)) is False:
    os.makedirs(os.path.join('./results', this_time_model)) 

log_dir = os.path.join('./logs', this_time_model)
res_dir = os.path.join('./results', this_time_model)

In [None]:
class SinDataset(data.Dataset):

    def __init__(self, dir, transform):
        self.data_dir = dir
        self.transform = transform
        self.image_dir = sorted(glob(os.path.join(self.data_dir, '*.jpg')))[0]
    
    def __len__(self):
        return len(self.image_dir)

    def __getitem__(self, idx):
        with open(self.image_dir, 'rb') as f:
            img = Image.open(f)
            img = img.convert('RGB')
            return self.transform(img)


def get_dataset(data_dir):
    # image processing
    train_transforms = transforms.Compose([transforms.Resize((256,256)),
                                           transforms.ToTensor(),
                                           transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                                std=[0.5, 0.5, 0.5])])
    
    val_transforms = transforms.Compose([transforms.Resize((256,256)), 
                                         transforms.ToTensor(), 
                                         transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                              std=[0.5, 0.5, 0.5])])

    train_dataset = SinDataset(data_dir, transform = train_transforms)
    val_dataset = SinDataset(data_dir, transform = val_transforms)

    return train_dataset, val_dataset

In [None]:
# datasets
train_dataset, _ = get_dataset(data_dir)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 1,
                                              shuffle = False, num_workers = 8,
                                              pin_memory= True)

  cpuset_checked))


In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.hidden = 32
        self.current_scale = 0

        self.discriminators = nn.ModuleList()

        temp_disc = nn.ModuleList()

        temp_disc.append(nn.Sequential(nn.Conv2d(3, self.hidden, 3, 1, 1),
                                       nn.LeakyReLU(0.2)))
        for _ in range(3):
            temp_disc.append(nn.Sequential(nn.Conv2d(self.hidden, self.hidden, 3, 1, 1),
                                           nn.BatchNorm2d(self.hidden),
                                           nn.LeakyReLU(0.2)))

        temp_disc.append(nn.Sequential(nn.Conv2d(self.hidden, 1, 3, 1, 1)))
        
        temp_disc = nn.Sequential(*temp_disc)

        self.discriminators.append(temp_disc)

    def forward(self, x):
        out = self.discriminators[self.current_scale](x)
        return out

    def progress(self):
        self.current_scale += 1
        if self.current_scale % 4 == 0:
            self.hidden *= 2

        temp_disc = nn.ModuleList()
        temp_disc.append(nn.Sequential(nn.Conv2d(3, self.hidden, 3, 1, 1),
                                       nn.LeakyReLU(0.2)))
        for _ in range(3):
            temp_disc.append(nn.Sequential(nn.Conv2d(self.hidden, self.hidden, 3, 1, 1),
                                           nn.BatchNorm2d(self.hidden),
                                           nn.LeakyReLU(0.2)))
        temp_disc.append(nn.Sequential(nn.Conv2d(self.hidden, 1, 3, 1, 1)))
        
        temp_disc = nn.Sequential(*temp_disc)

        if self.current_scale % 4 != 0:
            # continue start learning from prev discriminator's parameters
            temp_disc.load_state_dict(self.discriminators[-1].state_dict())

        self.discriminators.append(temp_disc)
        print("PROGRESSION DONE")

In [None]:
class Generator(nn.Module):
    def __init__(self, img_size_min, num_scale, scalefactor = 4/3):
        super(Generator, self).__init__()
        self.hidden = 32
        self.current_scale = 0
        self.img_size_min = img_size_min
        self.scalefactor = scalefactor
        self.num_scale = num_scale

        self.size_list = [int(self.img_size_min * (self.scalefactor ** i)) for i in range(num_scale + 1)]
        print(f"size_list : {self.size_list}")

        self.generators = nn.ModuleList()

        temp_gene = nn.ModuleList()

        temp_gene.append(nn.Sequential(nn.Conv2d(3, self.hidden, 3, 1),
                                             nn.BatchNorm2d(self.hidden),
                                             nn.LeakyReLU(0.2)))
        for _ in range(3):
            temp_gene.append(nn.Sequential(nn.Conv2d(self.hidden, self.hidden, 3, 1),
                                                 nn.BatchNorm2d(self.hidden),
                                                 nn.LeakyReLU(0.2)))
        temp_gene.append(nn.Sequential(nn.Conv2d(self.hidden, 3, 3, 1),
                                             nn.Tanh()))
        
        temp_gene = nn.Sequential(*temp_gene)

        self.generators.append(temp_gene)

    def forward(self, z, img = None):
        ret = []
        out = None
        if img != None:
            out = img
        else:
            out = self.generators[0](z[0])
        ret.append(out)
        for i in range(1, self.current_scale + 1):
            out = F.interpolate(out, (self.size_list[i], self.size_list[i]), mode = 'bilinear', align_corners = True)
            prev = out
            out = F.pad(out, [5,5,5,5], value = 0)
            out += z[i]
            out = self.generators[i](out) + prev
            ret.append(out)
            
        return ret

    def progress(self):
        self.current_scale += 1

        if self.current_scale % 4 == 0:
            self.hidden *= 2
        temp_gene = nn.ModuleList()

        temp_gene.append(nn.Sequential(nn.Conv2d(3, self.hidden, 3, 1),
                                             nn.BatchNorm2d(self.hidden),
                                             nn.LeakyReLU(0.2)))
        for _ in range(3):
            temp_gene.append(nn.Sequential(nn.Conv2d(self.hidden, self.hidden, 3, 1),
                                                 nn.BatchNorm2d(self.hidden),
                                                 nn.LeakyReLU(0.2)))
        temp_gene.append(nn.Sequential(nn.Conv2d(self.hidden, 3, 3, 1),
                                             nn.Tanh()))
        
        temp_gene = nn.Sequential(*temp_gene)
        

        if self.current_scale % 4 != 0:
            # continue start learning from prev generator's parameters
            temp_gene.load_state_dict(self.generators[-1].state_dict())

        self.generators.append(temp_gene)

In [None]:
# models
scale_factor = 4/3
min_max_ratio = img_size_max / img_size_min
num_scale = int(np.round(np.log(min_max_ratio)/np.log(scale_factor)))
size_list = [int(img_size_min * scale_factor ** i) for i in range(num_scale + 1)]

In [None]:
discriminator = Discriminator()
generator = Generator(25, num_scale, scale_factor)

size_list : [25, 33, 44, 59, 79, 105, 140, 187, 249]


In [None]:
torch.cuda.set_device(0)
discriminator = discriminator.cuda(0)
generator = generator.cuda(0)

In [None]:
# optimizers
dis_opt = torch.optim.Adam(discriminator.discriminators[0].parameters(), 5e-4, (0.5, 0.999))
gen_opt = torch.optim.Adam(generator.generators[0].parameters(), 5e-4, (0.5, 0.999))

In [None]:
stage = 0
if load_model is not None:
    check_load = open(os.path.join(log_dir, "checkpoint.txt"), 'r')
    to_restore = check_load.readlines()[-1].strip()
    load_file = os.path.join(log_dir, to_restore)
    if os.path.isfile(load_file):
        print("=> loading checkpoint '{}'".format(load_file))
        checkpoint = torch.load(load_file, map_location='cpu')
        for _ in range(int(checkpoint['stage'])):
            generator.progress()
            discriminator.progress()
        networks = [discriminator, generator]
        
        torch.cuda.set_device(0)
        networks = [x.cuda(0) for x in networks]

        discriminator, generator, = networks
        
        stage = checkpoint['stage']
        print("stage: ",stage)
        discriminator.load_state_dict(checkpoint['D_state_dict'])
        generator.load_state_dict(checkpoint['G_state_dict'])
        dis_opt.load_state_dict(checkpoint['d_optimizer'])
        gen_opt.load_state_dict(checkpoint['g_optimizer'])
        print("=> loaded checkpoint '{}' (stage {})"
              .format(load_file, checkpoint['stage']))
    else:
        print("=> no checkpoint found at '{}'".format(log_dir))

In [None]:
# Training
fixed_latents = [F.pad(torch.randn(batch_size, 3, size_list[0], size_list[0]), [5,5,5,5], value = 0)]
zero_latents = [F.pad(torch.zeros(batch_size, 3, size_list[idx], size_list[idx]), [5,5,5,5], value = 0) for idx in range(1, num_scale + 1)]
fixed_latents = fixed_latents + zero_latents

In [None]:
# util functions
def compute_grad_gp(d_out, x_in):
    batch_size = x_in.size(0)
    grad_dout = autograd.grad(
        outputs=d_out.sum(), inputs=x_in,
        create_graph=True, retain_graph=True, only_inputs=True)[0]
    grad_dout2 = grad_dout.pow(2)
    assert(grad_dout2.size() == x_in.size())
    reg = grad_dout2.view(batch_size, -1).sum(1)
    return reg


def compute_grad_gp_wgan(D, x_real, x_fake):
    alpha = torch.rand(x_real.size(0), 1, 1, 1).cuda(0)

    x_interpolate = ((1 - alpha) * x_real + alpha * x_fake).detach()
    x_interpolate.requires_grad = True
    d_inter_logit = D(x_interpolate)
    grad = torch.autograd.grad(d_inter_logit, x_interpolate,
                               grad_outputs=torch.ones_like(d_inter_logit), create_graph=True)[0]

    norm = grad.view(grad.size(0), -1).norm(p=2, dim=1)

    d_gp = ((norm - 1) ** 2).mean()
    return d_gp

def save_checkpoint(state, check_list, log_dir, epoch=0):
    check_file = os.path.join(log_dir, 'model_{}.ckpt'.format(epoch))
    torch.save(state, check_file)
    check_list.write('model_{}.ckpt\n'.format(epoch))


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
from tqdm import trange
import torchvision.utils as vutils

def train(data_loader, generator, discriminator, d_opt, g_opt, stage_idx, z, size_list, res_dir, gantype, num_scale):
    generator.train()
    discriminator.train()

    epochs = 2000
    decay_lr = 1600
    train_it = iter(data_loader)
    origin = train_it.next()

    if torch.cuda.is_available():
        for z_idx in range(len(z)):
            z[z_idx] = z[z_idx].cuda(0, non_blocking=True)
        origin = origin.cuda(0, non_blocking = True)
    
    x_in = F.interpolate(origin, (size_list[stage_idx], size_list[stage_idx]), mode = 'bilinear', align_corners = True)
    vutils.save_image(x_in.detach().cpu(), os.path.join(res_dir, "ORGINAL_{}.png".format(stage_idx)), nrow = 1, normalize = True)

    x_in_list = [x_in]
    for idx in range(1, stage_idx + 1):
        x_in_list.append(F.interpolate(origin, (size_list[idx], size_list[idx]), mode = 'bilinear', align_corners = True))

    tqdm_train = trange(0, epochs, initial = 0, total = epochs)

 
    d_losses = AverageMeter()
    g_losses = AverageMeter()
    for i in tqdm_train:
        if i == decay_lr:
            for params in d_opt.param_groups:
                params['lr'] *= 0.1

            for params in g_opt.param_groups:
                params['lr'] *= 0.1
            print("Generator and Discriminator's learning rate updated")

        # update Generator's weights
        for _ in range(3):
            g_opt.zero_grad()

            out = generator(z)            

            g_mse = F.mse_loss(out[-1], x_in)

            sqrt_rmse = [1.0]
            # calc rmse for every scale (except stage 0)
            for idx in range(1, stage_idx + 1):
                sqrt_rmse.append(torch.sqrt(F.mse_loss(out[idx], x_in_list[idx])))

            # 각 scale의 sqrt_rmse의 값을 랜덤 값에 곱해준 리스트 생성

            z_list = [F.pad(sqrt_rmse[z_idx] * torch.randn(1, 3, size_list[z_idx],
                                               size_list[z_idx]).cuda(0, non_blocking=True),
                            [5, 5, 5, 5], value=0) for z_idx in range(stage_idx + 1)]
            
            x_fake_list = generator(z_list)
            g_fake_logit = discriminator(x_fake_list[-1])
            if torch.cuda.is_available():
                ones = torch.ones_like(g_fake_logit).cuda(0)
            else:
                ones = torch.ones_like(g_fake_logit)

            if gantype == 'wgangp':
                # wgan gp
                g_fake = -torch.mean(g_fake_logit, (2, 3))
                g_loss = g_fake + 10.0 * g_mse
            elif gantype == 'zerogp':
                # zero centered GP
                g_fake = F.binary_cross_entropy_with_logits(g_fake_logit, ones, reduction='none').mean()
                g_loss = g_fake + 100.0 * g_mse

            g_loss.backward()
            g_opt.step()
            g_losses.update(g_loss.item(), x_in.size(0))

        # Update Discriminator's weights
        for _ in range(3):
            x_in.requires_grad = True

            d_opt.zero_grad()
            x_fake_list = generator(z_list)

            d_fake_logit = discriminator(x_fake_list[-1].detach())
            d_real_logit = discriminator(x_in)

            if torch.cuda.is_available():
                ones = torch.ones_like(d_real_logit).cuda(0)
                zeros = torch.zeros_like(d_fake_logit).cuda(0)

            if gantype == 'wgangp':
                # wgan gp
                d_fake = torch.mean(d_fake_logit, (2, 3))
                d_real = -torch.mean(d_real_logit, (2, 3))
                d_gp = compute_grad_gp_wgan(discriminator, x_in, x_fake_list[-1])
                d_loss = d_real + d_fake + 0.1 * d_gp

            elif gantype == 'zerogp':
                # zero centered GP
                d_fake = F.binary_cross_entropy_with_logits(d_fake_logit, zeros, reduction='none').mean()
                d_real = F.binary_cross_entropy_with_logits(d_real_logit, ones, reduction='none').mean()
                d_gp = compute_grad_gp(torch.mean(d_real_logit, (2, 3)), x_in)
                d_loss = d_real + d_fake + 10.0 * d_gp

            d_loss.backward()
            d_opt.step()
            d_losses.update(d_loss.item(), x_in.size(0))

        tqdm_train.set_description(f'Stage: [{stage_idx}/{num_scale}] Avg Loss: D[{d_losses.avg : .3f}] G[{g_losses.avg : .3f}] RMSE[{sqrt_rmse[-1] : .3f}]')

In [None]:
def validation_func(data_loader, generator, discriminator, stage_idx, z, size_list, res_dir, validation):
    discriminator.eval()
    generator.eval()

    val_iter = iter(data_loader)
    origin = next(val_iter)
    
    if torch.cuda.is_available():
        origin = origin.cuda(0, non_blocking = True)
    x_in = F.interpolate(origin, (size_list[stage_idx], size_list[stage_idx]), mode='bilinear', align_corners=True)
    vutils.save_image(x_in.detach().cpu(), os.path.join(res_dir, 'ORG_{}.png'.format(stage_idx)),
                      nrow=1, normalize=True)
    x_in_list = [x_in]
    for xidx in range(1, stage_idx + 1):
        x_tmp = F.interpolate(origin, (size_list[xidx], size_list[xidx]), mode='bilinear', align_corners=True)
        x_in_list.append(x_tmp)

    for z_idx in range(len(z)):
        z[z_idx] = z[z_idx].cuda(0, non_blocking=True)

    with torch.no_grad():
        out = generator(z)

        # calculate rmse for each scale
        rmse_list = [1.0]
        for rmseidx in range(1, stage_idx + 1):
            rmse = torch.sqrt(F.mse_loss(out[rmseidx], x_in_list[rmseidx]))
            if validation:
                rmse /= 100.0
            rmse_list.append(rmse)
        if len(rmse_list) > 1:
            rmse_list[-1] = 0.0
        if validation:
            vutils.save_image(out[-1].detach().cpu(), os.path.join(res_dir, 'validation_REC_{}.png'.format(stage_idx)),
                              nrow=1, normalize=True)
        else:
            vutils.save_image(out[-1].detach().cpu(), os.path.join(res_dir, 'REC_{}.png'.format(stage_idx)),
                              nrow=1, normalize=True)

        for k in range(50):
            z_list = [F.pad(rmse_list[z_idx] * torch.randn(1, 3, size_list[z_idx],
                                               size_list[z_idx]).cuda(0, non_blocking=True),
                            [5, 5, 5, 5], value=0) for z_idx in range(stage_idx + 1)]
            x_fake_list = generator(z_list)
            if validation:
                vutils.save_image(x_fake_list[-1].detach().cpu(), os.path.join(res_dir, 'validation_GEN_{}_{}.png'.format(stage_idx, k)),
                                  nrow=1, normalize=True)
            else:
                vutils.save_image(x_fake_list[-1].detach().cpu(), os.path.join(res_dir, 'GEN_{}_{}.png'.format(stage_idx, k)),
                                  nrow=1, normalize=True)



In [None]:
if validation:
    validation_func(train_loader, generator, discriminator, stage, fixed_latents, res_dir, validation)
else:        
    for stage_idx in range(stage, num_scale + 1):
        
        train(train_loader, generator, discriminator, dis_opt, gen_opt, stage_idx, fixed_latents, size_list, res_dir, gantype, num_scale)
        validation_func(train_loader, generator, discriminator, stage_idx, fixed_latents, size_list, res_dir, validation)
        discriminator.progress()
        generator.progress()
        if torch.cuda.is_available():
            discriminator = discriminator.cuda(0)
            generator = generator.cuda(0)
            
        # Update the networks at finest scale
        for net_idx in range(generator.current_scale):
            for param in generator.generators[net_idx].parameters():
                param.requires_grad = False
            for param in discriminator.discriminators[net_idx].parameters():
                param.requires_grad = False

        dis_opt = torch.optim.Adam(discriminator.discriminators[discriminator.current_scale].parameters(),
                                    5e-4, (0.5, 0.999))
        gen_opt = torch.optim.Adam(generator.generators[generator.current_scale].parameters(),
                                    5e-4, (0.5, 0.999))


        if stage_idx == 0:            
            check_list = open(os.path.join(log_dir, "checkpoint.txt"), "a+")

        save_checkpoint({
            'stage': stage_idx + 1,
            'D_state_dict': discriminator.state_dict(),
            'G_state_dict': generator.state_dict(),
            'd_optimizer': dis_opt.state_dict(),
            'g_optimizer': gen_opt.state_dict()
        }, check_list, log_dir, stage_idx + 1)

        if stage_idx == num_scale:
            check_list.close()

  cpuset_checked))
Stage: [0/8] Avg Loss: D[ 0.063] G[ 5.536] RMSE[ 1.000]:  80%|████████  | 1602/2000 [01:23<00:20, 19.12it/s]

Generator and Discriminator's learning rate updated


Stage: [0/8] Avg Loss: D[ 0.055] G[ 5.667] RMSE[ 1.000]: 100%|██████████| 2000/2000 [01:44<00:00, 19.10it/s]


PROGRESSION DONE


Stage: [1/8] Avg Loss: D[ 0.679] G[ 2.714] RMSE[ 0.025]:  80%|████████  | 1603/2000 [01:38<00:24, 16.26it/s]

Generator and Discriminator's learning rate updated


Stage: [1/8] Avg Loss: D[ 0.583] G[ 2.946] RMSE[ 0.023]: 100%|██████████| 2000/2000 [02:02<00:00, 16.27it/s]


PROGRESSION DONE


Stage: [2/8] Avg Loss: D[ 0.702] G[ 2.360] RMSE[ 0.035]:  80%|████████  | 1601/2000 [01:46<00:27, 14.44it/s]

Generator and Discriminator's learning rate updated


Stage: [2/8] Avg Loss: D[ 0.602] G[ 2.611] RMSE[ 0.042]: 100%|██████████| 2000/2000 [02:13<00:00, 14.98it/s]


PROGRESSION DONE


Stage: [3/8] Avg Loss: D[ 0.815] G[ 2.086] RMSE[ 0.032]:  80%|████████  | 1601/2000 [01:58<00:28, 14.04it/s]

Generator and Discriminator's learning rate updated


Stage: [3/8] Avg Loss: D[ 0.728] G[ 2.290] RMSE[ 0.029]: 100%|██████████| 2000/2000 [02:28<00:00, 13.50it/s]


PROGRESSION DONE


Stage: [4/8] Avg Loss: D[ 0.953] G[ 1.740] RMSE[ 0.037]:  80%|████████  | 1601/2000 [04:03<00:59,  6.67it/s]

Generator and Discriminator's learning rate updated


Stage: [4/8] Avg Loss: D[ 0.828] G[ 1.981] RMSE[ 0.031]: 100%|██████████| 2000/2000 [05:04<00:00,  6.57it/s]


PROGRESSION DONE


Stage: [5/8] Avg Loss: D[ 0.992] G[ 1.449] RMSE[ 0.032]:  80%|████████  | 1601/2000 [04:56<01:13,  5.45it/s]

Generator and Discriminator's learning rate updated


Stage: [5/8] Avg Loss: D[ 0.904] G[ 1.594] RMSE[ 0.024]: 100%|██████████| 2000/2000 [06:10<00:00,  5.39it/s]


PROGRESSION DONE


Stage: [6/8] Avg Loss: D[ 0.893] G[ 1.483] RMSE[ 0.025]:  80%|████████  | 1600/2000 [12:20<03:04,  2.16it/s]

Generator and Discriminator's learning rate updated


Stage: [6/8] Avg Loss: D[ 0.821] G[ 1.631] RMSE[ 0.023]: 100%|██████████| 2000/2000 [15:24<00:00,  2.16it/s]


PROGRESSION DONE


Stage: [7/8] Avg Loss: D[ 0.945] G[ 1.293] RMSE[ 0.024]:  80%|████████  | 1600/2000 [15:46<03:56,  1.69it/s]

Generator and Discriminator's learning rate updated


Stage: [7/8] Avg Loss: D[ 0.888] G[ 1.412] RMSE[ 0.022]: 100%|██████████| 2000/2000 [19:43<00:00,  1.69it/s]


PROGRESSION DONE


Stage: [8/8] Avg Loss: D[ 1.241] G[ 1.257] RMSE[ 0.022]:  80%|████████  | 1600/2000 [33:46<08:28,  1.27s/it]

Generator and Discriminator's learning rate updated


Stage: [8/8] Avg Loss: D[ 1.263] G[ 1.159] RMSE[ 0.021]: 100%|██████████| 2000/2000 [42:13<00:00,  1.27s/it]


PROGRESSION DONE
