In [1]:
import sys
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
import glob
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = {
    "train_batchsz" : 128,
    "epoch" : 200,
    "lr" : 1.0e-4,
    "betas" : (0.5, 0.999),
    "train_path" : '/data/dlcv/hw2/hw2_data/face/train/',
    "val_path" : '/data/dlcv/hw2/hw2_data/face/val/',
    "G_path" : '/data/allen/hw2model/dcgan_G.pth',
    "D_path" :  '/data/allen/hw2model/dcgan_D.pth',
    "device" :  "cuda" if torch.cuda.is_available() else "cpu",
    "c_noise" : 100,
    "feature_dim" : 64
}
train_tfm = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
if config["device"] == "cuda":
    torch.cuda.set_device(7)
print('Device used :', config["device"])

Device used : cuda


In [3]:
def same_seeds(seed):
    # Python built-in random module
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Torch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

same_seeds(7777)

In [4]:
def save_checkpoint(checkpoint_path, model, optimizer):
    state = {'model_state_dict': model.state_dict(),
             'optimizer_state_dict' : optimizer.state_dict()}
    torch.save(state, checkpoint_path)
    print('model saved to {}'.format(checkpoint_path))

In [5]:
class FaceDataset(Dataset):
    def __init__(self, dirpath, transform = None) -> None:
        self.transform = transform
        self.data = []
        files = glob.glob(os.path.join(dirpath, "*.png"))
        for file in files:
            image = Image.open(file)
            self.data.append(image)
        self.len = len(self.data)
        
    def __getitem__(self, index):
        img = self.data[index]
        if self.transform != None:
            img = self.transform(img)
        return img

    def __len__(self):
        return self.len

In [6]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [7]:
class Generator(nn.Module):
    def __init__(self, c_noise, feature_dim, c_img) -> None:
        super().__init__()
        self.project = nn.Sequential(
            nn.ConvTranspose2d(c_noise, feature_dim*16, kernel_size=4, stride=1, bias=False),
            nn.BatchNorm2d(feature_dim*16),
            nn.ReLU(True)
        )
        self.conv = nn.Sequential(
            self.dconv_bn_relu(feature_dim*16, feature_dim*8),
            self.dconv_bn_relu(feature_dim*8, feature_dim*4),
            self.dconv_bn_relu(feature_dim*4, feature_dim*2)
        )
        self.last = nn.Sequential(
            nn.ConvTranspose2d(feature_dim*2, c_img, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False),
            nn.Tanh() 
        )
        self.apply(weights_init)

    def dconv_bn_relu(self, in_dim, out_dim):
        return nn.Sequential(
            nn.ConvTranspose2d(in_dim, out_dim, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False),
            nn.BatchNorm2d(out_dim),
            nn.ReLU(True)
        )
        
    def forward(self, x):
        x1 = self.project(x)
        x2 = self.conv(x1)
        # print("Generator x1 {} x2 {}".format(x1.shape, x2.shape))
        return self.last(x2)

In [8]:
class Discriminator(nn.Module):
    def __init__(self, feature_dim, c_img) -> None:
        super().__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(c_img, feature_dim, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2),
            self.conv_bn_lrelu(feature_dim, feature_dim*2),
            self.conv_bn_lrelu(feature_dim*2, feature_dim*4),
            self.conv_bn_lrelu(feature_dim*4, feature_dim*8),
            nn.Conv2d(feature_dim*8, 1, kernel_size=4, bias=False),
            nn.Sigmoid()
        )
        self.apply(weights_init)

    def conv_bn_lrelu(self, in_dim, out_dim):
        return nn.Sequential(
            nn.Conv2d(in_dim, out_dim, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(out_dim),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
    def forward(self, x):
        return self.disc(x).squeeze()
        

In [9]:
class TrainGan():
    def __init__(self, config) -> None:
        self.config = config
        self.G = Generator(self.config["c_noise"], self.config["feature_dim"], 3)
        self.D = Discriminator(self.config["feature_dim"], 3)

        self.optimizer_G = optim.Adam(self.G.parameters(), lr=self.config["lr"], betas=self.config["betas"])
        self.optimizer_D = optim.Adam(self.D.parameters(), lr=self.config["lr"], betas=self.config["betas"])
        self.criterion = nn.BCELoss()

        self.z_samples = torch.randn(32, self.config["c_noise"], 1, 1, device=self.config["device"]) #fix sample used
        
    def CreateLoader(self):
        self.train_loader = DataLoader(FaceDataset(dirpath=self.config["train_path"], transform=train_tfm), batch_size=self.config["train_batchsz"], shuffle=True, pin_memory=True)
        self.refresh = len(self.train_loader.dataset) / self.config["train_batchsz"]
        print(len(self.train_loader.dataset))

    def SaveEpochResult(self, ep):
        self.G.eval()
        imgs = self.G(self.z_samples).detach().cpu()
        imgs = (imgs + 1.) / 2.0
        torchvision.utils.save_image(imgs, fp="/data/allen/gangrid/dcgan_ep{}.png".format(ep + 1), padding=0)

    def train(self):
        self.CreateLoader()
        self.G = self.G.to(self.config["device"])
        self.D = self.D.to(self.config["device"])
        for ep in range(self.config["epoch"]):
            self.D.train()
            self.G.train()
            total_loss_D, total_loss_G = 0., 0.
            for (idx, real_img) in enumerate(self.train_loader):
                bsz = real_img.shape[0]
                #train Discriminator
                real_img = real_img.to(self.config["device"]) #(bsz, 3, 64, 64)
                real_label = torch.ones((bsz, ), device=self.config["device"]) #(bsz, )
                real_logit = self.D(real_img) #(bsz, )
                # print("r_image {} r_label {} r_logit".format(real_img.shape, real_label.shape, real_logit.shape))
                real_loss = self.criterion(real_logit, real_label)
                random_z = torch.randn(bsz, self.config["c_noise"], 1, 1, device=self.config["device"]) #(bsz, 100, 1, 1)
                fake_img = self.G(random_z) #(bsz, 3, 64, 64)
                fake_label = torch.zeros((bsz, ), device=self.config["device"]) #(bsz, )
                fake_logit = self.D(fake_img) #(bsz, )
                # print("f_image {} f_label {} f_logit".format(fake_img.shape, fake_label.shape, fake_logit.shape))
                fake_loss = self.criterion(fake_logit, fake_label)
                loss_D = (real_loss + fake_loss) / 2
                total_loss_D += loss_D.item()
                # Discriminator backwarding
                self.D.zero_grad()
                loss_D.backward()
                self.optimizer_D.step()

                #train Generater
                random_z = torch.randn(bsz, self.config["c_noise"], 1, 1, device=self.config["device"]) #(bsz, 100, 1, 1)
                fake_img = self.G(random_z)  #(bsz, 3, 64, 64)
                fake_logit = self.D(fake_img) #(bsz, )
                loss_G = self.criterion(fake_logit, real_label)
                total_loss_G += loss_G.item()
                # Generator backwarding
                self.G.zero_grad()
                loss_G.backward()
                self.optimizer_G.step()

            self.SaveEpochResult(ep)
            print("Epoch[{}/{}] loss_G : {:.6f} loss_D : {:.6f}".format(ep + 1, self.config["epoch"], total_loss_G / (self.refresh), total_loss_D / (self.refresh)))
            if (ep + 1) % 100 == 0:
                save_checkpoint("/data/allen/hw2model/dcgan_G_ep{}.pth".format(ep + 1), self.G, self.optimizer_G)
                save_checkpoint("/data/allen/hw2model/dcgan_D_ep{}.pth".format(ep + 1), self.D, self.optimizer_D)
            

In [10]:
DCGAN = TrainGan(config)

In [11]:
DCGAN.train()

38464
Epoch[1/200] loss_G : 6.680003 loss_D : 0.454092
Epoch[2/200] loss_G : 3.352254 loss_D : 0.397286
Epoch[3/200] loss_G : 3.180087 loss_D : 0.358949
Epoch[4/200] loss_G : 3.074823 loss_D : 0.439574
Epoch[5/200] loss_G : 3.498047 loss_D : 0.391640
Epoch[6/200] loss_G : 3.329298 loss_D : 0.394997
Epoch[7/200] loss_G : 3.035551 loss_D : 0.381383
Epoch[8/200] loss_G : 2.944751 loss_D : 0.385395
Epoch[9/200] loss_G : 3.007794 loss_D : 0.356832
Epoch[10/200] loss_G : 2.798196 loss_D : 0.366431
Epoch[11/200] loss_G : 2.652982 loss_D : 0.347175
Epoch[12/200] loss_G : 2.597107 loss_D : 0.349507
Epoch[13/200] loss_G : 2.521463 loss_D : 0.369286
Epoch[14/200] loss_G : 2.615038 loss_D : 0.350668
Epoch[15/200] loss_G : 2.528749 loss_D : 0.352190
Epoch[16/200] loss_G : 2.500395 loss_D : 0.359663
Epoch[17/200] loss_G : 2.463727 loss_D : 0.368085
Epoch[18/200] loss_G : 2.377447 loss_D : 0.378761
Epoch[19/200] loss_G : 2.390643 loss_D : 0.348742
Epoch[20/200] loss_G : 2.330604 loss_D : 0.369489
Epo