In [1]:
import sys
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
import glob
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = {
    "train_batchsz" : 64,
    "epoch" : 500,
    "lr" : 2.0e-4,
    "betas" : (0., 0.999),
    "train_path" : '/data/dlcv/hw2/hw2_data/face/train/',
    "device" :  "cuda" if torch.cuda.is_available() else "cpu",
    "c_noise" : 100,
    "feature_dim" : 64,
    "n_critic" : 5
}
train_tfm = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
if config["device"] == "cuda":
    torch.cuda.set_device(2)
print('Device used :', config["device"])

Device used : cuda


In [3]:
def same_seeds(seed):
    # Python built-in random module
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Torch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

same_seeds(7777)

In [4]:
def save_checkpoint(checkpoint_path, model, optimizer):
    state = {'model_state_dict': model.state_dict(),
             'optimizer_state_dict' : optimizer.state_dict()}
    torch.save(state, checkpoint_path)
    print('model saved to {}'.format(checkpoint_path))

In [5]:
class FaceDataset(Dataset):
    def __init__(self, dirpath, transform = None) -> None:
        self.transform = transform
        self.data = []
        files = glob.glob(os.path.join(dirpath, "*.png"))
        for file in files:
            image = Image.open(file)
            self.data.append(image)
        self.len = len(self.data)
        
    def __getitem__(self, index):
        img = self.data[index]
        if self.transform != None:
            img = self.transform(img)
        return img

    def __len__(self):
        return self.len

In [6]:
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif isinstance(m, (nn.BatchNorm2d, nn.LayerNorm)):
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0.)

In [7]:
class Generator(nn.Module):
    def __init__(self, c_noise, feature_dim, c_img) -> None:
        super().__init__()
        self.project = nn.Sequential(
            nn.ConvTranspose2d(c_noise, feature_dim*16, kernel_size=4, stride=1, bias=False),
            nn.BatchNorm2d(feature_dim*16),
            nn.ReLU(True)
        )
        self.conv = nn.Sequential(
            self.dconv_bn_relu(feature_dim*16, feature_dim*8),
            self.dconv_bn_relu(feature_dim*8, feature_dim*4),
            self.dconv_bn_relu(feature_dim*4, feature_dim*2)
        )
        self.last = nn.Sequential(
            nn.ConvTranspose2d(feature_dim*2, c_img, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False),
            nn.Tanh() 
        )
        self.apply(weights_init)

    def dconv_bn_relu(self, in_dim, out_dim):
        return nn.Sequential(
            nn.ConvTranspose2d(in_dim, out_dim, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False),
            nn.BatchNorm2d(out_dim),
            nn.ReLU(True)
        )
        
    def forward(self, x):
        x1 = self.project(x)
        x2 = self.conv(x1)
        return self.last(x2)

In [8]:
class Discriminator(nn.Module):
    def __init__(self, feature_dim, c_img) -> None:
        super().__init__()
        self.disc = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(c_img, feature_dim, kernel_size=4, stride=2, padding=1, bias=False)),
            nn.LeakyReLU(0.2),
            self.conv_bn_lrelu(feature_dim, feature_dim*2, feature_dim//4, feature_dim//4),
            self.conv_bn_lrelu(feature_dim*2, feature_dim*4, feature_dim//8, feature_dim//8),
            self.conv_bn_lrelu(feature_dim*4, feature_dim*8, feature_dim//16, feature_dim//16),
            nn.utils.spectral_norm(nn.Conv2d(feature_dim*8, 1, kernel_size=4, bias=False)),
        )
        self.apply(weights_init)

    def conv_bn_lrelu(self, in_dim, out_dim, h, w):
        return nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(in_dim, out_dim, kernel_size=4, stride=2, padding=1, bias=False)),
            nn.LayerNorm([out_dim, h, w]),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
    def forward(self, x):
        return self.disc(x).squeeze()

In [9]:
class TrainGan():
    def __init__(self, config, gpth=None, dpth=None) -> None:
        self.config = config
        self.G = Generator(self.config["c_noise"], self.config["feature_dim"], 3).to(self.config["device"])
        self.D = Discriminator(self.config["feature_dim"], 3).to(self.config["device"])

        self.optimizer_G = optim.Adam(self.G.parameters(), lr=self.config["lr"], betas=self.config["betas"])
        self.optimizer_D = optim.Adam(self.D.parameters(), lr=self.config["lr"], betas=self.config["betas"])

        self.z_samples = torch.randn(32, self.config["c_noise"], 1, 1, device=self.config["device"]) #fix sample used

        if gpth != None:
            if os.path.exists(gpth):
                gcheckpoint = torch.load(gpth, map_location=self.config["device"])
                self.G.load_state_dict(gcheckpoint["model_state_dict"])
                self.optimizer_G.load_state_dict(gcheckpoint["optimizer_state_dict"])
            else:
                print("path error {}".format(gpth))
        if dpth != None:
            if os.path.exists(dpth):
                dcheckpoint = torch.load(dpth, map_location=self.config["device"])
                self.D.load_state_dict(dcheckpoint["model_state_dict"])
                self.optimizer_D.load_state_dict(dcheckpoint["optimizer_state_dict"])
            else:
                print("path error {}".format(dpth))
        
    def CreateLoader(self):
        self.train_loader = DataLoader(FaceDataset(dirpath=self.config["train_path"], transform=train_tfm), batch_size=self.config["train_batchsz"], shuffle=True, pin_memory=True)
        self.refresh = len(self.train_loader.dataset) / self.config["train_batchsz"]
        print(len(self.train_loader.dataset))

    def SaveEpochResult(self, ep):
        self.G.eval()
        imgs = self.G(self.z_samples).detach().cpu()
        imgs = (imgs + 1.) / 2.0
        torchvision.utils.save_image(imgs, fp="/data/allen/gangrid/sngangp_ep{}.png".format(ep + 1), padding=0)

    def gp(self, real, fake):
        lambda_ = 10.
        bsz = real.shape[0]
        epsilon = torch.rand((bsz, 1, 1, 1), device=self.config["device"])
        xhat = (epsilon * real + (1 - epsilon) * fake).requires_grad_(True).to(config["device"])
        xhat_logit = self.D(xhat)
        gradxhat = torch.autograd.grad(outputs=xhat_logit, inputs=xhat, grad_outputs=torch.ones_like(xhat_logit, device=self.config["device"]),
                                        create_graph=True, retain_graph=True, only_inputs=True)[0]
        gradxhat = gradxhat.view(bsz, -1)
        return lambda_*((gradxhat.norm(2, dim=1)-1)**2).mean()

    def train(self):
        self.CreateLoader()
        for ep in range(self.config["epoch"]):
            self.D.train()
            self.G.train()
            total_loss_D, total_loss_G = 0., 0.
            for (idx, real_img) in enumerate(self.train_loader):
                bsz = real_img.shape[0]
                #train Discriminator
                real_img = real_img.to(self.config["device"]) #(bsz, 3, 64, 64)
                real_logit = self.D(real_img) #(bsz, )
                random_z = torch.randn(bsz, self.config["c_noise"], 1, 1, device=self.config["device"]) #(bsz, 100, 1, 1)
                fake_img = self.G(random_z) #(bsz, 3, 64, 64)
                fake_logit = self.D(fake_img) #(bsz, )
                loss_D = -torch.mean(real_logit) + torch.mean(fake_logit) + self.gp(real_img, fake_img)
                total_loss_D += loss_D.item()
                # Discriminator backwarding
                self.D.zero_grad()
                loss_D.backward()
                self.optimizer_D.step()
                
                #train Generater
                if (idx + 1) % self.config["n_critic"] == 0:
                    random_z = torch.randn(bsz, self.config["c_noise"], 1, 1, device=self.config["device"]) #(bsz, 100, 1, 1)
                    fake_img = self.G(random_z)  #(bsz, 3, 64, 64)
                    fake_logit = self.D(fake_img) #(bsz, )
                    loss_G = -torch.mean(fake_logit)
                    total_loss_G += loss_G.item()
                    # Generator backwarding
                    self.G.zero_grad()
                    loss_G.backward()
                    self.optimizer_G.step()
            print("Epoch[{}/{}] loss_G : {:.6f} loss_D : {:.6f}".format(ep + 1, self.config["epoch"], total_loss_G / (self.refresh/self.config["n_critic"]), total_loss_D / (self.refresh)))
            if (ep + 1) % 5 == 0:
                self.SaveEpochResult(ep)
            if ep >= 100 and (ep + 1) % 10 == 0:
                save_checkpoint("/data/allen/hw2model/sngangp_G_ep{}.pth".format(ep + 1), self.G, self.optimizer_G)
                save_checkpoint("/data/allen/hw2model/sngangp_D_ep{}.pth".format(ep + 1), self.D, self.optimizer_D)
            

In [None]:
SNGAN_GP = TrainGan(config)
print(SNGAN_GP.G, SNGAN_GP.D)

In [11]:
SNGAN_GP.train()

38464
Epoch[1/500] loss_G : 26.856668 loss_D : -21.220496
Epoch[2/500] loss_G : 31.458643 loss_D : -13.720754
Epoch[3/500] loss_G : 34.311541 loss_D : -11.039145
Epoch[4/500] loss_G : 35.169863 loss_D : -8.399651
Epoch[5/500] loss_G : 33.540571 loss_D : -6.861910
Epoch[6/500] loss_G : 34.539854 loss_D : -6.243287
Epoch[7/500] loss_G : 35.998614 loss_D : -5.951208
Epoch[8/500] loss_G : 37.191012 loss_D : -5.779420
Epoch[9/500] loss_G : 38.122981 loss_D : -5.720413
Epoch[10/500] loss_G : 38.961808 loss_D : -5.583660
Epoch[11/500] loss_G : 39.887662 loss_D : -5.544386
Epoch[12/500] loss_G : 40.907555 loss_D : -5.452085
Epoch[13/500] loss_G : 41.926163 loss_D : -5.326874
Epoch[14/500] loss_G : 42.750660 loss_D : -5.155093
Epoch[15/500] loss_G : 43.512206 loss_D : -5.016008
Epoch[16/500] loss_G : 44.183177 loss_D : -4.888528
Epoch[17/500] loss_G : 45.472748 loss_D : -4.693958
Epoch[18/500] loss_G : 46.089792 loss_D : -4.567614
Epoch[19/500] loss_G : 46.589586 loss_D : -4.401463
Epoch[20/500