In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import torchvision.transforms.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import pickle
from torch.optim.lr_scheduler import ExponentialLR

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
class FacesDataset(Dataset):
    def __init__(self, imgs, labels):
        self.imgs = imgs
        self.labels = labels

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        return self.imgs[idx], self.labels[idx]

def load_dataset(batch_size, augment=True):
    faces = torch.load('/kaggle/input/ap-cfggan-ds/faces_sex_images_2.pt').permute((0, 3, 1, 2))
    labels = torch.load('/kaggle/input/ap-cfggan-ds/labels_sex_images_2.pt')
    
    if augment:
        men = faces[:3367]
        indices = torch.randperm(len(men))[:5721-3367]
        selected_men = men[indices]
        selected_men = F.hflip(selected_men)
        faces = torch.cat((faces[:len(faces)-25], selected_men[:len(selected_men)-25]), 0)
        labels = torch.cat((labels[:len(labels)-25], torch.ones(len(selected_men))[:len(selected_men)-25]), 0)
    
    print(faces.shape, labels.shape)
    
    dataloader = DataLoader(FacesDataset(faces, labels), batch_size, True, pin_memory=True)
    dataloader_faces_only = DataLoader(faces, batch_size, True, pin_memory=True)

    return dataloader, dataloader_faces_only

batch_size = 64
dataloader, dataloader_faces_only = load_dataset(batch_size, augment=False)

In [None]:
len(dataloader_faces_only)

In [None]:
faces = torch.load('/kaggle/input/ap-cfggan-ds/faces_sex_images_2.pt').permute((0, 3, 1, 2))
labels = torch.load('/kaggle/input/ap-cfggan-ds/labels_sex_images_2.pt')

In [None]:
len(labels[labels == 1]), len(labels[labels == 0])

In [None]:
men = faces[:3367]
indices = torch.randperm(len(men))[:5721-3367]
selected_men = men[indices]
selected_men = F.hflip(selected_men)

In [None]:
faces = torch.cat((faces, selected_men), 0)
labels = torch.cat((labels, torch.ones(len(selected_men))), 0)
faces.shape, labels.shape

In [None]:
plt.imshow(F.to_pil_image(F.hflip(selected_men[3])*0.5 + 0.5 ))
labels

In [None]:
batch_size = 64
with open('/kaggle/input/ap-dataset/faces(1).npy','rb') as f:
    n_imgs = batch_size * 142
    faces = np.load(f)
    faces = faces[:n_imgs]#.reshape((n_imgs, 3, 128, 128))
    print(faces.shape)
    
from torch.utils.data import DataLoader
dataloader = DataLoader(torch.from_numpy(faces).permute((0, 3, 1, 2)), batch_size, True, pin_memory=True)

In [None]:
next(iter(dataloader))[0].shape

In [None]:
def plot_images(sqr = 5 , class_=1):
    plt.figure(figsize = (15,15))
    plt.title("Real Images",fontsize = 35)
    for i in range(sqr * sqr):
        idx = 0 if i >= 64 else i
        f = next(iter(dataloader))

        while f[1][idx] != class_:
            f = next(iter(dataloader))

        plt.subplot(sqr,sqr,i+1)
        plt.imshow(F.to_pil_image((f[0][idx] + (torch.randn(f[0][idx].size()) * 0.15))*0.5 + 0.5 ))
        print(torch.max(f[0][idx] + (torch.randn(f[0][idx].size())*(0.1**0.5))))
        plt.xticks([])
        plt.yticks([])

# to plot images
plot_images(5)

In [None]:
plot_images(5, 0)

In [None]:
next(iter(dataloader_faces_only)).shape

In [None]:
def plot_images(sqr = 5):
    plt.figure(figsize = (10,10))
    plt.title("Real Images",fontsize = 35)
    for i in range(sqr * sqr):
        r = next(iter(dataloader_faces_only))
        plt.subplot(sqr,sqr,i+1)
        plt.imshow(F.to_pil_image(r[i]*0.5+0.5))
        plt.xticks([])
        plt.yticks([])

plot_images(5)

## Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()

        self.conv_block1 = self.__block(3, 64)
        self.conv_block2 = self.__block(64, 128)
        self.conv_block3 = self.__block(128, 256)
        self.conv_block4 = self.__block(256, 512)
        self.conv_block5 = self.__block(512, 64)
        self.linear1 = nn.Sequential(
            nn.Linear(1024, 100),
            nn.LeakyReLU(0.2),
            nn.Linear(100, 1),
            nn.Sigmoid()
        )

    def __block(self, input, output):
        return nn.Sequential(
            nn.Conv2d(input, output, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(output),
            nn.LeakyReLU(0.2)
        )

    def forward(self, features):
        out_conv = self.conv_block1(features)
        out_conv = self.conv_block2(out_conv)
        out_conv = self.conv_block3(out_conv)
        out_conv = self.conv_block4(out_conv)
        out_conv = self.conv_block5(out_conv)
        
        flattened = out_conv.reshape(out_conv.size(0), -1)
        
        output = self.linear1(flattened)

        return output

## Generator

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, **kwargs):
        super().__init__()
        self.latent_dim = latent_dim
        # Upsampling
        self.upsampling_block1 = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, 4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            #nn.Dropout(0.2)
        )
        self.upsampling_block2 = self.__upsampling_block(512, 256)
        self.upsampling_block3 = self.__upsampling_block(256, 128)
        self.upsampling_block4 = self.__upsampling_block(128, 64)
        self.upsampling_block5 = self.__upsampling_block(64, 32)

        self.out_layer = nn.Sequential(
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def __upsampling_block(self, input, output):
        return nn.Sequential(
            nn.ConvTranspose2d(input, output, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(output),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.2)
        )

    def forward(self, features):
        out_upsampling = self.upsampling_block1(features)
        out_upsampling = self.upsampling_block2(out_upsampling)
        out_upsampling = self.upsampling_block3(out_upsampling)
        out_upsampling = self.upsampling_block4(out_upsampling)
        out_upsampling = self.upsampling_block5(out_upsampling)
        output = self.out_layer(out_upsampling)

        return output

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, **kwargs):
        super().__init__()
        self.latent_dim = latent_dim
        self.linear1 = nn.Linear(in_features=latent_dim, out_features=128*128*3, bias=False)
    
        # Downsampling
        self.downsampling_block1 = self.__downsampling_block(3, 128) #64
        self.downsampling_block2 = self.__downsampling_block(128, 256)
        self.downsampling_block3 = nn.Sequential(
            nn.ConvTranspose2d(256, 512, 4, stride=1, padding=2, bias=False),
            nn.Conv2d(512, 512, 4, stride=2, padding=2, bias=False),
            # nn.BatchNorm2d(output),
            nn.LeakyReLU(0.2)
        )

        # 512x16x16

        # Upsampling
        self.upsampling_block1 = nn.Sequential(self.__upsampling_block(512, 512), nn.LeakyReLU(0.2))
        self.upsampling_block2 = self.__upsampling_block(512, 256)
        self.upsampling_block3 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=2, bias=False),
            nn.ConvTranspose2d(128, 128, 4, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(128),
            #nn.LeakyReLU(0.2)
        )

        self.out_layer = nn.Sequential(
            nn.ConvTranspose2d(128, 3, 4, stride=1, padding=1),
            nn.Tanh()
        )

    def __downsampling_block(self, input, output):
        return nn.Sequential(
            nn.Conv2d(input, output, 4, stride=1, padding=1, bias=False),
            nn.Conv2d(output, output, 4, stride=2, padding=2, bias=False),
            nn.BatchNorm2d(output),
            nn.LeakyReLU(0.2)
        )

    def __upsampling_block(self, input, output):
        return nn.Sequential(
            nn.ConvTranspose2d(input, output, 4, stride=1, padding=1, bias=False),
            nn.ConvTranspose2d(output, output, 4, stride=2, padding=2, bias=False),
            nn.BatchNorm2d(output),
            #nn.LeakyReLU(0.2)
        )

    def forward(self, features):
        out_linear = self.linear1(features) 

        reshaped = out_linear.view(-1, 3, 128, 128)
        
        # Downsampling group
        out_downsampling = self.downsampling_block1(reshaped)
        out_downsampling = self.downsampling_block2(out_downsampling)
        out_downsampling = self.downsampling_block3(out_downsampling)

        # Upsampling group
        out_upsampling = self.upsampling_block1(out_downsampling) 
        out_upsampling = self.upsampling_block2(out_upsampling)
        out_upsampling = self.upsampling_block3(out_upsampling)
        output = self.out_layer(out_upsampling)

        return output

## FGGAN

In [None]:
class FGGAN(nn.Module):
    def __init__(self, generator, discriminator, **kwargs):
        super().__init__()
        self.generator = generator
        self.discriminator = discriminator

        self.generator.apply(self.weights_init_normal)
        self.discriminator.apply(self.weights_init_normal)
        
        self.g_losses = []
        self.d_losses = []

    def forward(self, features):
        return self.generator(features)
    
    def compile(self,
            generator_optimizer,
            discriminator_optimizer,
            generator_loss_criterion,
            discriminator_loss_criterion
        ):
        self.generator_optimizer = generator_optimizer
        self.discriminator_optimizer = discriminator_optimizer
        self.generator_loss_criterion = generator_loss_criterion
        self.discriminator_loss_criterion = discriminator_loss_criterion

    def train_generator(self, noise, batch_size):
        self.generator_optimizer.zero_grad()
        generated_output = self.generator(noise)
        fake_output = self.discriminator(generated_output)

        # Calculate loss
        generator_labels = torch.ones(batch_size).float().to(device)
        generator_loss = self.generator_loss_criterion(fake_output.squeeze(), generator_labels)

        # Update gradients
        generator_loss.backward()

        # Gradient clipping (exploding gradient)
        torch.nn.utils.clip_grad_norm_(self.generator.parameters(), 1)

        self.generator_optimizer.step()

        g_loss = generator_loss.item()

        return g_loss
    
    def train_discriminator(self, X, noise, batch_size):
        self.discriminator_optimizer.zero_grad()
        generated_output = self.generator(noise).detach()
        fake_output = self.discriminator(generated_output)
        real_output = self.discriminator(X)

        # Calc losses
        discriminator_fake_loss = self.discriminator_loss_criterion(fake_output.squeeze(), torch.zeros(batch_size).float().to(device))
        discriminator_real_loss = self.discriminator_loss_criterion(real_output.squeeze(), torch.from_numpy(np.array([0.9]*batch_size)).float().to(device))
        
        discriminator_loss = discriminator_real_loss + discriminator_fake_loss

        # Update gradients
        discriminator_loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), 1)

        self.discriminator_optimizer.step()

        d_loss = discriminator_loss.item()

        return d_loss

    def fit(self, X, epochs=10, batch_size=64, latent_dim=100, n_disc=1, n_gen=1):
        n_batches = len(X)
        batch_print_step = int(n_batches / 10)
        print("Training starting....")
        #disp_noise = torch.from_numpy(np.random.normal(0, 1, (3, latent_dim))).float().to(device)
        disp_noise = torch.randn(4, latent_dim, 1, 1, device=device).float()

        for epoch in range(epochs):
            g_loss = 0
            d_loss = 0

            print(f"Epoch {epoch}/{epochs}: ", end="")
            for index, batch in enumerate(X):
                batch = batch.to(device)
                #noise = torch.from_numpy(np.random.normal(0, 1, (batch_size, latent_dim))).float().to(device)
                noise = torch.randn(batch_size, latent_dim, 1, 1, device=device).float()
                
                for i in range(n_disc):
                    d_loss_tmp = self.train_discriminator(batch, noise, batch_size)
                d_loss += d_loss_tmp
                
                for i in range(n_gen):
                    g_loss_tmp = self.train_generator(noise, batch_size)
                g_loss += g_loss_tmp
                if index % batch_print_step == 0:
                    print("#", end="")

            g_loss /= n_batches
            d_loss /= n_batches
            
            self.g_losses.append(g_loss)
            self.d_losses.append(d_loss)


            print(f"\nGenerator loss: {g_loss}  Discriminator loss: {d_loss}")
            
            
            with torch.no_grad():
                fig, axs = plt.subplots(1, 3)
                fig.set_figwidth(12)
                fig.set_figheight(4)

                out = self.generator(disp_noise.float())
                print(out.shape)
                #for ax in axs:
                axs[0].imshow(np.array(F.to_pil_image(out[0] * 0.5 + 0.5)))
                axs[1].imshow(np.array(F.to_pil_image(out[1] * 0.5 + 0.5)))
                axs[2].imshow(np.array(F.to_pil_image(out[2] * 0.5 + 0.5)))
                plt.show()

                img = F.to_pil_image(out[0] * 0.5 + 0.5)
                #img.save(f'/content/drive/MyDrive/img_outs/fggan_{epoch}.jpg')
            
            torch.save(self.state_dict(), './fggan_tmp.pt')
            if (epoch+1) % 5 == 0:
                torch.save(self.state_dict(), f'./fggan_epoch_{epoch+1}.pt')
                with open(f'./fggan_losses_{epoch+1}.npy', 'wb') as f:
                    np.save(f, np.array([self.g_losses, self.d_losses]))
            
    def weights_init_normal(self, m):
        classname = m.__class__.__name__
        # Apply initial weights to convolutional and linear layers
        if classname.find('Conv') != -1 or classname.find('Linear') != -1:
            nn.init.normal_(m.weight.data, 0.0,0.02)
        return m


## CGAN 1

In [None]:
from torch.optim.lr_scheduler import LinearLR

class cFGGAN(nn.Module):
    def __init__(self, generator, discriminator, n_classes=2, **kwargs):
        super().__init__()
        self.generator = generator
        self.discriminator = discriminator
        self.n_classes = n_classes

        self.generator_conditional_head = self.__conditional_head().to(device)
        """self.conv_gen = nn.Sequential(
            nn.Conv2d(4, 3, 3, stride=1, padding=1, bias=False),
            nn.Dropout(0.3),
            nn.LeakyReLU(0.2)
        )"""
        self.conv_gen = nn.Sequential(
            nn.Conv2d(4, 128, 4, stride=1, padding=1, bias=False),
            nn.Conv2d(128, 128, 4, stride=2, padding=2, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2)
        )
        self.discriminator_conditional_head = self.__conditional_head().to(device)
        """self.conv_disc = nn.Sequential(
            nn.Conv2d(4, 3, 3, stride=1, padding=1, bias=False),
            nn.Dropout(0.3),
            nn.LeakyReLU(0.2)
        )"""
        self.conv_disc = nn.Sequential(
            nn.Conv2d(4, 128, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU()
        )

        self.generator_conditional_head.apply(self.weights_init_normal)
        self.conv_gen.apply(self.weights_init_normal)
        self.discriminator_conditional_head.apply(self.weights_init_normal)
        self.conv_disc.apply(self.weights_init_normal)
        # self.discriminator.apply(self.weights_init_normal)
        
        self.d_losses = []
        self.g_losses = []
        self.schedule = False

    def __conditional_head(self):
        return nn.Sequential(
            nn.Embedding(self.n_classes, 25),
            nn.Linear(25, 128*128*1)
        )

    def get_generator_parameters(self):
        return list(self.generator.parameters()) + list(self.conv_gen.parameters()) + list(self.generator_conditional_head.parameters())
        #return list(self.generator.parameters()) + list(self.generator_conditional_head.parameters())

    def get_discriminator_parameters(self):
        return list(self.discriminator.parameters()) + list(self.conv_disc.parameters()) + list(self.discriminator_conditional_head.parameters())
        #return list(self.discriminator.parameters()) + list(self.discriminator_conditional_head.parameters())

    def forward(self, features, class_):
        # Conditional input
        out = self.generator_conditional_head(class_)
        reshaped_conditional = out.view(-1, 1, 128, 128).to(device)

        # GAN input
        out_linear = self.generator.linear1(features) 
        reshaped_gan = out_linear.view(-1, 3, 128, 128).to(device)
        
        #conditioned_out = reshaped_gan.to(device) * reshaped_conditional.to(device)
        
        # Downsampling group
        conditioned_out = self.conv_gen(torch.cat((reshaped_gan.to(device), reshaped_conditional.to(device)), 1).to(device))
        #out_downsampling = self.generator.downsampling_block1(conditioned_out)
        out_downsampling = self.generator.downsampling_block2(conditioned_out)
        out_downsampling = self.generator.downsampling_block3(out_downsampling)

        # Upsampling group
        out_upsampling = self.generator.upsampling_block1(out_downsampling) 
        out_upsampling = self.generator.upsampling_block2(out_upsampling)
        out_upsampling = self.generator.upsampling_block3(out_upsampling)
        output = self.generator.out_layer(out_upsampling)

        return output

    def forward_discriminator(self, features, class_):
        # Conditional input
        out = self.discriminator_conditional_head(class_)
        reshaped_conditional = out.view(-1, 1, 128, 128).to(device)

        conditioned_out = self.conv_disc(torch.cat((features.to(device), reshaped_conditional), 1))
        #conditioned_out = reshaped_conditional * features
        
        #out_conv = self.conv_block1(features)
        out_conv = self.discriminator.conv_block2(conditioned_out)
        out_conv = self.discriminator.conv_block3(out_conv)
        out_conv = self.discriminator.conv_block4(out_conv)
        out_conv = self.discriminator.conv_block5(out_conv)
        
        flattened = out_conv.reshape(out_conv.size(0), -1)
        
        output = self.discriminator.linear1(flattened)
        
        return output # self.discriminator(conditioned_out)
    
    def compile(self,
            generator_optimizer,
            discriminator_optimizer,
            generator_loss_criterion,
            discriminator_loss_criterion,
            schedule=False
        ):
        self.generator_optimizer = generator_optimizer
        self.discriminator_optimizer = discriminator_optimizer
        self.generator_loss_criterion = generator_loss_criterion
        self.discriminator_loss_criterion = discriminator_loss_criterion
        self.schedule = schedule
        if schedule:
            self.generator_optimizer_scheduler = ExponentialLR(self.generator_optimizer, gamma=0.1, last_epoch=-1, verbose=False)
            self.discriminator_optimizer_scheduler = ExponentialLR(self.discriminator_optimizer, gamma=0.1, last_epoch=-1, verbose=False)


    def train_generator(self, noise, batch_size):
        self.generator_optimizer.zero_grad()
        
        #random_labels = torch.zeros(batch_size, self.n_classes, 1, 1)
        random_labels = torch.randint(0, 1, (batch_size,)).to(device)

        generated_output = self(noise, random_labels.int().to(device))
        fake_output = self.forward_discriminator(generated_output, random_labels.int())

        # Calculate loss
        #generator_labels = torch.ones(batch_size).float().to(device)
        generator_labels = torch.from_numpy(np.array([0.9] * batch_size)).float().to(device)
        generator_loss = self.generator_loss_criterion(fake_output.squeeze(), generator_labels)

        # Update gradients
        generator_loss.backward()

        # Gradient clipping (exploding gradient)
        torch.nn.utils.clip_grad_norm_(self.get_generator_parameters(), 1)

        self.generator_optimizer.step()
        if self.schedule:
            self.generator_optimizer_scheduler.step()

        g_loss = generator_loss.item()

        return g_loss
    
    def train_discriminator(self, X, labels, noise, batch_size):
        self.discriminator_optimizer.zero_grad()
        
        #random_labels = torch.zeros(batch_size, self.n_classes, 1, 1)
        random_labels = torch.randint(0, 1, (batch_size,)).to(device)

        generated_output = self(noise, random_labels.int()).detach()
        fake_output = self.forward_discriminator(generated_output + torch.randn(tensor.size()) * 1.0 + 0, random_labels.int())
        
        flip_indices = torch.randperm(len(labels))[:10]
        labels[flip_indices] = labels[flip_indices] * (-1) + 1
        
        real_output = self.forward_discriminator(X + torch.randn(tensor.size()) * 1.0 + 0, labels.int())

        # Calc losses
        fake_labels = torch.from_numpy(np.array([0.1] * batch_size)).float().to(device)
        real_labels = torch.from_numpy(np.array([0.9] * batch_size)).float().to(device)
        real_labels[flip_indices] = 0.1
        
        #discriminator_fake_loss = self.discriminator_loss_criterion(fake_output.squeeze(), torch.zeros(batch_size).float().to(device))
        #discriminator_real_loss = self.discriminator_loss_criterion(real_output.squeeze(), torch.ones(batch_size).float().to(device))
        discriminator_fake_loss = self.discriminator_loss_criterion(fake_output.squeeze(), fake_labels)
        discriminator_real_loss = self.discriminator_loss_criterion(real_output.squeeze(), real_labels)
        
        discriminator_loss = discriminator_real_loss + discriminator_fake_loss

        # Update gradients
        discriminator_loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(self.get_discriminator_parameters(), 1)

        self.discriminator_optimizer.step()
        if self.schedule:
            self.discriminator_optimizer_scheduler.step()

        d_loss = discriminator_loss.item()

        return d_loss

    def fit(self, X, epochs=10, batch_size=64, latent_dim=100, n_disc=1):
        n_batches = len(X)
        batch_print_step = int(n_batches / 10)
        print("Training starting....")
        disp_noise = torch.from_numpy(np.random.normal(0, 1, (4, latent_dim))).float().to(device)
        disp_labels = torch.from_numpy(np.array([1, 1, 0, 0])).int().to(device)
        

        for epoch in range(epochs):
            g_loss = 0
            d_loss = 0

            print(f"Epoch {epoch}/{epochs}: ", end="")
            for index, (batch, labels) in enumerate(X):
                batch = batch.to(device)
                noise = torch.from_numpy(np.random.normal(0, 1, (batch_size, latent_dim))).float().to(device)

                self.train_discriminator(batch, labels.to(device), noise, batch_size)
                d_loss += self.train_discriminator(batch, labels.to(device), noise, batch_size)
                    
                g_loss += self.train_generator(noise.to(device), batch_size)

                if index % batch_print_step == 0:
                    print("#", end="")

            g_loss /= n_batches
            d_loss /= n_batches
            
            self.g_losses.append(g_loss)
            self.d_losses.append(d_loss)

            print(f"\nGenerator loss: {g_loss}  Discriminator loss: {d_loss}")
            
            
            with torch.no_grad():
                fig, axs = plt.subplots(1, 4)
                fig.set_figwidth(16)
                fig.set_figheight(4)

                out = self(disp_noise.float(), disp_labels.int())
                print(out.shape)
                #for ax in axs:
                axs[0].imshow(np.array(F.to_pil_image(out[0] * 0.5 + 0.5)))
                axs[1].imshow(np.array(F.to_pil_image(out[1] * 0.5 + 0.5)))
                axs[2].imshow(np.array(F.to_pil_image(out[2] * 0.5 + 0.5)))
                axs[3].imshow(np.array(F.to_pil_image(out[3] * 0.5 + 0.5)))
                plt.show()

                img = F.to_pil_image(out[0] * 0.5 + 0.5)
                #img.save(f'/content/drive/MyDrive/img_outs/fggan_{epoch}.jpg')
            
            torch.save(self.state_dict(), './cfggan_tmp.pt')
            if (epoch+1) % 5 == 0:
                torch.save(self.state_dict(), f'./cfggan_epoch_{epoch+1}.pt')
                with open(f'./cfggan_losses_{epoch+1}.npy', 'wb') as f:
                    np.save(f, np.array([self.g_losses, self.d_losses]))
            
    def weights_init_normal(self, m):
        classname = m.__class__.__name__
        # Apply initial weights to convolutional and linear layers
        if classname.find('Conv') != -1 or classname.find('Linear') != -1:
            nn.init.normal_(m.weight.data, 0.0,0.02)
        if isinstance(m, nn.Embedding):
            m.weight.data.normal_(mean=0.0, std=0.02)
            if m.padding_idx is not None:
                m.weight.data[m.padding_idx].zero_()
        if classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

        return m


## CGAN 2

In [None]:
from torch.optim.lr_scheduler import LinearLR

class cFGGAN(nn.Module):
    def __init__(self, generator, discriminator, n_classes=2, n_embbed=32, **kwargs):
        super().__init__()
        self.generator = generator
        self.discriminator = discriminator
        self.n_classes = n_classes
        self.n_embbed = n_embbed

        self.generator_conditional_head = nn.Embedding(self.n_classes, self.n_embbed).to(device)
        self.linear_gen = nn.Sequential(
            nn.ConvTranspose2d(self.generator.latent_dim+n_embbed, 512, 4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2)
        ).to(device)
        self.discriminator_conditional_head = nn.Embedding(self.n_classes, self.n_embbed).to(device)
        self.linear_disc = nn.Sequential(
            nn.Linear(1024+self.n_embbed, 100),
            nn.LeakyReLU(0.2),
            nn.Linear(100, 1),
            nn.Sigmoid()
        ).to(device)

        self.generator_conditional_head.apply(self.weights_init_normal)
        self.linear_gen.apply(self.weights_init_normal)
        self.discriminator_conditional_head.apply(self.weights_init_normal)
        self.linear_disc.apply(self.weights_init_normal)
        #.generator.apply(self.weights_init_normal)
        #self.discriminator.apply(self.weights_init_normal)
        
        self.d_losses = []
        self.g_losses = []
        self.schedule = False

    def __conditional_head(self):
        return nn.Sequential(
            nn.Embedding(self.n_classes, 25),
            nn.Linear(25, 128*128*1)
        )

    def get_generator_parameters(self):
        return list(self.generator_conditional_head.parameters()) + \
               list(self.linear_gen.parameters()) + \
               list(self.generator.upsampling_block2.parameters()) + \
               list(self.generator.upsampling_block3.parameters()) + \
               list(self.generator.upsampling_block4.parameters()) + \
               list(self.generator.upsampling_block5.parameters()) + \
               list(self.generator.out_layer.parameters())
        

    def get_discriminator_parameters(self):
        return list(self.discriminator_conditional_head.parameters()) + \
               list(self.discriminator.conv_block1.parameters()) + \
               list(self.discriminator.conv_block2.parameters()) + \
               list(self.discriminator.conv_block3.parameters()) + \
               list(self.discriminator.conv_block4.parameters()) + \
               list(self.discriminator.conv_block5.parameters()) + \
               list(self.linear_disc.parameters())

    def forward(self, features, class_):
        # Conditional input
        out = self.generator_conditional_head(class_).view(len(class_), self.n_embbed, 1, 1)
        # GAN input
        out_linear = self.linear_gen(torch.cat((features, out), 1))
        # Upsampling group
        out_upsampling = self.generator.upsampling_block2(out_linear) 
        out_upsampling = self.generator.upsampling_block3(out_upsampling)
        out_upsampling = self.generator.upsampling_block4(out_upsampling)
        out_upsampling = self.generator.upsampling_block5(out_upsampling)
        output = self.generator.out_layer(out_upsampling)

        return output

    def forward_discriminator(self, features, class_):
        # Conditional input
        out = self.discriminator_conditional_head(class_)
        
        out_conv = self.discriminator.conv_block1(features)
        out_conv = self.discriminator.conv_block2(out_conv)
        out_conv = self.discriminator.conv_block3(out_conv)
        out_conv = self.discriminator.conv_block4(out_conv)
        out_conv = self.discriminator.conv_block5(out_conv)
        
        flattened = out_conv.reshape(out_conv.size(0), -1)
        
        output = self.linear_disc(torch.cat((flattened, out), 1))
        
        return output # self.discriminator(conditioned_out)
    
    def compile(self,
            generator_optimizer,
            discriminator_optimizer,
            generator_loss_criterion,
            discriminator_loss_criterion,
            schedule=False
        ):
        self.generator_optimizer = generator_optimizer
        self.discriminator_optimizer = discriminator_optimizer
        self.generator_loss_criterion = generator_loss_criterion
        self.discriminator_loss_criterion = discriminator_loss_criterion
        self.schedule = schedule
        if schedule:
            self.generator_optimizer_scheduler = ExponentialLR(self.generator_optimizer, gamma=0.1, last_epoch=-1, verbose=False)
            self.discriminator_optimizer_scheduler = ExponentialLR(self.discriminator_optimizer, gamma=0.1, last_epoch=-1, verbose=False)


    def train_generator(self, noise, batch_size):
        self.generator_optimizer.zero_grad()
        
        #random_labels = torch.zeros(batch_size, self.n_classes, 1, 1)
        random_labels = torch.randint(0, 2, (batch_size,)).to(device)
        generated_output = self(noise, random_labels.int().to(device))
        fake_output = self.forward_discriminator(generated_output, random_labels.int())

        # Calculate loss
        #generator_labels = torch.ones(batch_size).float().to(device)
        generator_labels = torch.from_numpy(np.array([0.9] * batch_size)).float().to(device)
        generator_loss = self.generator_loss_criterion(fake_output.squeeze(), generator_labels)

        # Update gradients
        generator_loss.backward()

        # Gradient clipping (exploding gradient)
        torch.nn.utils.clip_grad_norm_(self.get_generator_parameters(), 1)

        self.generator_optimizer.step()
        if self.schedule:
            self.generator_optimizer_scheduler.step()

        g_loss = generator_loss.item()

        return g_loss
    
    def train_discriminator(self, X, labels, noise, batch_size):
        self.discriminator_optimizer.zero_grad()
        
        #random_labels = torch.zeros(batch_size, self.n_classes, 1, 1)
        random_labels = torch.randint(0, 2, (batch_size,)).to(device)

        generated_output = self(noise, random_labels.int()).detach()
        fake_output = self.forward_discriminator(generated_output, random_labels.int())
        
        flip_indices = torch.randperm(len(labels))[:20]
        labels[flip_indices] = labels[flip_indices] * (-1) + 1
        
        real_output = self.forward_discriminator(X, labels.int())

        # Calc losses
        fake_labels = torch.from_numpy(np.array([0.1] * batch_size)).float().to(device)
        real_labels = torch.from_numpy(np.array([0.9] * batch_size)).float().to(device)
        real_labels[flip_indices] = 0.1
        
        #discriminator_fake_loss = self.discriminator_loss_criterion(fake_output.squeeze(), torch.zeros(batch_size).float().to(device))
        #discriminator_real_loss = self.discriminator_loss_criterion(real_output.squeeze(), torch.ones(batch_size).float().to(device))
        discriminator_fake_loss = self.discriminator_loss_criterion(fake_output.squeeze(), fake_labels)
        discriminator_real_loss = self.discriminator_loss_criterion(real_output.squeeze(), real_labels)
        
        discriminator_loss = discriminator_real_loss + discriminator_fake_loss

        # Update gradients
        discriminator_loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(self.get_discriminator_parameters(), 1)

        self.discriminator_optimizer.step()
        if self.schedule:
            self.discriminator_optimizer_scheduler.step()

        d_loss = discriminator_loss.item()

        return d_loss

    def fit(self, X, epochs=10, batch_size=64, latent_dim=100, n_disc=1, n_gen=1):
        n_batches = len(X)
        batch_print_step = int(n_batches / 10)
        print("Training starting....", batch_size)
        #disp_noise = torch.from_numpy(np.random.normal(0, 1, (4, latent_dim))).float().to(device)
        disp_noise = torch.randn(4, latent_dim, 1, 1, device=device).float()
        disp_labels = torch.from_numpy(np.array([1, 1, 0, 0])).int().to(device)

        for epoch in range(epochs):
            g_loss = 0
            d_loss = 0

            print(f"Epoch {epoch}/{epochs}: ", end="")
            for index, (batch, labels) in enumerate(X):
                batch = batch.to(device)
                #noise = torch.from_numpy(np.random.normal(0, 1, (batch_size, latent_dim))).float().to(device)
                noise = torch.randn(batch_size, latent_dim, 1, 1, device=device).float()

                #self.train_discriminator(batch, labels.to(device), noise, batch_size)
                for i in range(n_disc):
                    d_loss_tmp = self.train_discriminator(batch, labels.to(device), noise, batch_size)
                d_loss += d_loss_tmp
                
                for i in range(n_gen):
                    g_loss_tmp = self.train_generator(noise.to(device), batch_size)
                g_loss += g_loss_tmp
                    
                if index % batch_print_step == 0:
                    print("#", end="")

            g_loss /= n_batches
            d_loss /= n_batches
            
            self.g_losses.append(g_loss)
            self.d_losses.append(d_loss)

            print(f"\nGenerator loss: {g_loss}  Discriminator loss: {d_loss}")
            
            
            with torch.no_grad():
                fig, axs = plt.subplots(1, 4)
                fig.set_figwidth(16)
                fig.set_figheight(4)
                out = self(disp_noise.float(), disp_labels.int())
                print(out.shape)
                #for ax in axs:
                axs[0].imshow(np.array(F.to_pil_image(out[0] * 0.5 + 0.5)))
                axs[1].imshow(np.array(F.to_pil_image(out[1] * 0.5 + 0.5)))
                axs[2].imshow(np.array(F.to_pil_image(out[2] * 0.5 + 0.5)))
                axs[3].imshow(np.array(F.to_pil_image(out[3] * 0.5 + 0.5)))
                plt.show()

                img = F.to_pil_image(out[0] * 0.5 + 0.5)
                #img.save(f'/content/drive/MyDrive/img_outs/fggan_{epoch}.jpg')
            
            torch.save(self.state_dict(), './cfggan_tmp.pt')
            if (epoch+1) % 5 == 0:
                torch.save(self.state_dict(), f'./cfggan_epoch_{epoch+1}.pt')
                with open(f'./cfggan_losses_{epoch+1}.npy', 'wb') as f:
                    np.save(f, np.array([self.g_losses, self.d_losses]))
            
    def weights_init_normal(self, m):
        classname = m.__class__.__name__
        # Apply initial weights to convolutional and linear layers
        if classname.find('Conv') != -1 or classname.find('Linear') != -1:
            nn.init.normal_(m.weight.data, 0.0,0.02)
        if isinstance(m, nn.Embedding):
            m.weight.data.normal_(mean=0.0, std=0.02)
            if m.padding_idx is not None:
                m.weight.data[m.padding_idx].zero_()
        return m


# Train

In [None]:
noise = np.random.normal(0,1,(1, 100))
noise = torch.from_numpy(noise).to(device)
latent_dim = 100
noise = torch.randn(1, latent_dim, 1, 1, device=device).float()

generator = Generator(100).to(device)
discriminator = Discriminator().to(device)
fggan = FGGAN(generator, discriminator)

# Load the saved model weights
#fggan.load_state_dict(torch.load('/kaggle/input/ap-cfggan-ds/fggan_tmp(4).pt'))

generator_optimizer = torch.optim.Adam(
    generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
discriminator_optimizer = torch.optim.Adam(
    discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))

generator_criterion = nn.BCELoss().to(device)
discriminator_criterion = nn.BCELoss().to(device)

fggan.compile(generator_optimizer, discriminator_optimizer,
              generator_criterion, discriminator_criterion)

out = fggan(noise.float().to(device))
print(out.shape)

plt.imshow(out[0].permute(1, 2, 0).cpu().detach().numpy()*0.5+0.5)
plt.show()

fggan.fit(dataloader_faces_only, epochs=10, batch_size=batch_size, n_gen=2)

In [None]:
def plot_generated_images(square = 5, epochs = 0): 
  plt.figure(figsize = (15,15))
  for i in range(square * square):
    plt.subplot(square, square, i+1)
    noise = torch.randn(1, 100, 1, 1, device=device).float()
    #img = fggan(noise)[0].permute(1, 2, 0).cpu().detach().numpy()
    img = cfggan(noise, torch.from_numpy(np.array([epochs])).to(device))[0].cpu().detach()
    plt.imshow(F.to_pil_image(img*0.5 + 0.5 ))
    
    plt.xticks([])
    plt.yticks([])
    plt.grid()
plot_generated_images(6, 0)

In [None]:
plot_generated_images(6, 1)

In [None]:
def plot_generated_images(square = 5, epochs = 0): 
    plt.figure(figsize = (10,10))
    for i in range(square * square):
        plt.subplot(square, square, i+1)
        noise = torch.randn(1, 100, 1, 1, device=device).float()
        #img = fggan(noise)[0].permute(1, 2, 0).cpu().detach().numpy()
        img = fggan(noise).to(device)[0].cpu().detach()
        plt.imshow(F.to_pil_image(img*0.5 + 0.5 ))

        plt.xticks([])
        plt.yticks([])
        plt.grid()
plot_generated_images(6)

In [None]:
del generator
del discriminator
del fggan
torch.cuda.empty_cache()
import gc
gc.collect()

DATASET STUFF