In [1]:
# ============================================================
# ✅ SETUP
# ============================================================

!git clone https://github.com/NVlabs/edm2.git /kaggle/working/edm2
%cd /kaggle/working/edm2
!pip install click tqdm psutil scipy pillow --quiet
%cd /kaggle/working/edm2


Cloning into '/kaggle/working/edm2'...
remote: Enumerating objects: 60, done.[K
remote: Counting objects: 100% (27/27), done.[K
remote: Compressing objects: 100% (17/17), done.[K
remote: Total 60 (delta 13), reused 10 (delta 10), pack-reused 33 (from 1)[K
Receiving objects: 100% (60/60), 1.27 MiB | 10.34 MiB/s, done.
Resolving deltas: 100% (24/24), done.
/kaggle/working/edm2
/kaggle/working/edm2


In [2]:
# ============================================================
# ✅ PATCH: LOGARITHMIC TRAINING NOISE SCHEDULE
# ============================================================
import torch
import training.training_loop as loop

def logarithmic_sigma(batch_size, device, sigma_min=0.002, sigma_max=80):
    """
    Sample sigma logarithmically from sigma_max → sigma_min.
    t ~ U[0,1]
    sigma(t) = 10^( log10(sigma_max) - t*(log10(sigma_max) - log10(sigma_min)) )
    """
    t = torch.rand(batch_size, 1, 1, 1, device=device)

    log_max = torch.log10(torch.tensor(sigma_max, device=device))
    log_min = torch.log10(torch.tensor(sigma_min, device=device))

    log_sigma = log_max + t * (log_min - log_max)
    return torch.pow(10.0, log_sigma)


def patched_call(self, net, images, labels=None):
    """
    Overrides EDM2Loss.__call__ with a logarithmic sigma schedule.
    """
    batch = images.shape[0]

    # === LOGARITHMIC NOISE SCHEDULE ===
    sigma = logarithmic_sigma(batch, images.device)

    # === EDM2 weighting ===
    weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data)**2

    # Add noise
    noise = torch.randn_like(images) * sigma

    # Forward network
    denoised, logvar = net(images + noise, sigma, labels, return_logvar=True)

    # Loss
    loss = (weight / logvar.exp()) * ((denoised - images)**2) + logvar
    return loss


# Inject patch
loop.EDM2Loss.__call__ = patched_call

print("✅ Logarithmic training schedule injected successfully!")


✅ Logarithmic training schedule injected successfully!


In [3]:
!torchrun --standalone --nproc_per_node=2 train_edm2.py \
    --outdir=/kaggle/working/training-runs/celeba64-edm2-logarithmic-noise-inject2 \
    --data=/kaggle/input/celeva-64x64-dataset/celeba64/train \
    --cond=False \
    --preset=edm2-img64-xs \
    --batch=64 \
    --batch-gpu=32 \
    --duration=2Mi \
    --status=16Ki \
    --snapshot=256Ki \
    --checkpoint=0 \
    --seed=0


W1114 02:57:11.076000 54 torch/distributed/run.py:792] 
W1114 02:57:11.076000 54 torch/distributed/run.py:792] *****************************************
W1114 02:57:11.076000 54 torch/distributed/run.py:792] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1114 02:57:11.076000 54 torch/distributed/run.py:792] *****************************************
[W1114 02:57:21.904045632 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W1114 02:57:31.914050664 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W1114 02:57:42.618581984 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W1114 02:57:42.619066192 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W1114 02:57:52.629052445 socket.

In [4]:
!tail -n 20 /kaggle/working/training-runs/celeba64-edm2-logarithmic/log.txt


tail: cannot open '/kaggle/working/training-runs/celeba64-edm2-logarithmic/log.txt' for reading: No such file or directory
