In [None]:
####

In [None]:
import torch ## Deep Learning Framework
from torch import nn ## Neural Nets

from torchsummary import summary ## To get summary of model
import os
from torch.utils.data import Dataset , DataLoader ## Custom dataset , dataloader
from torchvision import transforms ## transformation for image
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ## Device cuda or cpu

In [None]:
def show_tensor_images(image_tensor, num_images=2, size=(3 , 256 , 256)):
  image_shifted = image_tensor
  image_unflat = image_shifted.detach().cpu().view(-1, *size)
  image_grid = make_grid(image_unflat[:num_images], nrow=5)
  plt.imshow(image_grid.permute(1, 2, 0).squeeze())
  plt.show()

In [None]:
class Linear(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 use_norm = True , 
                 use_activation = True , 
                 use_dropout = False , 
                 n_slope = 0.2 , 
                 p_dropout = 0.5):
        super(Linear , self).__init__()

        self.use_norm = use_norm
        self.use_activation = use_activation
        self.use_dropout = use_dropout

        self.linear = nn.Linear(in_channels , out_channels)
        if self.use_norm:
            self.norm = nn.BatchNorm1d(out_channels)
        if self.use_activation:
            self.activation = nn.LeakyReLU(n_slope)
        if self.use_dropout:
            self.dropout = nn.Dropout(p_dropout)
    
    def forward(self , x):
        x = self.linear(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_activation:
            x = self.activation(x)
        if self.use_dropout:
            x = self.dropout(x)
        return x

In [None]:
## Linear test
linear = Linear(512 , 256).to(device)
x = torch.randn(2 , 512).to(device)
out = linear(x)
out.shape

In [None]:
class Mapping_Network(nn.Module):
    def __init__(self , 
                 in_channels = 512, 
                 out_channels = 512 , 
                 hidden_dim = 32):
        super(Mapping_Network , self).__init__()

        self.layers = nn.Sequential(
            Linear(in_channels , hidden_dim) ,
            Linear(hidden_dim  , hidden_dim * 2) , 
            Linear(hidden_dim * 2, hidden_dim * 4) , 
            Linear(hidden_dim * 4 , hidden_dim * 8) , 
            Linear(hidden_dim * 8 , hidden_dim * 16) , 
            Linear(hidden_dim * 16 , hidden_dim * 32) , 
            Linear(hidden_dim * 32 , hidden_dim * 64) , 
            Linear(hidden_dim * 64 , out_channels)
        )
    def forward(self , x):
        x = self.layers(x)
        return x

In [None]:
mapping_network = Mapping_Network().to(device)
x = torch.randn(2 , 512).to(device)
out = mapping_network(x)
out.shape

In [None]:
class Conv(nn.Module):
    def __init__(self , 
                 in_channels, 
                 out_channels , 
                 kernel_size = (3 , 3) , 
                 stride = (1 , 1) , 
                 padding = 1 , 
                 use_norm = True , 
                 use_activation = True , 
                 use_pool = True , 
                 n_slope = 0.2):
        super(Conv , self).__init__()

        self.use_norm = use_norm
        self.use_activation = use_activation
        self.use_pool = use_pool
        
        self.conv = nn.Conv2d(in_channels ,
                              out_channels , 
                              kernel_size , 
                              stride , 
                              padding)
        if self.use_norm:
            self.norm = nn.InstanceNorm2d(out_channels)
        if self.use_activation:
            self.activation = nn.LeakyReLU(n_slope)
        if self.use_pool:
            self.pool = nn.MaxPool2d(kernel_size=(2 , 2) , stride=(2 , 2))
    
    def forward(self , x):
        x = self.conv(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_activation:
            x = self.activation(x)
        if self.use_pool:
            x = self.pool(x)
        return x

In [None]:
conv = Conv(3 , 512).to(device)
summary(conv , (3 , 512 , 512))

In [None]:
class ConvT(nn.Module):
    def __init__(self , 
                 in_channels, 
                 out_channels , 
                 kernel_size = (2 , 2) , 
                 stride = (2 ,2) , 
                 padding = 0 , 
                 use_norm = True , 
                 use_activation = True , 
                 use_second_convT = False , 
                 n_slope = 0.2):
        super(ConvT , self).__init__()
        
        self.use_norm = use_norm
        self.use_activation = use_activation
        self.use_second_convT = use_second_convT

        self.convT1 = nn.ConvTranspose2d(in_channels , out_channels , kernel_size , stride , padding)
        if self.use_norm:
            self.norm1 = nn.InstanceNorm2d(out_channels)
        if self.use_activation:
            self.activation = nn.LeakyReLU(n_slope)
        if self.use_second_convT:
            self.convT2 = nn.ConvTranspose2d(out_channels , out_channels , kernel_size , stride , padding)
            if self.use_norm:
                self.norm2 = nn.InstanceNorm2d(out_channels * 2)
    
    def forward(self , x):
        x = self.convT1(x)
        if self.use_norm:
            x = self.norm1(x)
        if self.use_activation:
            x = self.activation(x)
        if self.use_second_convT:
            x = self.convT2(x)
            if self.use_norm:
                x = self.norm2(x)
            if self.use_activation:
                x = self.activation(x)
        return x

In [None]:
convT = ConvT(512 , 256 , use_second_convT=True).to(device)
summary(convT , (512 , 2 , 2))

In [None]:
class AdaIN(nn.Module):
    def __init__(self , in_channels):
        super(AdaIN , self).__init__()

        self.in_channels = in_channels
        self.norm = nn.InstanceNorm2d(in_channels)

    def forward(self , x , y):
        out = torch.mean(y) * self.norm(x) + torch.var(y)
        return out

In [None]:
class A_Block(nn.Module):
    ## [512 , 256 , 128 , 64 , 32 , 16 , 8 , 4]
    def __init__(self ,
                 in_channels , 
                 out_channels , 
                 first_layer = False):
        super(A_Block , self).__init__()
        
        if first_layer:
            self.convT = ConvT(in_channels , out_channels , use_second_convT=True)
        else:
            self.convT = ConvT(in_channels , out_channels)
        
    def forward(self , x):
        x = self.convT(x)
        return x

In [None]:
class A_Net(nn.Module):
    def __init__(self):
        super(A_Net , self).__init__()

        self.network1 = nn.ModuleList([
                                       A_Block(512 , 512 , first_layer=True) , 
                                       A_Block(256 , 256) , 
                                       A_Block(128 , 128) , 
                                       A_Block(64 , 64) , 
                                       A_Block(32 , 32) , 
                                       A_Block(16 , 16) , 
                                       A_Block(8 , 8) , 
                                       A_Block(4 , 4)
                                       
        ])
        self.network2 = nn.ModuleList([
                                       Conv(512 , 256 , use_pool=False) , 
                                       Conv(256 , 128 , use_pool=False) , 
                                       Conv(128 , 64 , use_pool=False) , 
                                       Conv(64 , 32 , use_pool=False) , 
                                       Conv(32 , 16 , use_pool=False) , 
                                       Conv(16 , 8 , use_pool=False) , 
                                       Conv(8 , 4 , use_pool=False)
        ])

    def forward(self , x , layer):
        if layer == 0:
            x1 = self.network1[layer](x)
            x2 = self.network2[layer](x1)
            return x1 , x2
        else:
            for l in range(0 , layer+1):
                x = self.network1[l](x)
                if l == layer:
                    x1 = x
                x = self.network2[l](x)
                if l == layer:
                    x2 = x
            return x1 , x2

In [None]:
a_ = A_Net().to(device)
x = torch.randn(2 , 512 , 1 , 1).to(device)
out1 , out2 = a_(x , 6)
print(out1.shape , out2.shape)

In [None]:
class B_Net(nn.Module):
    def __init__(self):
        super(B_Net , self).__init__()

        self.conv = nn.ModuleList([
                                   Conv(512 , 512 , kernel_size=1 , stride=1 , padding=0 , use_pool=False),
                                   Conv(256 , 256 , kernel_size=1 , stride=1 , padding=0 , use_pool=False),
                                   Conv(128 , 128 , kernel_size=1 , stride=1 , padding=0 , use_pool=False),
                                   Conv(64 , 64 , kernel_size=1 , stride=1 , padding=0 , use_pool=False),
                                   Conv(32 , 32 , kernel_size=1 , stride=1 , padding=0 , use_pool=False),
                                   Conv(16 , 16 , kernel_size=1 , stride=1 , padding=0 , use_pool=False),
                                   Conv(8 , 8 , kernel_size=1 , stride=1 , padding=0 , use_pool=False),
                                   Conv(4 , 4 , kernel_size=1 , stride=1 , padding=0 , use_pool=False)
        ])

        self.conv_ = nn.ModuleList([
                                    Conv(512 , 256 , kernel_size=1 , stride=1 , padding=0 , use_pool=False),
                                    Conv(256 , 128 , kernel_size=1 , stride=1 , padding=0 , use_pool=False),
                                    Conv(128 , 64 , kernel_size=1 , stride=1 , padding=0 , use_pool=False),
                                    Conv(64 , 32 , kernel_size=1 , stride=1 , padding=0 , use_pool=False),
                                    Conv(32 , 16 , kernel_size=1 , stride=1 , padding=0 , use_pool=False),
                                    Conv(16 , 8 , kernel_size=1 , stride=1 , padding=0 , use_pool=False),
                                    Conv(8 , 4 , kernel_size=1 , stride=1 , padding=0 , use_pool=False)
        ])

    def forward(self , x , layer):
        if layer == 0:
            x1 = self.conv[layer](x)
            x2 = self.conv_[layer](x1)
            return x1 , x2
        else:
            for l in range(0 , layer+1):
                x = self.conv[l](x)
                if l == layer:
                    x1 = x
                x = self.conv_[l](x)
                if l == layer:
                    x2 = x
            return x1 , x2

In [None]:
b_ = B_Net().to(device)
x = torch.randn(2 , 512 , 1 , 1).to(device)
out1 , out2 = b_(x , 6)
print(out1.shape , out2.shape)

In [None]:
class Style_Block(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels ):
        super(Style_Block , self).__init__()


        self.adain = AdaIN(in_channels)

        self.conv1 = Conv(in_channels , in_channels , use_pool=False)
        self.conv2 = Conv(in_channels , out_channels , use_pool=False)

        self.a_net = A_Net()
        self.b_net = B_Net()

        self.upsample = nn.Upsample(scale_factor=2)

    def forward(self , x , w , noise , layer):
        b1 , b2 = self.b_net(noise , layer)
        a1 , a2 = self.a_net(w , layer)
        if layer == 0:
            x = x + b1
            x = self.adain(x , a1)
            x = self.conv2(x)
            x = x + b2
            x = self.adain(x , a2)
            return x
        else:
            x = self.upsample(x)
            x = self.conv1(x)
            x = x + b1
            x = self.adain(x , a1)
            x = self.conv2(x)
            x = x + b2
            x = self.adain(x , a2)
            return x

In [None]:
styled_conv = Style_Block(256 , 128).to(device)
x = torch.randn(2 , 256 , 4 , 4).to(device)
w = torch.randn(2 , 512 , 1 , 1).to(device)
noise = torch.randn(2 , 512 , 1 , 1).to(device)
x = styled_conv(x , w , noise , 1)
x.shape

In [None]:
class Generator(nn.Module):
    def __init__(self , 
                 batch_size , 
                 device = device):
        super(Generator , self).__init__()

        self.batch_size = batch_size
        self.device = device
        self.network = nn.ModuleList([
                                      Style_Block(512 , 256) , 
                                      Style_Block(256 , 128) , 
                                      Style_Block(128 , 64) , 
                                      Style_Block(64 ,32) , 
                                      Style_Block(32 , 16) , 
                                      Style_Block(16 , 8) , 
                                      Style_Block(8 , 4)
        ])
        
        self.to_rgb = nn.ModuleList([
                                     Conv(256 , 3 , kernel_size=(1 , 1) , stride=(1 , 1) , padding=0 , use_pool=False),
                                     Conv(128 , 3 , kernel_size=(1 , 1) , stride=(1 , 1) , padding=0 , use_pool=False),
                                     Conv(64 , 3 , kernel_size=(1 , 1) , stride=(1 , 1) , padding=0 , use_pool=False),
                                     Conv(32 , 3 , kernel_size=(1 , 1) , stride=(1 , 1) , padding=0 , use_pool=False),
                                     Conv(16 , 3 , kernel_size=(1 , 1) , stride=(1 , 1) , padding=0 , use_pool=False),
                                     Conv(8 , 3 , kernel_size=(1 , 1) , stride=(1 , 1) , padding=0 , use_pool=False),
                                     Conv(4 , 3 , kernel_size=(1 , 1) , stride=(1 , 1) , padding=0 , use_pool=False)
        ])


    def forward(self , x , layer):
        w = torch.randn(self.batch_size , 512 , 1 , 1).to(self.device)
        noise = torch.randn(self.batch_size , 512 , 1 , 1).to(self.device)
        if layer == 0:
            x = torch.randn((self.batch_size , 512 , 4 , 4)).to(self.device)
            x = self.network[layer](x ,w , noise , layer)
            x_ = self.to_rgb[layer](x)
            return x , x_
        else:
            x = self.network[layer](x ,w , noise , layer)
            x_ = self.to_rgb[layer](x)
            return x , x_

In [None]:
generator = Generator(2).to(device)
x = torch.randn(2 , 512 , 4 , 4).to(device)
out , out_rbg = generator(x , 0)
out.shape , out_rbg.shape

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator , self).__init__()

        self.from_rgb = nn.ModuleList([
                                       Conv(3 , 256 , kernel_size=(1 , 1) , stride=(1 , 1) , padding=0 , use_pool=False),
                                       Conv(3 , 128 , kernel_size=(1 , 1) , stride=(1 , 1) , padding=0 , use_pool=False),
                                       Conv(3 , 64 , kernel_size=(1 , 1) , stride=(1 , 1) , padding=0 , use_pool=False),
                                       Conv(3 , 32 , kernel_size=(1 , 1) , stride=(1 , 1) , padding=0 , use_pool=False),
                                       Conv(3 , 16 , kernel_size=(1 , 1) , stride=(1 , 1) , padding=0 , use_pool=False),
                                       Conv(3 , 8 , kernel_size=(1 , 1) , stride=(1 , 1) , padding=0 , use_pool=False),
                                       Conv(3 , 4 , kernel_size=(1 , 1) , stride=(1 , 1) , padding=0 , use_pool=False)
        ])

        self.network = nn.ModuleList([ 
                                      Conv(256 , 128) , 
                                      Conv(128 , 64) , 
                                      Conv(64 , 32) , 
                                      Conv(32 , 16) , 
                                      Conv(16 , 8) , 
                                      Conv(8 , 4)
        ])

        self.flatten = nn.Flatten()
        self.linear = nn.Sequential(
            Linear(64 , 32) , 
            Linear(32 , 16) , 
            Linear(16, 8) , 
            Linear(8 , 4) , 
            Linear(4 , 1 , use_activation=False , use_norm=False),
            nn.Sigmoid()
        )

    def forward(self , x , layer):
        x = self.from_rgb[layer](x)
        for l in range(layer , len(self.network)):
            x = self.network[l](x)
        x = self.flatten(x)
        x = self.linear(x)
        return x

In [None]:
discriminator = Discriminator().to(device)
x = torch.randn(2 , 3 , 8 , 8).to(device)
out = discriminator(x , 5)
out.shape

In [None]:
def test():
    gen = Generator(2).to(device)
    disc = Discriminator().to(device)
    channels = 512
    shape_ = 4
    for layer in range(0 , 7):
        x = torch.randn(2 , channels , shape_ , shape_).to(device)
        print(f'x {x.shape} , layer {layer}')
        x , x_rgb = gen(x , layer)
        print(torch.max(x) , torch.min(x))
        disc_pred = disc(x_rgb , 6 - layer)
        channels = channels // 2
        if layer == 0:
            shape_ = shape_
        else:
            shape_ = shape_ * 2
        print(f'x {x.shape} , disc pred {disc_pred.shape}')

In [None]:
def test():
    gen = Generator(2).to(device)
    disc = Discriminator().to(device)
    channels = 512
    shape_ = 4
    x = torch.randn(2 , 512 , 4 , 4).to(device)
    for layer in range(0 , 7):
        print(f'x {x.shape} , layer {layer}')
        x , x_rgb = gen(x , layer)
        disc_pred = disc(x_rgb , 6 - layer)
        print(f'x {x.shape} , x_rgb {x_rgb.shape} , disc pred {disc_pred.shape}')

In [None]:
test()

In [None]:
def resize_on_the_fly(img , layer):
    sizes = [4 , 8 , 16 , 32 , 64 , 128 , 256]
    img = transforms.functional.resize(img , [sizes[layer] , sizes[layer]])
    return img 

In [None]:
x = torch.randn(2 , 3 , 512 , 512).to(device)
x = resize_on_the_fly(x , 1)
x.shape

In [None]:
def get_dataset_mapped(root_dir = '/content/drive/MyDrive/celeb_hq/celeba_hq/train/'):
    img_paths = []
    for data in os.listdir(root_dir):
        for img in os.listdir(os.path.join(root_dir , data)):
            img_paths.append(os.path.join(root_dir , data , img))
    return img_paths

In [None]:
img_paths = get_dataset_mapped()

In [None]:
class Dataset_(Dataset):
    def __init__(self , 
                transforms = None, 
                img_paths = img_paths):
        super(Dataset_ , self).__init__()
        self.img_paths = img_paths
        self.transforms = transforms

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self , idx):
        img = self.img_paths[idx]
        img = np.asarray(plt.imread(img))
        img_tensor = torch.from_numpy(img).permute(2 , 0 , 1)
        if self.transforms:
            img_tensor = self.transforms(img_tensor)
        return img_tensor

In [None]:
transform = transforms.Compose([
                                transforms.ToPILImage() , 
                                transforms.Resize((256 , 256)) , 
                                transforms.ToTensor()
])

In [None]:
dataset = Dataset_(transforms=transform)

In [None]:
adv_criterion = nn.BCEWithLogitsLoss()
recon_criterion = nn.L1Loss()
lambda_recon = 200
betas = (0.5 , 0.999)

n_layers = 6
n_epochs = 1
in_channels = 3
out_channels = 3
display_step = 100
batch_size = 2
lr = 0.0002
target_shape = 256

In [None]:
dataloader = DataLoader(dataset , batch_size=batch_size , shuffle=True)

In [None]:
for x in dataloader:
    show_tensor_images(x)
    print(torch.max(x) , torch.min(x))
    break

In [None]:
generator = Generator(batch_size).to(device)
opt_generator = torch.optim.Adam(generator.parameters() , lr=lr , betas = betas)
discriminator = Discriminator().to(device)
opt_discriminator = torch.optim.Adam(discriminator.parameters() , lr=lr , betas = betas)

In [None]:
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)

In [None]:
generator = generator.apply(weights_init)
discriminator = discriminator.apply(weights_init)

In [None]:
def get_generator_loss(fake , 
                       real , 
                       layer , 
                       generator = generator , 
                       discriminator = discriminator , 
                       adv_criterion = adv_criterion , 
                       recon_criterion = recon_criterion , 
                       lambda_recon = lambda_recon):
    disc_pred = discriminator(fake , 6 - layer)
    disc_loss = adv_criterion(disc_pred , torch.zeros_like(disc_pred))
    generator_loss = recon_criterion(fake , real)
    loss = disc_loss + lambda_recon * generator_loss
    return loss

In [None]:
torch.autograd.set_detect_anomaly(True)

In [None]:
def train():
    mean_generator_loss = 0
    mean_discriminator_loss = 0
    cur_step = 0
    best_loss = 0
    for layer in range(n_layers):
        for epoch in range(n_epochs):
            for img in tqdm(dataloader):
                x = torch.randn(batch_size , 512 , 4 , 4).to(device)
                x1 = torch.randn_like(x)
                for l in range(layer + 1):
                    img = img.to(device)
                    img = resize_on_the_fly(img , l)

                    opt_generator.zero_grad()
                    if l == 0:
                        x_next , fake_img = generator(x , l)
                        x_next = torch.tensor(x_next , requires_grad=False)
                    else:
                        if l == layer:
                            x_next , fake_img = generator(x_next , l)
                            x_next = torch.tensor(x_next , requires_grad=False)
                        else:
                            with torch.no_grad():
                                x_next , fake_img = generator(x_next , l)
                                x_next = torch.tensor(x_next , requires_grad=False)
                    generator_loss = get_generator_loss(fake_img , img , l)
                    generator_loss.backward(retain_graph=True)
                    opt_generator.step()

                    opt_discriminator.zero_grad()
                    with torch.no_grad():
                        if l == 0:
                            img_ , fake_img_ = generator(x1 , l)
                        else:
                            img_ , fake_img_ = generator(img_ , l) 
                            img_ = torch.tensor(img_ , requires_grad=False)
                    disc_fake_pred = discriminator(fake_img_ , 6 - l)
                    disc_real_pred = discriminator(img , 6 - l)
                    disc_fake_loss = adv_criterion(disc_fake_pred , torch.zeros_like(disc_fake_pred))
                    disc_real_loss = adv_criterion(disc_real_pred , torch.ones_like(disc_real_pred))
                    discriminator_loss = (disc_fake_loss + disc_real_loss) /2.0
                    discriminator_loss.backward(retain_graph=True)
                    opt_discriminator.step()

                    mean_discriminator_loss += discriminator_loss.item() / display_step
                    mean_generator_loss += generator_loss.item() / display_step

                    if best_loss < mean_generator_loss:
                        print('Saving ....')
                        torch.save(generator.state_dict() , '/content/drive/MyDrive/StyleGAN_Weights/Generator.pt')
                        torch.save(discriminator.state_dict() , '/content/drive/MyDrive/StyleGAN_Weights/Discriminator.pt')
                        best_loss = mean_generator_loss
                
                    if cur_step % display_step == 0:
                        if cur_step > 0:
                            print(f"Epoch {epoch}: layer{l} Step {cur_step}: Generator loss: {mean_generator_loss}, Discriminator loss: {mean_discriminator_loss}")
                        else:
                            print("Pretrained initial state")
                        print('real image')
                        show_tensor_images(img , size=img.shape[1:])
                        print('Generated image')
                        show_tensor_images(fake_img_ , size=fake_img_.shape[1:])
                        mean_generator_loss = 0
                        mean_discriminator_loss = 0
                    cur_step += 1

In [None]:
train()