In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torchvision.utils import make_grid
from torchvision import transforms
import itertools
from PIL import Image
import numpy as np
import os
import yaml
import random
import matplotlib.pyplot as plt

In [None]:
IMG_EXTENSIONS = ["png", "jpg"]

class ImagetoImageDataset(Dataset):
    def __init__(self, domainA_dir, domainB_dir, transforms=None):
        self.imagesA = [os.path.join(domainA_dir, x) for x in os.listdir(domainA_dir) if
                        x.lower().endswith(tuple(IMG_EXTENSIONS))]
        self.imagesB = [os.path.join(domainB_dir, x) for x in os.listdir(domainB_dir) if
                        x.lower().endswith(tuple(IMG_EXTENSIONS))]

        self.transforms = transforms

        self.lenA = len(self.imagesA)
        self.lenB = len(self.imagesB)

    def __len__(self):
        return max(self.lenA, self.lenB)

    def __getitem__(self, idx):
        idx_a = idx_b = idx
        if idx_a >= self.lenA:
            idx_a = np.random.randint(self.lenA)
        if idx_b >= self.lenB:
            idx_b = np.random.randint(self.lenB)
        
        imageA = np.array(Image.open(self.imagesA[idx_a]).convert("RGB"))
        imageB = np.array(Image.open(self.imagesB[idx_b]).convert("RGB"))

        if self.transforms is not None:
            imageA = self.transforms(imageA)
            imageB = self.transforms(imageB)

        return imageA, imageB

In [None]:
class ConvBlock2d(nn.Module):
    def __init__(self, in_feature, out_feature, kernel_size, stride=(1, 1), activation='relu'):
        super(ConvBlock2d, self).__init__()
        self.conv = nn.Conv2d(in_feature, out_feature, kernel_size=kernel_size, stride=stride, padding='same')
        # self.batchNorm = nn.BatchNorm2d(out_feature)
        self.activation = activation
        
    def forward(self, x):
        # x = self.batchNorm(self.conv(x))
        x = self.conv(x)
        if self.activation == 'relu':
            return F.relu(x)
        else:
            return x

In [None]:
class MultiResBlock(nn.Module):
    def __init__(self, in_feature, out_feature):
        super(MultiResBlock, self).__init__()
        feature_3x3 = out_feature // 6
        feature_5x5 = out_feature // 3
        feature_7x7 = out_feature - feature_3x3 - feature_5x5
        self.conv_3x3 = ConvBlock2d(in_feature, feature_3x3, kernel_size=3)
        self.conv_5x5 = ConvBlock2d(feature_3x3, feature_5x5, kernel_size=3)
        self.conv_7x7 = ConvBlock2d(feature_5x5, feature_7x7, kernel_size=3)

        self.conv_1x1 = ConvBlock2d(in_feature, out_feature, kernel_size=1)

        # self.batch_norm1 = nn.BatchNorm2d(out_feature)
        # self.batch_norm2 = nn.BatchNorm2d(out_feature)

    def forward(self, x):
        o_3x3 = self.conv_3x3(x)
        o_5x5 = self.conv_5x5(o_3x3)
        o_7x7 = self.conv_7x7(o_5x5)
        # o = self.batch_norm1(torch.cat([o_3x3, o_5x5, o_7x7], axis=1))
        o = torch.cat([o_3x3, o_5x5, o_7x7], axis=1)

        o_1x1 = self.conv_1x1(x)

        # o = self.batch_norm1(o + o_1x1)
        o = o + o_1x1

        return F.relu(o)

In [None]:
class ResPath(nn.Module):
    def __init__(self, in_feature, out_feature, length):
        super(ResPath, self).__init__()
        self.respath_length = length
        self.residuals = torch.nn.ModuleList([])
        self.convs = torch.nn.ModuleList([])
        # self.bns = torch.nn.ModuleList([])

        for i in range(self.respath_length):
            if(i==0):
                self.residuals.append(ConvBlock2d(in_feature, out_feature, kernel_size = (1,1), activation='None'))
                self.convs.append(ConvBlock2d(in_feature, out_feature, kernel_size = (3,3),activation='relu'))

            	
            else:
                self.residuals.append(ConvBlock2d(out_feature, out_feature, kernel_size = (1,1), activation='None'))
                self.convs.append(ConvBlock2d(out_feature, out_feature, kernel_size = (3,3), activation='relu'))

            # self.bns.append(torch.nn.BatchNorm2d(out_feature))

    def forward(self, x):
        
        for i in range(self.respath_length):
            res = self.residuals[i](x)

            x = self.convs[i](x)
            # x = self.bns[i](x)
            # x = torch.nn.functional.relu(x)

            x = x + res
            # x = self.bns[i](x)
            x = torch.nn.functional.relu(x)
        
        return x

In [None]:
class MultiResUNet(nn.Module):
    def __init__(self, in_feature, out_feature, alpha=1.667, ngf = 32):
        super(MultiResUNet, self).__init__()
        #encoder
        feature1 = int(ngf * alpha)
        self.multi1 = MultiResBlock(in_feature, feature1)
        self.pool1 = nn.MaxPool2d(2)
        self.respath1 = ResPath(feature1, ngf, length=4)

        feature2 = int(ngf * 2 * alpha)
        self.multi2 = MultiResBlock(feature1, feature2)
        self.pool2 = nn.MaxPool2d(2)
        self.respath2 = ResPath(feature2, ngf * 2, length=3)

        feature3 = int(ngf * 4 * alpha)
        self.multi3 = MultiResBlock(feature2, feature3)
        self.pool3 = nn.MaxPool2d(2)
        self.respath3 = ResPath(feature3, ngf * 4, length=2)

        feature4 = int(ngf * 8 * alpha)
        self.multi4 = MultiResBlock(feature3, feature4)
        self.pool4 = nn.MaxPool2d(2)
        self.respath4 = ResPath(feature4, ngf * 8, length=1)

        feature5 = int(ngf * 16 * alpha)
        self.multi5 = MultiResBlock(feature4, feature5)

        #decoder
        out_feature5 = feature5
        self.upsample1 = nn.ConvTranspose2d(out_feature5, ngf * 8, kernel_size = (2, 2), stride = (2, 2))  
        out_feature4 = int(ngf * 8 * alpha)
        self.multi6 = MultiResBlock(ngf * 8 * 2, out_feature4)
        
        self.upsample2 = nn.ConvTranspose2d(out_feature4, ngf * 4, kernel_size = (2, 2), stride = (2, 2))
        out_feature3 = int(ngf * 4 * alpha)  
        self.multi7 = MultiResBlock(ngf * 4 * 2, out_feature3)
	
        self.upsample3 = nn.ConvTranspose2d(out_feature3, ngf * 2, kernel_size = (2, 2), stride = (2, 2))
        out_feature2 = int(ngf * 2 * alpha)
        self.multi8 = MultiResBlock(ngf * 2 * 2, out_feature2)
		
        self.upsample4 = nn.ConvTranspose2d(out_feature2, ngf, kernel_size = (2, 2), stride = (2, 2))
        out_feature1 = int(ngf * alpha)
        self.multi9 = MultiResBlock(ngf * 2, out_feature1)

        self.conv_final = ConvBlock2d(out_feature1, out_feature, kernel_size = (1,1), activation='None')

    def forward(self, x):
        #encoder
        layer1 = self.multi1(x)        
        layer2 = self.multi2(self.pool1(layer1))
        layer3 = self.multi3(self.pool2(layer2))
        layer4 = self.multi4(self.pool3(layer3))
        layer5 = self.multi5(self.pool4(layer4))
        #decoder
        layer4 = self.multi6(torch.cat([self.upsample1(layer5), self.respath4(layer4)], axis=1))
        layer3 = self.multi7(torch.cat([self.upsample2(layer4), self.respath3(layer3)], axis=1))
        layer2 = self.multi8(torch.cat([self.upsample3(layer3), self.respath2(layer2)], axis=1))
        layer1 = self.multi9(torch.cat([self.upsample4(layer2), self.respath1(layer1)], axis=1))

        return self.conv_final(layer1)


In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_feature, alpha=0.667, ndf = 32):
        super(Discriminator, self).__init__()
        feature1 = int(ndf * alpha)
        self.multi1 = MultiResBlock(in_feature, feature1)
        self.pool1 = nn.MaxPool2d(2)

        feature2 = int(ndf * 2 * alpha)
        self.multi2 = MultiResBlock(feature1, feature2)
        self.pool2 = nn.MaxPool2d(2)

        feature3 = int(ndf * 4 * alpha)
        self.multi3 = MultiResBlock(feature2, feature3)
        self.pool3 = nn.MaxPool2d(2)

        feature4 = int(ndf * 8 * alpha)
        self.multi4 = MultiResBlock(feature3, feature4)
        self.pool4 = nn.MaxPool2d(2)

        feature5 = int(ndf * 16 * alpha)
        self.multi5 = MultiResBlock(feature4, feature5)
        self.pool5 = nn.MaxPool2d(2)

        self.FC = nn.Linear(feature5 * 8 * 8, 1)

    def forward(self, x):
        #encoder
        layer = self.multi1(x)        
        layer = self.multi2(self.pool1(layer))
        layer = self.multi3(self.pool2(layer))
        layer = self.multi4(self.pool3(layer))
        layer = self.multi5(self.pool4(layer))
        layer = self.FC(self.pool5(layer).view(x.shape[0], -1))
        return F.sigmoid(layer)


In [None]:
@torch.no_grad()
def infer():
    args = {
        'config': 'configs/aging_gan.yaml',
        'checkpoint_dir': './pretrained/',
        'image_dir': './archive/test_image/'
    }
    with open(args['config']) as file:
        configs = yaml.load(file, Loader=yaml.FullLoader)
    image_dir_O = args['image_dir'] + 'testO/'
    image_dir_Y = args['image_dir'] + 'testY/'
    old_image_paths = [os.path.join(image_dir_O, x) for x in os.listdir(image_dir_O) if
                   x.endswith('.png') or x.endswith('.jpg')]
    young_image_paths = [os.path.join(image_dir_Y, x) for x in os.listdir(image_dir_Y) if
                   x.endswith('.png') or x.endswith('.jpg')]
    
    model = MultiResUNet(3, 3, configs['gen_alpha'], configs['ngf'])
    ckpt = torch.load(args['checkpoint_dir'] + configs['y2o'], map_location='cpu')
    model.load_state_dict(ckpt)
    model.eval()
    trans = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
    nr_images = len(young_image_paths) 
    fig, ax = plt.subplots(2, nr_images, figsize=(20, 10))
    random.shuffle(young_image_paths)
    for i in range(nr_images):
        img = Image.open(young_image_paths[i]).convert('RGB')
        img = trans(img).unsqueeze(0)
        aged_face = model(img)
        aged_face = (aged_face.squeeze().permute(1, 2, 0).numpy() + 1.0) / 2.0
        ax[0, i].imshow((img.squeeze().permute(1, 2, 0).numpy() + 1.0) / 2.0)
        ax[1, i].imshow(aged_face)
    plt.show()
    plt.savefig("mygraph_y2o.png")





    model = MultiResUNet(3, 3, configs['gen_alpha'], configs['ngf'])
    ckpt = torch.load(args['checkpoint_dir'] + configs['o2y'], map_location='cpu')
    model.load_state_dict(ckpt)
    model.eval()
    trans = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
    nr_images = len(old_image_paths) 
    fig, ax = plt.subplots(2, nr_images, figsize=(20, 10))
    random.shuffle(old_image_paths)
    for i in range(nr_images):
        img = Image.open(old_image_paths[i]).convert('RGB')
        img = trans(img).unsqueeze(0)
        aged_face = model(img)
        aged_face = (aged_face.squeeze().permute(1, 2, 0).numpy() + 1.0) / 2.0
        ax[0, i].imshow((img.squeeze().permute(1, 2, 0).numpy() + 1.0) / 2.0)
        ax[1, i].imshow(aged_face)
    plt.show()
    plt.savefig("mygraph_o2y.png")



In [None]:
class GAN(pl.LightningModule):
    def load_from_checkpoint(self, checkpoint_path):
        y2o = checkpoint_path + self.hparams['y2o']
        o2y = checkpoint_path + self.hparams['o2y']
        self.genY2O.load_state_dict(torch.load(y2o))
        self.genO2Y.load_state_dict(torch.load(o2y))

    def __init__(self, hparams):
        super(GAN, self).__init__()
        self.automatic_optimization = False
        self.save_hyperparameters(hparams)
        self.genY2O = MultiResUNet(3, 3, self.hparams['gen_alpha'], self.hparams['ngf'])
        self.genO2Y = MultiResUNet(3, 3, self.hparams['gen_alpha'], self.hparams['ngf'])
        self.disY = Discriminator(3, self.hparams['dis_alpha'], self.hparams['ndf'])
        self.disO = Discriminator(3, self.hparams['dis_alpha'], self.hparams['ndf'])

        # cache for generated images
        self.generated_Y = None
        self.generated_O = None
        self.real_Y = None
        self.real_O = None

    def forward(self, x):
        return self.genY2O(x)
    
    def training_step(self, batch, batch_idx):
        g_optim, d_optim = self.optimizers()
        g_optim.zero_grad()
        
        self.disO.requires_grad_(False)
        self.disY.requires_grad_(False)
        real_Y, real_O = batch


        fake_O = self.genY2O(real_Y)
        pred_O = self.disO(fake_O)
        loss_Y2O = F.binary_cross_entropy(pred_O, torch.ones(pred_O.shape).type_as(pred_O))

        rec_Y = self.genO2Y(fake_O)
        loss_Y2O2Y = F.mse_loss(rec_Y, real_Y)

        real_GY = self.genO2Y(real_Y)
        loss_Y2Y = F.mse_loss(real_GY, real_Y)



        fake_Y = self.genO2Y(real_Y)
        pred_Y = self.disY(fake_Y)
        loss_O2Y = F.binary_cross_entropy(pred_Y, torch.ones(pred_Y.shape).type_as(pred_Y))

        rec_O = self.genY2O(fake_Y)
        loss_O2Y2O = F.mse_loss(rec_O, real_O)

        real_GO = self.genY2O(real_O)
        loss_O2O = F.mse_loss(real_GO, real_O)
        

        g_loss = (loss_Y2O + loss_O2Y) * self.hparams['adv_weight'] + (loss_Y2Y + loss_O2O) * self.hparams['identity_weight'] + (loss_Y2O2Y + loss_O2Y2O) * self.hparams['cycle_weight']


        

        # Log to tb
        if batch_idx % 500 == 0:
            self.genY2O.eval()
            self.genY2O.eval()
            fake_Y = self.genO2Y(real_O)
            fake_O = self.genY2O(real_Y)
            self.logger.experiment.add_image('Real/Y', make_grid(real_Y, normalize=True, scale_each=True),
                                                self.current_epoch)
            self.logger.experiment.add_image('Real/O', make_grid(real_O, normalize=True, scale_each=True),
                                                self.current_epoch)
            self.logger.experiment.add_image('Generated/Y',
                                                make_grid(fake_Y, normalize=True, scale_each=True),
                                                self.current_epoch)
            self.logger.experiment.add_image('Generated/O',
                                                make_grid(fake_O, normalize=True, scale_each=True),
                                                self.current_epoch)
            self.genY2O.train()
            self.genO2Y.train()

            output_path = './pretrained/'
            torch.save(self.genY2O.state_dict(), f"{output_path}{self.hparams['y2o']}")
            torch.save(self.genO2Y.state_dict(), f"{output_path}{self.hparams['o2y']}")

            infer()

        self.manual_backward(g_loss)
        g_optim.step()
        self.disO.requires_grad_(True)
        self.disY.requires_grad_(True)

        self.log('Loss/Generator', g_loss.detach())

        d_optim.zero_grad()
        self.genO2Y.requires_grad_(False)
        self.genY2O.requires_grad_(False)

        pred_RY = self.disY(real_Y)
        loss_RY = F.binary_cross_entropy(pred_RY, torch.ones(pred_RY.shape).type_as(pred_RY))

        pred_RO = self.disY(real_O)
        loss_RO = F.binary_cross_entropy(pred_RO, torch.ones(pred_RO.shape).type_as(pred_RO))

        pred_FY = self.disY(self.genO2Y(real_O))
        loss_FY = F.binary_cross_entropy(pred_FY, torch.zeros(pred_FY.shape).type_as(pred_FY))

        pred_FO = self.disO(self.genY2O(real_Y))
        loss_FO = F.binary_cross_entropy(pred_FO, torch.zeros(pred_FO.shape).type_as(pred_FO))

        d_loss = loss_RO + loss_FO + loss_RY + loss_FY


        self.manual_backward(d_loss)
        d_optim.step()
        self.genO2Y.requires_grad_(True)
        self.genY2O.requires_grad_(True)

        self.log('Loss/Discriminator', d_loss.detach())
        

        

    def configure_optimizers(self):
        g_optim = torch.optim.Adam(itertools.chain(self.genY2O.parameters(), self.genO2Y.parameters()),
                                   lr=self.hparams['lr'], betas=(0.5, 0.999),
                                   weight_decay=self.hparams['weight_decay'])
        d_optim = torch.optim.Adam(itertools.chain(self.disY.parameters(),
                                                   self.disO.parameters()),
                                   lr=self.hparams['lr'],
                                   betas=(0.5, 0.999),
                                   weight_decay=self.hparams['weight_decay'])
        return [g_optim, d_optim], []
    

    def train_dataloader(self):
        train_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomHorizontalFlip(),
            transforms.Resize((self.hparams['img_size'] + 50, self.hparams['img_size'] + 50)),
            transforms.RandomCrop(self.hparams['img_size']),
            #transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3),
            #transforms.RandomPerspective(p=0.5),
            transforms.RandomRotation(degrees=(0, int(self.hparams['augment_rotation']))),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
        dataset = ImagetoImageDataset(self.hparams['domainY_dir'], self.hparams['domainO_dir'], train_transform)
        #use small data
        print(f"Using {len(dataset)} images for training")
        # dataset = torch.utils.data.Subset(dataset, range(0, 10))

        return DataLoader(dataset,
                          batch_size=self.hparams['batch_size'],
                          num_workers=self.hparams['num_workers'],
                          shuffle=True)



In [None]:
def train():
    args = {
        'config': 'configs/aging_gan.yaml',
        'load_checkpoint_dir': None,
        'save_checkpoint_dir': None
    }
    with open(args['config']) as file:
        configs = yaml.load(file, Loader=yaml.FullLoader)
    configs['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
    torch.set_float32_matmul_precision('medium')
    print(configs)
    model = GAN(configs)
    if args['load_checkpoint_dir']:
        model.load_checkpoint(args['load_checkpoint_dir'])

    trainer = Trainer(max_epochs=configs['epochs'])
    trainer.fit(model)

    output_path = args['save_checkpoint_dir'] if args['save_checkpoint_dir'] else 'pretrained/'
    try:
        os.mkdir(output_path)
    except:
        pass
    torch.save(model.genY2O.state_dict(), f"{output_path}{configs['y2o']}")
    torch.save(model.genO2Y.state_dict(), f"{output_path}{configs['o2y']}")

In [None]:
train()
infer()