In [None]:
import torch
import random
import numpy as np
import os
import torchvision
import torch.nn as nn
import torch.nn.functional as F

from torchvision.utils import save_image
import matplotlib.pyplot as plt
from math import log2


### Model 

In [None]:

factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]


class WSConv2d(nn.Module):


    def __init__(
        self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2
    ):
        super(WSConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        # initialize conv layer
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)

class WSLinear(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(WSLinear,self).__init__()
        self.lin=nn.Linear(in_channels, out_channels)
        self.bias=self.lin.bias
        self.lin.bias=None
        
        nn.init.normal_(self.lin.weight,mean=0.0, std=0.01)
        nn.init.zeros_(self.bias)
    
    def forward(self,x):
        return self.lin(x)+self.bias.view(1, self.bias.shape[0])
    
        
    
class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_pixelnorm=True):
        super(ConvBlock, self).__init__()
        self.use_pn = use_pixelnorm
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)
        self.pn = PixelNorm()

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.pn(x) if self.use_pn else x
        x = self.leaky(self.conv2(x))
        x = self.pn(x) if self.use_pn else x
        return x

class LinearBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_pixelnorm=True):
        super(LinearBlock,self).__init__()
        self.lin=WSLinear(in_channels,out_channels)
        self.relu=nn.ReLU()
        self.pn=PixelNorm()
        self.use_pn=use_pixelnorm
    
    def forward(self,x):
        x = self.relu(self.lin(x))
        x = self.pn(x) if self.use_pn else x
        return x
        


class Generator(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=3):
        super(Generator, self).__init__()

        # initial takes 1x1 -> 4x4
        self.initial = nn.Sequential(
            PixelNorm(),
            nn.ConvTranspose2d(z_dim, in_channels, 4, 1, 0),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            PixelNorm(),
        )

        self.initial_rgb = WSConv2d(
            in_channels, img_channels, kernel_size=1, stride=1, padding=0
        )
        self.prog_blocks, self.rgb_layers = (
            nn.ModuleList([]),
            nn.ModuleList([self.initial_rgb]),
        )

        for i in range(
            len(factors) - 1
        ):  
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i + 1])
            self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c))
            self.rgb_layers.append(
                WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
            )

    def fade_in(self, alpha, upscaled, generated):
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)

    def forward(self, x, alpha, steps):
        out = self.initial(x)

        if steps == 0:
            return self.initial_rgb(out)

        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode="nearest")
            out = self.prog_blocks[step](upscaled)
        final_upscaled = self.rgb_layers[steps - 1](upscaled)
        final_out = self.rgb_layers[steps](out)
        return self.fade_in(alpha, final_upscaled, final_out)


class Discriminator(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=3):
        super(Discriminator, self).__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)

        for i in range(len(factors) - 1, 0, -1):
            conv_in = int(in_channels * factors[i])
            conv_out = int(in_channels * factors[i - 1])
            self.prog_blocks.append(ConvBlock(conv_in, conv_out, use_pixelnorm=False))
            self.rgb_layers.append(
                WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0)
            )

        self.initial_rgb = WSConv2d(
            img_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.rgb_layers.append(self.initial_rgb)
        self.avg_pool = nn.AvgPool2d(
            kernel_size=2, stride=2
        )  
        self.final_block = nn.Sequential(
           
            WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
            nn.LeakyReLU(0.2),
            WSConv2d(
                in_channels, 1, kernel_size=1, padding=0, stride=1
            ),  
        )

    def fade_in(self, alpha, downscaled, out):
        """Used to fade in downscaled using avg pooling and output from CNN"""
        return alpha * out + (1 - alpha) * downscaled

    def minibatch_std(self, x):
        batch_statistics = (
            torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        )

        return torch.cat([x, batch_statistics], dim=1)

    def forward(self, x, alpha, steps):

        cur_step = len(self.prog_blocks) - steps
        out = self.leaky(self.rgb_layers[cur_step](x))

        if steps == 0:  # i.e, image is 4x4
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)

        downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_step](out))

     
        out = self.fade_in(alpha, downscaled, out)

        for step in range(cur_step + 1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)

        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)


class VariationalMiner(nn.Module):
    def __init__(self, z_dim):
        super(VariationalMiner, self).__init__()
        self.l_block1= LinearBlock(z_dim, z_dim*2, use_pixelnorm=True)
        self.l_block2= LinearBlock(z_dim*2, z_dim*2, use_pixelnorm=True)
        self.l_block3= LinearBlock(z_dim*2, z_dim, use_pixelnorm=True)
        self.fc_mu = LinearBlock(z_dim, z_dim, use_pixelnorm=False)  # Mean
        self.fc_logvar = LinearBlock(z_dim, z_dim, use_pixelnorm=False)  # Log-variance
    
    def reparameterize(self, mu, logvar):
        # Reparameterization trick
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)  # Random noise
        return mu + eps * std
    
    def forward(self, x):
        x=self.l_block1(x)
        x=self.l_block2(x)
        x=self.l_block3(x)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        
        # Sample latent vector using reparameterization trick
        z = self.reparameterize(mu, logvar)
        
         # Reshape for generator input
        z = torch.unsqueeze(z, -1)
        z = torch.unsqueeze(z, -1)
        
        return z, mu, logvar

        
        
Z_DIM = 256
IN_CHANNELS = 256
gen = Generator(Z_DIM, IN_CHANNELS, img_channels=3)
critic = Discriminator(Z_DIM, IN_CHANNELS, img_channels=3)
miner=VariationalMiner(Z_DIM)

for img_size in [4, 8, 16, 32, 64, 128, 256, 512, 1024]:
    num_steps = int(log2(img_size / 4))
    x = torch.randn((1, Z_DIM))
    y,_,_ = miner(x)
    z = gen(y, 0.5, steps=num_steps)
    assert z.shape == (1, 3, img_size, img_size)
    out = critic(z, alpha=0.5, steps=num_steps)
    assert out.shape == (1, 1)
    print(f"Success! At img size: {img_size}")


### Helper functions

In [2]:

def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images, alpha, train_step)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty


def save_checkpoint(model, optimizer, filename):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location="cuda")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

    
def plot_images(images, epoch):
    images = images.cpu().detach().numpy().transpose(0, 2, 3, 1)
    fig, axes = plt.subplots(1, len(images), figsize=(15, 15))
    for i, img in enumerate(images):
        axes[i].imshow(img)
        axes[i].axis("off")
    plt.show()

import os
import shutil

def save_real_images(loader, output_dir="real_images", num_images=500):
    """Save real images from DataLoader to a specified folder."""
    shutil.rmtree(output_dir, ignore_errors=True) 
    os.makedirs(output_dir, exist_ok=True)
    
    num_saved = 0
    for i, (real_images, _) in enumerate(loader):
        for j, real_image in enumerate(real_images):
            if num_saved >= num_images:
                return
            save_image(real_image * 0.5 + 0.5, f"{output_dir}/real_{num_saved}.png")
            num_saved += 1
            
def save_fake_images(output_dir="fake_images", num_images=500):
    """Generate and save fake images from the generator to a specified folder."""

    shutil.rmtree(output_dir, ignore_errors=True)    
    os.makedirs(output_dir, exist_ok=True)
    
    
    gen = Generator(config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG).to(config.DEVICE)
    gen_checkpoint = torch.load("/kaggle/working/generator_finetuned_st2.pth")
    gen.load_state_dict(gen_checkpoint["state_dict"])
    gen.eval()
    
    miner= VariationalMiner(config.Z_DIM).to(config.DEVICE)
    miner_checkpoint = torch.load("/kaggle/working/miner_finetuned_st2.pth")
    miner.load_state_dict(miner_checkpoint["state_dict"])
    miner.eval()

    with torch.no_grad():
            NOISE = torch.randn(1, config.Z_DIM).to(config.DEVICE)
            _, mu, logvar = miner(NOISE)      
            std = torch.exp(0.5 * logvar)
        for i in range(num_images):
            epsilon = torch.randn_like(std)
            z = mu + epsilon * std
            z_reshaped = z.view(-1, config.IN_CHANNELS, 1, 1) 
            image_fakes = gen(z_reshaped, alpha=1, steps=config.STEP) * 0.5 + 0.5  # Scale to [0, 1]  
#             print(image_fakes.shape)
            save_image(image_fakes, f"{output_dir}/fake_{i}.png")

### Configuration

In [3]:
import cv2
import torch
from math import log2

class configuration():
    def __init__(self):
        self.STEP = 4
        self.DATASET = '/kaggle/input/shape-dataset'
        self.DATASET_SAMPLE=800
        self.CHECKPOINT_GEN = "/kaggle/input/progan_multi/pytorch/default/1/generator.pth"
        self.CHECKPOINT_CRITIC = "/kaggle/input/progan_multi/pytorch/default/1/critic.pth"
        self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
        self.IMAGE_SIZE = 64
        self.LEARNING_RATE = 1e-3
        self.BATCH_SIZE = 32
        self.CHANNELS_IMG = 3
        self.Z_DIM = 256  
        self.IN_CHANNELS = 256 
        self.CRITIC_ITERATIONS = 1
        self.LAMBDA_GP = 10
        self.EPOCHS = 500
        self.FIXED_NOISE = torch.randn(8, self.Z_DIM).to(self.DEVICE)
        self.NUM_WORKERS = 4
        self.MINER_LAYERS=4
config=configuration()

### Data Loader

In [4]:
import torch
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset

from math import log2
from tqdm import tqdm

torch.backends.cudnn.benchmarks = True



def get_loader(image_size):
    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(
                [0.5 for _ in range(config.CHANNELS_IMG)],
                [0.5 for _ in range(config.CHANNELS_IMG)],
            ),
        ]
    )
    batch_size = config.BATCH_SIZE
    
    dataset = datasets.ImageFolder(root=config.DATASET, transform=transform)
    
    indices = random.sample(range(len(dataset)), config.DATASET_SAMPLE)
    dataset = Subset(dataset, indices)
    print(len(dataset))
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=config.NUM_WORKERS,
        pin_memory=True,
    )
    return loader, dataset

### Training Setup

In [5]:
def train_fn_stage1(
    critic,
    gen,
    miner,
    loader,
    step,
    alpha,
    opt_critic,
    opt_miner,
    scaler_miner,
    scaler_critic,
):
    '''In this stage the miner network is trained keeping the generator weights frozen'''
    loop = tqdm(loader, leave=True)
    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(config.DEVICE)
        cur_batch_size = real.shape[0]

        noise = torch.randn(cur_batch_size, config.Z_DIM).to(config.DEVICE)

        with torch.cuda.amp.autocast():
            z, _, _ = miner(noise)
            fake = gen(z, alpha, step)
            critic_real = critic(real, alpha, step)
            critic_fake = critic(fake.detach(), alpha, step)
            gp = gradient_penalty(critic, real, fake, alpha, step, device=config.DEVICE)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake))
                + config.LAMBDA_GP * gp
                + (0.001 * torch.mean(critic_real ** 2))
            )

        opt_critic.zero_grad()
        scaler_critic.scale(loss_critic).backward()
        scaler_critic.step(opt_critic)
        scaler_critic.update()

        with torch.cuda.amp.autocast():
            gen_fake = critic(fake, alpha, step)
            loss_gen = -torch.mean(gen_fake)

        opt_miner.zero_grad()
        scaler_miner.scale(loss_gen).backward()
        scaler_miner.step(opt_miner)
        scaler_miner.update()

In [6]:
def train_fn_stage2(
    critic,
    gen,
    miner,
    loader,
    step,
    alpha,
    opt_critic,
    opt_gen,
    scaler_gen,
    scaler_critic,
):
    loop = tqdm(loader, leave=True)
    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(config.DEVICE)
        cur_batch_size = real.shape[0]
        
        noise = torch.randn(cur_batch_size, config.Z_DIM).to(config.DEVICE)

        with torch.cuda.amp.autocast():
            z, _, _ = miner(noise)
            fake = gen(z, alpha, step)
            critic_real = critic(real, alpha, step)
            critic_fake = critic(fake.detach(), alpha, step)
            gp = gradient_penalty(critic, real, fake, alpha, step, device=config.DEVICE)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake))
                + config.LAMBDA_GP * gp
                + (0.001 * torch.mean(critic_real ** 2))
            )

        opt_critic.zero_grad()
        scaler_critic.scale(loss_critic).backward()
        scaler_critic.step(opt_critic)
        scaler_critic.update()

        with torch.cuda.amp.autocast():
            gen_fake = critic(fake, alpha, step)
            loss_gen = -torch.mean(gen_fake)

        opt_gen.zero_grad()
        scaler_gen.scale(loss_gen).backward()
        scaler_gen.step(opt_gen)
        scaler_gen.update()

In [None]:
import shutil
def main(): 
    shutil.rmtree("/kaggle/working/", ignore_errors=True)
    gen = Generator(config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG).to(config.DEVICE)
    critic = Discriminator(config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG).to(config.DEVICE)
    miner = VariationalMiner(config.Z_DIM).to(config.DEVICE)
    
    gen_checkpoint = torch.load(config.CHECKPOINT_GEN)
    gen.load_state_dict(gen_checkpoint["state_dict"])

    critic_checkpoint = torch.load(config.CHECKPOINT_CRITIC)
    critic.load_state_dict(critic_checkpoint["state_dict"])

    # initialize optimizers and scalers
    opt_gen = optim.Adam(list(gen.parameters())+list(miner.parameters()), lr=config.LEARNING_RATE, betas=(0.0, 0.99))
    opt_critic = optim.Adam(critic.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99))
    opt_miner = optim.Adam(miner.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99))
    
    scaler_critic = torch.cuda.amp.GradScaler()
    scaler_gen = torch.cuda.amp.GradScaler()
    scaler_miner = torch.cuda.amp.GradScaler()

    # start at step that corresponds to img size that we set in config
    step = config.STEP
    alpha = 1

    loader, dataset = get_loader(image_size=config.IMAGE_SIZE)  # Load real images

    # Save real images to the "real_images" folder for evaluation purpose
    save_real_images(loader, output_dir="real_images", num_images=config.DATASET_SAMPLE)
    print(f'real_images_saved: {len(os.listdir("real_images"))}')
    
    print("STARTING STAGE 1")

    for epoch in range(100):
        print(f"Epoch [{epoch+1}/{100}]")
        critic.train()
        miner.train()
        gen.eval() 
        train_fn_stage1(
            critic,
            gen,
            miner,
            loader,
            step,
            alpha,
            opt_critic,
            opt_miner,
            scaler_miner,
            scaler_critic,
        )
    

        if epoch % 20 == 0:
            miner.eval()
            with torch.no_grad():
                noise= torch.randn(8, config.Z_DIM).to(config.DEVICE)
                x,_,_=miner(noise)
                fixed_fakes = gen(x, alpha, step) * 0.5 + 0.5  # Scale to [0, 1]
                plot_images(fixed_fakes, epoch)
            miner.train()
            
    save_checkpoint(gen, opt_gen, filename="generator_finetuned_st1.pth")
    save_checkpoint(critic, opt_critic, filename="critic_finetuned_st1.pth")
    save_checkpoint(miner, opt_miner, filename="miner_finetuned_st1.pth")
    print("STARTING STAGE 2")        
    for epoch in range(500):
        print(f"Epoch [{epoch+1}/{500}]")
        critic.train()
        miner.train()
        gen.train()

        train_fn_stage2(
                    critic,
                    gen,
                    miner,
                    loader,
                    step,
                    alpha,
                    opt_critic,
                    opt_gen,
                    scaler_gen,
                    scaler_critic,
                )
        
        
        if epoch % 20 == 0:
            miner.eval()
            gen.eval()
            with torch.no_grad():
                noise= torch.randn(8, config.Z_DIM).to(config.DEVICE)
                x,_,_=miner(noise)
                fixed_fakes = gen(x, alpha, step) * 0.5 + 0.5  # Scale to [0, 1]
                plot_images(fixed_fakes, epoch)
            miner.train()
            gen.train()
        save_checkpoint(gen, opt_gen, filename="generator_finetuned_st2.pth")
        save_checkpoint(critic, opt_critic, filename="critic_finetuned_st2.pth")
        save_checkpoint(miner, opt_miner, filename="miner_finetuned_st2.pth")


if __name__ == "__main__":
    main()


### Eval

In [None]:
!pip -q install torch-fidelity

In [None]:
from torch_fidelity import calculate_metrics

def calculate_fid(real_images_dir, fake_images_dir):
    metrics = calculate_metrics(
        input1=real_images_dir,  # Path to the directory with real images
        input2=fake_images_dir,  # Path to the directory with generated images
        cuda=True,               # Use GPU if availabl
        isc=True, fid=True
    )
    
    return metrics

#inference
save_fake_images("fake_images", 1000)
print(f'num_images: {len(os.listdir("fake_images"))}')
metrics=calculate_fid("/kaggle/working/fake_images","/kaggle/working/real_images")

### Visualize the distribution after Mining

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

z_dim = 256  # The dimension of latent space
batch_size = 2000  # Number of points want to visualize
noise = torch.randn((batch_size, z_dim))


miner.eval() 
with torch.no_grad():
    _, mu, logvar = miner(noise)
    
    std = torch.exp(0.5 * logvar)
    
    epsilon = torch.randn_like(std)
    
    z = mu + epsilon * std
    z_reshaped = z.view(-1, config.IN_CHANNELS, 1, 1) 

# Using t-SNE for 2D visualization
tsne = TSNE(n_components=2, random_state=42)
latent_2d = tsne.fit_transform(z_reshaped)


# Plot the 2D latent space
plt.figure(figsize=(8, 6))
plt.scatter(latent_2d[:, 0], latent_2d[:, 1], c='blue', alpha=0.6)
plt.title("t-SNE Visualization of Latent Space After Passing Through Miner")
plt.xlabel("Dimension 1")
plt.ylabel("Dimension 2")
plt.show()
