In [1]:
# Full StyleGAN2-based notebook code for BTD (Brain Tumor Deepfake)
# with EMA, truncation trick, style mixing, blur regularization

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, datasets, utils
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import random
import copy
import os
from PIL import Image

In [2]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [8]:
# Hyperparameters
image_size = 64
latent_dim = 512
batch_size = 32
epochs = 500
lr = 0.0002
ema_beta = 0.999
truncation_psi = 0.7

# Dataset (assuming images in ./data/)

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # <- ensures grayscale
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

dataset = datasets.ImageFolder("./data/generation_data/Training/", transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [9]:
# Mapping network
def make_mapping_network():
    return nn.Sequential(
        nn.Linear(latent_dim, latent_dim), nn.LeakyReLU(0.2),
        nn.Linear(latent_dim, latent_dim), nn.LeakyReLU(0.2),
        nn.Linear(latent_dim, latent_dim)
    )

# StyleBlock with AdaIN
class StyleBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.norm = nn.InstanceNorm2d(out_channels)
        self.style_scale = nn.Linear(latent_dim, out_channels)
        self.style_shift = nn.Linear(latent_dim, out_channels)

    def forward(self, x, w):
        x = self.conv(x)
        x = self.norm(x)
        scale = self.style_scale(w).unsqueeze(2).unsqueeze(3)
        shift = self.style_shift(w).unsqueeze(2).unsqueeze(3)
        return x * (scale + 1) + shift

# Generator
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.initial = nn.Parameter(torch.randn(1, 512, 4, 4))
        self.mapping = make_mapping_network()
        self.layers = nn.ModuleList([
            StyleBlock(512, 512),
            StyleBlock(512, 256),
            StyleBlock(256, 128),
            StyleBlock(128, 64),
            nn.Conv2d(64, 1, 1)
        ])

    def forward(self, z1, z2=None, mixing_cutoff=None):
        styles = []
        w1 = self.mapping(z1)
        if z2 is not None:
            w2 = self.mapping(z2)
        for i in range(len(self.layers)):
            if z2 is not None and mixing_cutoff is not None and i > mixing_cutoff:
                styles.append(w2)
            else:
                styles.append(w1)

        x = self.initial.expand(styles[0].size(0), -1, -1, -1)
        for i, layer in enumerate(self.layers[:-1]):
            x = F.interpolate(x, scale_factor=2)
            x = layer(x, styles[i])
        return torch.tanh(self.layers[-1](x))

# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1), nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2),
            nn.Flatten(), nn.Linear(4*4*512, 1), nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [10]:
# Initialize models
G = Generator().to(device)
D = Discriminator().to(device)
G_ema = copy.deepcopy(G).eval()

opt_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

criterion = nn.BCELoss()
G_losses, D_losses = [], []

# Training loop
for epoch in range(epochs):
    for real, _ in dataloader:
        real = real.to(device)
        b = real.size(0)

        # Train Discriminator
        z = torch.randn(b, latent_dim).to(device)
        fake = G(z).detach()
        d_real = D(real)
        d_fake = D(fake)
        d_loss = criterion(d_real, torch.ones_like(d_real)) + criterion(d_fake, torch.zeros_like(d_fake))

        opt_D.zero_grad()
        d_loss.backward()
        opt_D.step()

        # Train Generator
        if random.random() < 0.9:
            z1 = torch.randn(b, latent_dim).to(device)
            z2 = torch.randn(b, latent_dim).to(device)
            cutoff = random.randint(1, len(G.layers)-1)
            fake = G(z1, z2, cutoff)
        else:
            z = torch.randn(b, latent_dim).to(device)
            fake = G(z)

        d_fake = D(fake)
        g_loss = criterion(d_fake, torch.ones_like(d_fake))

        opt_G.zero_grad()
        g_loss.backward()
        opt_G.step()

        # EMA update
        with torch.no_grad():
            for p, ema_p in zip(G.parameters(), G_ema.parameters()):
                ema_p.data = ema_beta * ema_p.data + (1 - ema_beta) * p.data

    G_losses.append(g_loss.item())
    D_losses.append(d_loss.item())

    print(f"Epoch [{epoch+1}/{epochs}]  D Loss: {d_loss.item():.4f}  G Loss: {g_loss.item():.4f}")

    if (epoch+1) % 10 == 0:
        z = torch.randn(64, latent_dim).to(device)
        with torch.no_grad():
            samples = G_ema(z).cpu()
        utils.save_image(samples, f"samples/epoch_{epoch+1}.png", nrow=8, normalize=True)

Epoch [1/500]  D Loss: 0.2833  G Loss: 2.6474


KeyboardInterrupt: 

In [None]:
# Loss plot
plt.plot(G_losses, label="Generator Loss")
plt.plot(D_losses, label="Discriminator Loss")
plt.legend()
plt.show()

In [None]:
# Save models
torch.save(G.state_dict(), "generator.pth")
torch.save(D.state_dict(), "discriminator.pth")