In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision
from tqdm.notebook import tqdm
!jupyter nbextension enable --py widgetsnbextension
from torch.utils.tensorboard import SummaryWriter
import albumentations as A
import torch.nn.functional as F
from math import log2
from albumentations.pytorch import ToTensorV2
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import numpy as np
device="cuda:7" if torch.cuda.is_available() else "cpu"
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark=True

  warn(


Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: [32mOK[0m


In [2]:
class ConvBlock(nn.Module):
    def __init__(self,in_channels,out_channels,
                 discriminator=False,use_act=True,use_bn=True,**kwargs) -> None:
        super(ConvBlock,self).__init__()
        self.con=nn.Conv2d(in_channels,out_channels,**kwargs,bias=not use_bn)
        self.bn=nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
        self.use_act=use_act
        self.act=(
            nn.LeakyReLU(0.2,inplace=True)
            if discriminator else nn.PReLU(num_parameters=out_channels)
        )
    def forward(self,x):
        return self.act(self.bn(self.con(x))) if self.use_act else self.bn(self.con(x))

In [3]:
class ResidualBlock(nn.Module):
    def __init__(self,in_channels) -> None:
        super().__init__()
        self.b1=ConvBlock(
            in_channels,
            in_channels,
            kernel_size=3,
            stride=1,
            padding=1
        )
        self.b2=ConvBlock(
            in_channels,
            in_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            use_act=False
        )
    def forward(self,x):
        out=self.b1(x)
        out=self.b2(out)
        return out+x

In [4]:
class UpSampleBlock(nn.Module):
    def __init__(self,in_channels,scale_factor) -> None:
        super().__init__()
        # self.con=nn.Conv2d(in_channels,in_channels*scale_factor**2,3,1,1)
        # self.ps=nn.PixelShuffle(scale_factor)
        # self.prelu=nn.PReLU(num_parameters=in_channels)
        self.con=nn.Sequential(
            nn.Conv2d(in_channels,in_channels*scale_factor**2,3,1,1),
            nn.PixelShuffle(scale_factor),
            nn.PReLU(num_parameters=in_channels)
        )
    def forward(self,x):
        return self.con(x)

In [5]:
class Generator(nn.Module):
    def __init__(self,in_channels=3,num_features=64,num_blocks=16) -> None:
        super().__init__()
        self.inital_block=ConvBlock(in_channels,num_features,kernel_size=9,
                                    stride=1,padding=4,use_bn=False)
        self.residuals=nn.Sequential(*[ResidualBlock(num_features) for _ in range(num_blocks)])
        self.convblock=ConvBlock(num_features,num_features,kernel_size=3,stride=1,
                                 padding=1,use_act=False)
        self.upsamples=nn.Sequential(
            UpSampleBlock(num_features,2),UpSampleBlock(num_features,2)
        )
        self.final_layer=nn.Conv2d(num_features,in_channels,kernel_size=9,stride=1,padding=4)
        
    def forward(self,x):
        inital=self.inital_block(x)
        x=self.residuals(inital)
        x=self.convblock(x)+inital
        x=self.upsamples(x)
        
        return torch.tanh(self.final_layer(x))

In [6]:
class Discriminator(nn.Module):
    def __init__(self,in_channels=3,features=[64,64,128,128,256,256,512,512]) -> None:
        super().__init__()
        blocks=[]
        for i,feature in enumerate(features):
            blocks.append(
                ConvBlock(
                    in_channels,
                    feature,
                    kernel_size=3,
                    stride=1+i%2,
                    padding=1,
                    use_act=True,
                    use_bn=True if i!=0 else False,
                    
                )
            )
            in_channels = feature
        self.convblock=nn.Sequential(*blocks)
        self.classifier=nn.Sequential(
            nn.AdaptiveAvgPool2d((6,6)),
            nn.Flatten(),
            nn.Linear(512*6*6,1024),
            nn.Linear(1024,1),
        )
    def forward(self,x):
        x=self.convblock(x)
        return self.classifier(x)

In [7]:
low_resolution=24
with torch.cuda.amp.autocast():
    x=torch.randn((5,3,low_resolution,low_resolution))
    g=Generator()
    d=Discriminator()
    gen_out=g(x)
    d_out=d(gen_out)
    print(gen_out.shape)
    print(d_out.shape)

torch.Size([5, 3, 96, 96])
torch.Size([5, 1])


In [8]:
class VggLoss(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.vgg=torchvision.models.vgg19(pretrained=True).features[:36].eval().to(device)
        self.loss=nn.MSELoss()
        for param in self.vgg.parameters():
            param.requires_grad=False
    
    def forward(self,in_features,target):
            vgg_in_features=self.vgg(in_features)
            vgg_out_features=self.vgg(target)
            
            return self.loss(vgg_in_features,vgg_out_features)
    

# Dataset preparation

In [9]:
HIGH_RES = 96
LOW_RES = HIGH_RES // 4
highres_transform = A.Compose(
    [
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ToTensorV2(),
    ]
)
lowres_transform = A.Compose(
    [
        A.Resize(width=LOW_RES, height=LOW_RES, interpolation=Image.BICUBIC),
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
        ToTensorV2(),
    ]
)
both_transforms = A.Compose(
    [
        A.RandomCrop(width=HIGH_RES, height=HIGH_RES),
        A.HorizontalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
    ]
)
test_transform = A.Compose(
    [
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
        ToTensorV2(),
    ]
)

In [10]:
class SrGanDataset(Dataset):
    def __init__(self, root_dir,train=True):
        super(SrGanDataset, self).__init__()
        self.data = []
        self.root_dir = root_dir
        self.class_names = os.listdir(root_dir)
        self.train=train
        for index, name in enumerate(self.class_names):
            files = os.listdir(os.path.join(root_dir, name))
            self.data += list(zip(files, [index] * len(files)))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img_file, label = self.data[index]
        root_and_dir = os.path.join(self.root_dir, self.class_names[label])

        image = np.array(Image.open(os.path.join(root_and_dir, img_file)))
        if self.train:
            image = both_transforms(image=image)["image"]
            high_res = highres_transform(image=image)["image"]
            low_res = lowres_transform(image=image)["image"]
            return low_res, high_res
        else:
            image=test_transform(image=image)["image"]
            
            return image

In [11]:
# dataset=SrGanDataset('/mnt/disk1/Gulshan/GAN/ProGAN/celeba_hq/train')
# train_loader = DataLoader(dataset, batch_size=2, num_workers=8)
# x,y=next(iter(train_loader))
# print(x.shape,y.shape) #torch.Size([2, 3, 24, 24]) torch.Size([2, 3, 96, 96])
# x,y=np.array(x),np.array(y)
# x=x.transpose(0,2,3,1)
# y=y.transpose(0,2,3,1)
# plt.axis('off')
# plt.imshow(x[1,...])
# plt.show()
# plt.axis('off')
# plt.imshow(y[1,...])
# plt.show()

# Hyperparameters

In [12]:
lr=1e-4
num_epochs=100
bs=1
num_worker=4
high_res=96
low_res=high_res//4
img_channesl=3

In [13]:
generator=Generator().to(device)
discriminator=Discriminator().to(device)
opt_gen = optim.Adam(generator.parameters(), lr=lr, betas=(0.9, 0.999))
opt_disc = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.9, 0.999))
mse = nn.MSELoss()
bce = nn.BCEWithLogitsLoss()
vgg_loss = VggLoss()
scaler_critic = torch.cuda.amp.GradScaler()
scaler_gen = torch.cuda.amp.GradScaler()
writer_real= SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")



In [14]:
dataset=SrGanDataset('/mnt/disk1/Gulshan/GAN/ProGAN/celeba_hq/train')
train_loader = DataLoader(dataset, batch_size=bs, num_workers=8)
dataset=SrGanDataset('/mnt/disk1/Gulshan/GAN/ProGAN/celeba_hq/val',train=False)
val_loader = DataLoader(dataset, batch_size=bs, num_workers=8)
# x=next(iter(val_loader))
# print(x.shape,) #torch.Size([16, 3, 1024, 1024])
# x=np.array(x)
# x=x.transpose(0,2,3,1)
# plt.axis('off')
# plt.imshow(x[1,...])
# plt.show()

In [15]:
def train():    
    generator.train()
    discriminator.train()
    for epoch in tqdm(range(num_epochs),total=num_epochs):
        for batch_idx,(low_res, high_res) in tqdm(enumerate(train_loader)):
            high_res = high_res.to(device)
            low_res = low_res.to(device)
            with torch.cuda.amp.autocast():
                fake = generator(low_res)
                disc_real = discriminator(high_res)
                disc_fake = discriminator(fake.detach())
                disc_loss_real = bce(
                    disc_real, torch.ones_like(disc_real) - 0.1 * torch.rand_like(disc_real)
                )
                disc_loss_fake = bce(disc_fake, torch.zeros_like(disc_fake))
                loss_disc = disc_loss_fake + disc_loss_real
            opt_disc.zero_grad()
            scaler_critic.scale(loss_disc).backward()
            scaler_critic.step(opt_disc)
            scaler_critic.update()
            with torch.cuda.amp.autocast():
                disc_fake = discriminator(fake)
                #l2_loss = mse(fake, high_res)
                adversarial_loss = 1e-3 * bce(disc_fake, torch.ones_like(disc_fake))
                loss_for_vgg = 0.006 * vgg_loss(fake, high_res)
                gen_loss = loss_for_vgg + adversarial_loss
            opt_gen.zero_grad()
            scaler_critic.scale(gen_loss).backward()
            scaler_critic.step(opt_gen)
            scaler_critic.update()
            
            if batch_idx % 100 == 0:
                high_scale_img=next(iter(val_loader))
                with torch.no_grad():
                    upscaled_img = generator(high_scale_img.to(device))
                    img_grid_real = torchvision.utils.make_grid(high_scale_img[:bs],normalize=True)
                    img_grid_fake = torchvision.utils.make_grid(upscaled_img[:bs], normalize=True)
                    writer_real.add_image("Real", img_grid_real, global_step=epoch)
                    writer_fake.add_image("Fake", img_grid_fake, global_step=epoch)
                    writer_fake.add_scalar("gen_loss ", gen_loss.item(), global_step=epoch)
                    writer_real.add_scalar("loss_disc ", loss_disc.item(), global_step=epoch)
                      
        print(f"epoch:{epoch}/num of epochs:{num_epochs},Gen loss : {gen_loss.item()}, Disc loss: {loss_disc.item()}")
train()

  0%|          | 0/100 [00:00<?, ?it/s]

0it [00:00, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 4.00 GiB. GPU  has a total capacity of 10.75 GiB of which 1.59 GiB is free. Including non-PyTorch memory, this process has 9.16 GiB memory in use. Of the allocated memory 6.35 GiB is allocated by PyTorch, and 2.59 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)