In [1]:
# ============================================================
# ✅ SETUP
# ============================================================

!git clone https://github.com/NVlabs/edm2.git /kaggle/working/edm2
%cd /kaggle/working/edm2
!pip install click tqdm psutil scipy pillow --quiet
%cd /kaggle/working/edm2

Cloning into '/kaggle/working/edm2'...
remote: Enumerating objects: 60, done.[K
remote: Counting objects: 100% (27/27), done.[K
remote: Compressing objects: 100% (17/17), done.[K
remote: Total 60 (delta 13), reused 10 (delta 10), pack-reused 33 (from 1)[K
Receiving objects: 100% (60/60), 1.27 MiB | 10.26 MiB/s, done.
Resolving deltas: 100% (24/24), done.
/kaggle/working/edm2
/kaggle/working/edm2


In [2]:
# ============================================================
# ✅ PATCH: KARRAS ρ TRAINING NOISE SCHEDULE (DEFAULT EDM2)
# ============================================================
import torch
import training.training_loop as loop

def karras_sigma(batch_size, device, sigma_min=0.002, sigma_max=80, rho=7.0):
    """
    EDM2 default sampling schedule (Karras et al. 2022):
    
    sigma(t) = (sigma_max^(1/rho) + t * (sigma_min^(1/rho) - sigma_max^(1/rho)))^rho
    where t ~ Uniform[0,1].

    This gives more samples at low noise and fewer at high noise.
    """
    t = torch.rand(batch_size, 1, 1, 1, device=device)

    s0 = sigma_max ** (1 / rho)
    s1 = sigma_min ** (1 / rho)

    sigma = (s0 + t * (s1 - s0)) ** rho
    return sigma


def patched_call(self, net, images, labels=None):
    """
    Override EDM2Loss.__call__ so training uses Karras ρ schedule.
    """
    batch = images.shape[0]

    # === KARRAS POWER-LAW NOISE ===
    sigma = karras_sigma(batch, images.device)

    # === EDM2 weighting ===
    weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data)**2

    # Add noise
    noise = torch.randn_like(images) * sigma

    # Forward pass
    denoised, logvar = net(images + noise, sigma, labels, return_logvar=True)

    # Loss formula
    loss = (weight / logvar.exp()) * ((denoised - images)**2) + logvar
    return loss


# Inject patch into EDM2
loop.EDM2Loss.__call__ = patched_call

print("✅ Karras ρ schedule injected successfully!")


✅ Karras ρ schedule injected successfully!


In [3]:
!torchrun --standalone --nproc_per_node=2 train_edm2.py \
    --outdir=/kaggle/working/training-runs/celeba64-karras-rho \
    --data=/kaggle/input/celeva-64x64-dataset/celeba64/train \
    --cond=False \
    --preset=edm2-img64-xs \
    --batch=64 \
    --batch-gpu=32 \
    --duration=2Mi \
    --status=16Ki \
    --snapshot=512Ki \
    --checkpoint=0 \
    --seed=0


W1114 03:04:57.175000 54 torch/distributed/run.py:792] 
W1114 03:04:57.175000 54 torch/distributed/run.py:792] *****************************************
W1114 03:04:57.175000 54 torch/distributed/run.py:792] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1114 03:04:57.175000 54 torch/distributed/run.py:792] *****************************************
[W1114 03:05:07.025224929 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W1114 03:05:17.031068196 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W1114 03:05:29.815892787 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W1114 03:05:29.830519255 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W1114 03:05:39.826525176 socket.