In [1]:
import torch
from torch import nn

from Generator import Generator
from Discriminator import ProjectedDiscriminator

from dataset_creation import StyleDataset
from torch.utils.data import DataLoader, Dataset

from tqdm import tqdm
from helper import fetch_data
from loss import ProjectedGANLoss

In [2]:
Z_DIM          = 64
C_DIM          = 768
BATCH_SIZE     = 1
EPOCHS         = 3
DEVICE         = "cpu"
IMG_RESOLUTION = 128
CLIP_wWEIGHT   = 0.2

In [3]:
generator     = Generator(z_dim = Z_DIM, conditional = True, img_resolution = IMG_RESOLUTION)
discriminator = ProjectedDiscriminator(c_dim = C_DIM)

discriminator.name = "D"
generator.name     = "G"

discriminator.opt = torch.optim.Adam(generator.parameters(), lr = 0.002, betas=[0, 0.99])
generator.opt = torch.optim.Adam(discriminator.parameters(), lr = 0.002, betas=[0, 0.99])

  model = create_fn(


In [4]:
def partial_freeze(gen_or_disc: nn.Module) -> None:
    phase = gen_or_disc.name

    if phase == "G":

        trainable_layers = gen_or_disc.trainable_layers
        # Freeze all layers first
        gen_or_disc.requires_grad_(False)

        # Then selectively unfreeze based on substring match
        for name, layer in gen_or_disc.named_modules():
            should_train = any(layer_type in name for layer_type in trainable_layers)
            layer.requires_grad_(should_train)
    
    elif phase == "D":
        gen_or_disc.dino.requires_grad_(False)
    
    else: raise NotImplemented
    

In [5]:
cudnn_benchmark = True

torch.backends.cudnn.benchmark = cudnn_benchmark    # Improves training speed.
torch.backends.cuda.matmul.allow_tf32 = False       # Improves numerical accuracy.
torch.backends.cudnn.allow_tf32 = False             # Improves numerical accuracy.
# conv2d_gradfix.enabled = True                       # Improves training speed.

In [6]:
# EMA gives more weight to recent values but still considers past history. 
# It’s like a “soft average” that forgets old data slowly.
# EMA(t) = beta . EMA(t) - 1 + (1 - beta) . xt

#  beta in [0, 1): decay rate (e.g., 0.99 or 0.999)
#  xt            : the current value (e.g., a model parameter or loss)
#  EMA(t)        : the new smoothed value
#  EMA(t-1)      : the previous smoothed value

In [7]:
loss = ProjectedGANLoss(G = generator, 
                        D = discriminator, 
                        blur_fade_kimg = 0.1, # after 100 image there will be 0 Blur,
                        clip_weight = CLIP_wWEIGHT,
                        device = DEVICE
                        )

In [8]:
sd = StyleDataset(path="/Users/mohamedmafaz/Desktop/StyleGAN-T/notebooks/Networks/dataset/", resolution=224)
sdl = DataLoader(sd, BATCH_SIZE, shuffle=True)

In [9]:
optimizer_gen = torch.optim.Adam(generator.parameters(), lr = 0.002, betas=[0, 0.99])
optimizer_dis = torch.optim.Adam(discriminator.parameters(), lr = 0.002, betas=[0, 0.99])

In [10]:
for epoch in tqdm(range(EPOCHS)):
    chance = 0
    cur_nimg = 0
    for real_images, real_labels in sdl:
        real_images = (real_images/(255/2)) - 1           # normalizing images from [-1, 1]
        batch_size = real_images.shape[0]
        all_gen_z = torch.randn([batch_size, Z_DIM], device = DEVICE)

        if chance % 2 == 0:
            phase = discriminator
        else: phase = generator

        # Train Discriminator and Generator
        phase.requires_grad_(True)
        partial_freeze(phase)
        loss.accumulate_gradients(phase = phase.name, cur_nimg = cur_nimg, real_imgs = real_images, c_raw = real_labels, gen_z = all_gen_z, verbose = False)
        
        training_stats = loss.training_stats
        if phase.name == "G": print("Generator Status")
        if phase.name == "D": print("Discriminator Status")
        print('-'*20)
        for key in training_stats:
            print(f"{key}: {training_stats[key]}", end = " || ")
        
        phase.opt.step()
        phase.opt.zero_grad()

        phase.requires_grad_(False)

        chance += 1
        cur_nimg += batch_size

  return torch._native_multi_head_attention(
  0%|          | 0/3 [00:03<?, ?it/s]


TypeError: ProjectedGANLoss.blur() got multiple values for argument 'blur_sigma'

In [None]:
generator.name.upper()

'G'

In [None]:
training_stats = {
            "Generator Loss"       : 4,
            "CLIP Loss"            : 5,
            "Generator Total Loss" : 6
        }
print("Discriminator Status")
print('-'*20)
for key in training_stats:
    print(f"{key}: {training_stats[key]}", end = " || ")

Discriminator Status
--------------------
Generator Loss: 4 || CLIP Loss: 5 || Generator Total Loss: 6 || 