In [None]:
import sys
root = '/blue/prabhat/parvath.harikris/gen-ai-bias/boosting'
sys.path.append(root)
import fid_custom
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from tqdm import tqdm
import torchvision.datasets as datasets
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
import torchvision.utils as vutils
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torch.utils.data import Subset
import numpy as np

# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-4
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 3
Z_DIM = 100
NUM_EPOCHS = 50
FEATURES_DISC = 64
FEATURES_GEN = 64

# Dataset and Dataloader
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = datasets.ImageFolder("/blue/prabhat/parvath.harikris/gen-ai-bias/boosting/biased_datasets/Smiling_Young_Male_99_Not_Male_1_N15000/images", transform=transform)
dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=torch.get_num_threads(), pin_memory=True)

# Visualize a batch of training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(7,7))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)));

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import Inception_V3_Weights

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pre-trained Inception v3 model
inception = models.inception_v3(weights=Inception_V3_Weights.DEFAULT)
inception.eval().to(device)  # Set to evaluation mode

# Hook to extract features from Mixed_6e
class FeatureExtractor(nn.Module):
    def __init__(self, model):
        super(FeatureExtractor, self).__init__()
        self.features = nn.Sequential(*list(model.children())[:14])  # Up to Mixed_6e

    def forward(self, x):
        return self.features(x)

# Create the feature extractor
inception_mixed6e = FeatureExtractor(inception).to(device)


In [None]:
# Generator Network
class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            self._block(z_dim, features_g * 16, 4, 1, 0),
            self._block(features_g * 16, features_g * 8, 4, 2, 1),
            self._block(features_g * 8, features_g * 4, 4, 2, 1),
            self._block(features_g * 4, features_g * 2, 4, 2, 1),
            nn.ConvTranspose2d(features_g * 2, channels_img, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        return self.net(x)

# Discriminator Network
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            self._block(features_d, features_d * 2, 4, 2, 1),
            self._block(features_d * 2, features_d * 4, 4, 2, 1),
            self._block(features_d * 4, features_d * 8, 4, 2, 1),
            nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid()
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        return self.disc(x)

# Weight Initialization
def initialize_weights(model):
    classname = model.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(model.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(model.weight.data, 1.0, 0.02)
        nn.init.constant_(model.bias.data, 0)

# Initialize models
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)

initialize_weights(gen)
initialize_weights(disc)

# Optimizers and Loss Function
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
criterion = nn.BCELoss()

# Fixed Noise for Visualization
fixed_noise = torch.randn(64, Z_DIM, 1, 1).to(device)

# Define perceptual loss function
def perceptual_loss(fake, real):
    with torch.no_grad():
        real_features = inception_mixed6e(real)
        fake_features = inception_mixed6e(fake)
    return torch.nn.functional.mse_loss(fake_features, real_features)

LAMBDA_PERC = 0.1

# Training Loop
for epoch in range(NUM_EPOCHS):
    loop = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch [{epoch}/{NUM_EPOCHS}]")
    for batch_idx, (real, _) in loop:
        real = real.to(device)
        batch_size = real.shape[0]
        
        # Labels for real and fake images
        real_labels = torch.ones(batch_size, 1, 1, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1, 1, 1).to(device)
        
        # Train Discriminator
        opt_disc.zero_grad()
        
        # Real images
        real_pred = disc(real)
        real_loss = criterion(real_pred, real_labels)
        
        # Fake images
        noise = torch.randn(batch_size, Z_DIM, 1, 1).to(device)
        fake = gen(noise)
        fake_pred = disc(fake.detach())  
        fake_loss = criterion(fake_pred, fake_labels)
        
        disc_loss = real_loss + fake_loss
        disc_loss.backward()
        opt_disc.step()
        
        # Train Generator
        opt_gen.zero_grad()
        
        fake_pred = disc(fake)
        adv_loss = criterion(fake_pred, real_labels)  # Adversarial loss
        perc_loss = perceptual_loss(fake, real)  # Perceptual loss
        
        gen_loss = adv_loss + LAMBDA_PERC * perc_loss  # Weighted sum
        gen_loss.backward()
        opt_gen.step()
        
        loop.set_postfix({
            "D_loss": f"{disc_loss.item():.4f}",
            "G_loss": f"{gen_loss.item():.4f}",
            "perc_loss": f"{perc_loss.item():.4f}"
        })

    # Generate Images for Visualization
    with torch.no_grad():
        fake_images = gen(fixed_noise).detach().cpu()
        img_grid = make_grid(fake_images, nrow=8, normalize=True)
        plt.figure(figsize=(8,8))
        plt.imshow(np.transpose(img_grid, (1, 2, 0)))
        plt.axis('off')
        plt.show()



In [None]:
# Generate a random noise vector
random_noise = torch.randn(1, Z_DIM, 1, 1).to(device)

# Generate a random image
with torch.no_grad():
    random_image = gen(random_noise).detach().cpu()

# Convert the tensor to a grid and display the image
img_grid = make_grid(random_image, nrow=1, normalize=True)
plt.figure(figsize=(2,2))
plt.imshow(np.transpose(img_grid, (1, 2, 0)))
plt.axis('off')
plt.show()


In [None]:
# Directory to save images
n = 5000
save_dir = f"generated_images_gan_perceptual_{n}"
os.makedirs(save_dir, exist_ok=True)

gen.eval()  # Set to evaluation mode
with torch.no_grad():
    for i in range(n):
        noise = torch.randn(1, Z_DIM, 1, 1).to(device)  # Generate random noise
        fake_image = gen(noise)  # Generate image
        save_image(fake_image, os.path.join(save_dir, f"generated_{i+1}.png"), normalize=True)

fid_custom.extract_and_save_features(save_dir, f'features/gan_prec_{n}')