In [2]:
import torch

# Clear the GPU memory cache
torch.cuda.empty_cache()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.models import vgg16, VGG16_Weights
from PIL import Image
import os
from tqdm import tqdm
import math
import itertools
import random

# For saving generated images and AMP
from torchvision.utils import save_image
from torch.cuda.amp import GradScaler, autocast

# =================================================================================
# 1. Configuration & Hyperparameters
# =================================================================================
class Config:
    """Configuration class for model hyperparameters and paths."""
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    
    # --- Checkpointing & Resuming ---
    RESUME_CHECKPOINT_PATH = "D:\DAIICT\Sem 3\Major Project 1\AG-CycleDiffusion\checkpoints\checkpoint_step_10000.pth"
    
    # --- Paths ---
    LAND_IMG_PATH = r"D:\DAIICT\Sem 3\Major Project 1\AG-CycleDiffusion\Land"
    WATER_IMG_PATH = r"D:\DAIICT\Sem 3\Major Project 1\AG-CycleDiffusion\Underwater"
    
    # --- Training Hyperparameters ---
    IMG_SIZE = 256
    BATCH_SIZE = 1
    TOTAL_ITERS = 60000
    SEED = 42
    
    # --- Optimizers ---
    LR_U = 2e-4
    LR_D = 1e-4
    BETA1 = 0.9
    BETA2 = 0.999
    
    # --- Scheduler Hyperparameters ---
    WARMUP_ITERS = 2000
    
    # --- Loss Weights ---
    LAMBDA_CYCLE = 10.0
    LAMBDA_PERCEPTUAL = 1.0
    LAMBDA_GAN = 1.0
    
    # --- Phased Training Curriculum (in iterations) ---
    PHASE1_ITERS = 30000  # Cycle + Perceptual pretraining
    PHASE2_RAMP_ITERS = 10000 # Steps to ramp up GAN loss in Phase 2
    
    # --- Diffusion Hyperparameters ---
    TIMESTEPS = 200
    
    # --- Training Mechanics ---
    USE_AMP = True
    EMA_DECAY = 0.999

    # --- Checkpointing & Sampling ---
    CHECKPOINT_DIR = "checkpoints"
    SAMPLE_DIR = "samples"
    SAVE_CHECKPOINT_STEP = 2000
    SAVE_IMAGE_STEP = 500

cfg = Config()
os.makedirs(cfg.SAMPLE_DIR, exist_ok=True)
os.makedirs(cfg.CHECKPOINT_DIR, exist_ok=True)

def set_seed(s=42):
    random.seed(s)
    torch.manual_seed(s)
    if cfg.DEVICE == "cuda":
        torch.cuda.manual_seed_all(s)
        torch.backends.cudnn.benchmark = True

set_seed(cfg.SEED)


# =================================================================================
# 2. Diffusion Logic & Helpers
# =================================================================================
def cosine_beta_schedule(timesteps, s=0.008, device=cfg.DEVICE):
    steps = timesteps + 1; x = torch.linspace(0, timesteps, steps, device=device)
    alphas_cumprod = torch.cos(((x/timesteps) + s) / (1 + s) * torch.pi * 0.5)**2
    alphas_cumprod = alphas_cumprod/alphas_cumprod[0]
    betas = 1-(alphas_cumprod[1:]/alphas_cumprod[:-1])
    return betas.clamp(1e-6, 0.999)

class Diffusion:
    def __init__(self, timesteps=cfg.TIMESTEPS, device=cfg.DEVICE):
        self.timesteps, self.device = timesteps, device
        self.betas = cosine_beta_schedule(timesteps, device=device)
        self.alphas = 1. - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

    def noise_images(self, x0, t):
        sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod[t]).view(-1, 1, 1, 1)
        sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod[t]).view(-1, 1, 1, 1)
        noise = torch.randn_like(x0)
        return sqrt_alphas_cumprod * x0 + sqrt_one_minus_alphas_cumprod * noise, noise

    @torch.no_grad()
    def sample(self, model, n, condition_tensor, n_steps=25):
        model.eval()
        x_t = torch.randn((n, 3, cfg.IMG_SIZE, cfg.IMG_SIZE), device=self.device)
        ts_vec = torch.linspace(self.timesteps-1, 0, n_steps+1).long().to(self.device)
        for i in range(n_steps):
            t = ts_vec[i].expand(n)
            pred_x0 = model(x_t, t, condition_tensor)
            
            alpha_cumprod = self.alphas_cumprod[t].view(-1,1,1,1)
            alpha_cumprod_prev = self.alphas_cumprod[ts_vec[i+1]].view(-1,1,1,1) if ts_vec[i+1] >= 0 else torch.ones_like(alpha_cumprod)
            
            pred_noise = (x_t - torch.sqrt(alpha_cumprod) * pred_x0) / torch.sqrt(1. - alpha_cumprod)
            dir_xt = torch.sqrt(1. - alpha_cumprod_prev) * pred_noise
            x_t = torch.sqrt(alpha_cumprod_prev) * pred_x0 + dir_xt
        model.train()
        return (x_t.clamp(-1,1)+1)/2

class EMA:
    def __init__(self, model, decay):
        self.shadow = {k: v.clone().detach() for k, v in model.state_dict().items()}
        self.decay = decay
    @torch.no_grad()
    def update(self, model):
        for k, v in model.state_dict().items(): self.shadow[k].mul_(self.decay).add_(v, alpha=1-self.decay)
    def copy_to(self, model): model.load_state_dict(self.shadow, strict=True)

# =================================================================================
# 3. Model Architectures
# =================================================================================
class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, dropout_prob=0.1):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.relu, self.norm, self.dropout = nn.ReLU(), nn.GroupNorm(8, out_ch), nn.Dropout(dropout_prob)
    def forward(self, x, t):
        h = self.norm(self.relu(self.conv1(x)))
        h += self.relu(self.time_mlp(t)).unsqueeze(-1).unsqueeze(-1)
        return self.dropout(self.norm(self.relu(self.conv2(h))))

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim): super().__init__(); self.dim = dim
    def forward(self, time):
        device = time.device
        half_dim = self.dim//2
        embeddings = math.log(10000)/ (half_dim-1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:,None] * embeddings[None,:]
        return torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)

class ConditionalUNet(nn.Module):
    def __init__(self, in_channels=6, out_channels=3, time_emb_dim=256):
        super().__init__()
        self.time_mlp=nn.Sequential(SinusoidalPositionEmbeddings(time_emb_dim), nn.Linear(time_emb_dim,time_emb_dim), nn.ReLU())
        self.down1, self.down2, self.down3 = Block(in_channels,64,time_emb_dim), Block(64,128,time_emb_dim), Block(128,256,time_emb_dim)
        self.pool = nn.MaxPool2d(2)
        self.bot1 = Block(256,512,time_emb_dim)
        self.up_trans_1, self.up_conv_1 = nn.ConvTranspose2d(512,256,2,2), Block(512,256,time_emb_dim)
        self.up_trans_2, self.up_conv_2 = nn.ConvTranspose2d(256,128,2,2), Block(256,128,time_emb_dim)
        self.up_trans_3, self.up_conv_3 = nn.ConvTranspose2d(128,64,2,2), Block(128,64,time_emb_dim)
        self.out = nn.Conv2d(64,out_channels,1)
    def forward(self, x, t, condition):
        x_cond = torch.cat([x, condition], dim=1)
        t_emb = self.time_mlp(t)
        h1 = self.down1(x_cond, t_emb); h2 = self.down2(self.pool(h1), t_emb); h3 = self.down3(self.pool(h2), t_emb)
        bot = self.bot1(self.pool(h3), t_emb)
        d1 = self.up_conv_1(torch.cat([self.up_trans_1(bot), h3], dim=1), t_emb)
        d2 = self.up_conv_2(torch.cat([self.up_trans_2(d1), h2], dim=1), t_emb)
        d3 = self.up_conv_3(torch.cat([self.up_trans_3(d2), h1], dim=1), t_emb)
        return self.out(d3)

class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()
        def block(i,o,n=True): return [nn.Conv2d(i,o,4,2,1), nn.InstanceNorm2d(o) if n else nn.Identity(), nn.LeakyReLU(0.2,True)]
        self.model=nn.Sequential(*block(in_channels,64,False),*block(64,128),*block(128,256),*block(256,512),nn.Conv2d(512,1,4,1,1))
    def forward(self, img): return self.model(img)

class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        vgg = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features.to(cfg.DEVICE).eval()
        self.feature_extractor = nn.Sequential(*list(vgg.children())[:16])
        for param in self.feature_extractor.parameters(): param.requires_grad=False
        self.l1, self.normalize = nn.L1Loss(), transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    def forward(self, gen, real):
        gen_r, real_r = (gen.clamp(-1,1)+1)*0.5, (real.clamp(-1,1)+1)*0.5
        gen_res, real_res = F.interpolate(gen_r, (224,224)), F.interpolate(real_r, (224,224))
        return self.l1(self.feature_extractor(self.normalize(gen_res)), self.feature_extractor(self.normalize(real_res)))

# =================================================================================
# 4. Data Loading
# =================================================================================
class ImageDataset(Dataset):
    def __init__(self, root_water, root_land, transform=None):
        self.transform = transform
        self.files_water = sorted([f for f in os.listdir(root_water) if f.endswith(('.png','.jpg','.jpeg'))])
        self.files_land = sorted([f for f in os.listdir(root_land) if f.endswith(('.png','.jpg','.jpeg'))])
        self.root_water, self.root_land = root_water, root_land
    def __getitem__(self, index):
        img_w = Image.open(os.path.join(self.root_water, self.files_water[index % len(self.files_water)])).convert("RGB")
        img_l = Image.open(os.path.join(self.root_land, self.files_land[random.randint(0, len(self.files_land)-1)])).convert("RGB")
        return {"water": self.transform(img_w), "land": self.transform(img_l)}
    def __len__(self): return max(len(self.files_water), len(self.files_land))

# =================================================================================
# 5. Training Loop
# =================================================================================
def train():
    # --- Initialization ---
    U_w2l, U_l2w = ConditionalUNet().to(cfg.DEVICE), ConditionalUNet().to(cfg.DEVICE)
    D_land, D_water = Discriminator().to(cfg.DEVICE), Discriminator().to(cfg.DEVICE)
    ema_w2l, ema_l2w = EMA(U_w2l, cfg.EMA_DECAY), EMA(U_l2w, cfg.EMA_DECAY)
    diffusion = Diffusion()

    optimizer_U = torch.optim.AdamW(itertools.chain(U_w2l.parameters(), U_l2w.parameters()), lr=cfg.LR_U, betas=(cfg.BETA1, cfg.BETA2))
    optimizer_D = torch.optim.AdamW(itertools.chain(D_land.parameters(), D_water.parameters()), lr=cfg.LR_D, betas=(cfg.BETA1, cfg.BETA2))
    
    crit_GAN, crit_cycle, crit_perc = nn.MSELoss(), nn.L1Loss(), PerceptualLoss()
    scaler_U, scaler_D = GradScaler(enabled=cfg.USE_AMP), GradScaler(enabled=cfg.USE_AMP)

    transform = transforms.Compose([
        transforms.Resize((cfg.IMG_SIZE, cfg.IMG_SIZE)), transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
    ])
    dataloader = DataLoader(ImageDataset(cfg.WATER_IMG_PATH, cfg.LAND_IMG_PATH, transform=transform), batch_size=cfg.BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)

    sched_U = torch.optim.lr_scheduler.SequentialLR(optimizer_U, [torch.optim.lr_scheduler.LinearLR(optimizer_U, 1e-6, 1.0, cfg.WARMUP_ITERS), torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_U, cfg.TOTAL_ITERS - cfg.WARMUP_ITERS)], [cfg.WARMUP_ITERS])
    sched_D = torch.optim.lr_scheduler.SequentialLR(optimizer_D, [torch.optim.lr_scheduler.LinearLR(optimizer_D, 1e-6, 1.0, cfg.WARMUP_ITERS), torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_D, cfg.TOTAL_ITERS - cfg.WARMUP_ITERS)], [cfg.WARMUP_ITERS])
    
    start_iter = 0
    if cfg.RESUME_CHECKPOINT_PATH:
        print(f"🔄 Resuming training from {cfg.RESUME_CHECKPOINT_PATH}...")
        ckpt = torch.load(cfg.RESUME_CHECKPOINT_PATH, map_location=cfg.DEVICE)
        U_w2l.load_state_dict(ckpt['U_w2l']); U_l2w.load_state_dict(ckpt['U_l2w'])
        D_land.load_state_dict(ckpt['D_land']); D_water.load_state_dict(ckpt['D_water'])
        optimizer_U.load_state_dict(ckpt['opt_U']); optimizer_D.load_state_dict(ckpt['opt_D'])
        sched_U.load_state_dict(ckpt['sched_U']); sched_D.load_state_dict(ckpt['sched_D'])
        scaler_U.load_state_dict(ckpt['scaler_U']); scaler_D.load_state_dict(ckpt['scaler_D'])
        ema_w2l.shadow = ckpt['ema_w2l']; ema_l2w.shadow = ckpt['ema_l2w']
        start_iter = ckpt['iter'] + 1
        print(f"✅ Resumed successfully from iteration {start_iter}.")

    print("🚀 Starting Simplified Training for AG-CycleDiffusion...")
    pbar = tqdm(range(start_iter, cfg.TOTAL_ITERS), initial=start_iter, total=cfg.TOTAL_ITERS)
    data_iter = iter(dataloader)

    for step in pbar:
        try: batch = next(data_iter)
        except StopIteration: data_iter = iter(dataloader); batch = next(data_iter)
        real_water, real_land = batch["water"].to(cfg.DEVICE), batch["land"].to(cfg.DEVICE)
        
        # --- Determine Training Phase ---
        phase = 1; lambda_gan_current = 0.0
        if step >= cfg.PHASE1_ITERS: 
            phase = 2
            lambda_gan_current = min(1.0, (step - cfg.PHASE1_ITERS) / cfg.PHASE2_RAMP_ITERS) * cfg.LAMBDA_GAN

        # --- Train U-Nets (Generators) ---
        optimizer_U.zero_grad(set_to_none=True)
        with autocast(enabled=cfg.USE_AMP):
            t = torch.randint(0, cfg.TIMESTEPS, (real_water.size(0),), device=cfg.DEVICE).long()
            
            # -- Symmetrical Cycles: Predict clean images directly --
            noisy_land, _ = diffusion.noise_images(real_land, t)
            fake_land = U_w2l(noisy_land, t, real_water)
            
            noisy_fake_land, _ = diffusion.noise_images(fake_land, t)
            reconstructed_water = U_l2w(noisy_fake_land, t, fake_land)
            
            noisy_water, _ = diffusion.noise_images(real_water, t)
            fake_water = U_l2w(noisy_water, t, real_land)
            
            noisy_fake_water, _ = diffusion.noise_images(fake_water, t)
            reconstructed_land = U_w2l(noisy_fake_water, t, fake_water)

            # -- Cycle & Perceptual Loss --
            loss_cycle = crit_cycle(reconstructed_water, real_water) + crit_cycle(reconstructed_land, real_land)
            loss_perceptual = crit_perc(reconstructed_water, real_water) + crit_perc(reconstructed_land, real_land)
            
            # -- Total U-Net Loss Calculation --
            loss_U = cfg.LAMBDA_CYCLE * loss_cycle + cfg.LAMBDA_PERCEPTUAL * loss_perceptual
            
            if lambda_gan_current > 0:
                pred_fake_land = D_land(fake_land)
                loss_GAN_w2l = crit_GAN(pred_fake_land, torch.ones_like(pred_fake_land))
                
                pred_fake_water = D_water(fake_water)
                loss_GAN_l2w = crit_GAN(pred_fake_water, torch.ones_like(pred_fake_water))
                loss_U += lambda_gan_current * (loss_GAN_w2l + loss_GAN_l2w)

        scaler_U.scale(loss_U).backward()
        scaler_U.step(optimizer_U); scaler_U.update()
        ema_w2l.update(U_w2l); ema_l2w.update(U_l2w)

        # --- Train Discriminators ---
        if lambda_gan_current > 0:
            optimizer_D.zero_grad(set_to_none=True)
            with autocast(enabled=cfg.USE_AMP):
                pred_real_land = D_land(real_land)
                pred_fake_land = D_land(fake_land.detach())
                loss_D_land = (crit_GAN(pred_real_land, torch.ones_like(pred_real_land)) + crit_GAN(pred_fake_land, torch.zeros_like(pred_fake_land))) / 2
                
                pred_real_water = D_water(real_water)
                pred_fake_water = D_water(fake_water.detach())
                loss_D_water = (crit_GAN(pred_real_water, torch.ones_like(pred_real_water)) + crit_GAN(pred_fake_water, torch.zeros_like(pred_fake_water))) / 2
                total_loss_D = (loss_D_land + loss_D_water) / 2
            
            scaler_D.scale(total_loss_D).backward()
            scaler_D.step(optimizer_D); scaler_D.update()
        else:
            total_loss_D = torch.tensor(0.0, device=cfg.DEVICE)

        sched_U.step(); sched_D.step()
        pbar.set_postfix({"U Loss": loss_U.item(), "D Loss": total_loss_D.item(), "LR": optimizer_U.param_groups[0]['lr'], "Phase": phase})

        # --- Sampling and Checkpointing ---
        if (step + 1) % cfg.SAVE_IMAGE_STEP == 0:
            print("📸 Sampling images...")
            ema_U_w2l_sample = ConditionalUNet().to(cfg.DEVICE)
            ema_w2l.copy_to(ema_U_w2l_sample)
            ema_U_w2l_sample.eval()
            sampled_land = diffusion.sample(ema_U_w2l_sample, n=1, condition_tensor=real_water[:1])
            img_sample = torch.cat((real_water[:1].add(1).mul(0.5), sampled_land), 0)
            save_image(img_sample, f"{cfg.SAMPLE_DIR}/step_{step+1}.png", nrow=1)
            del ema_U_w2l_sample
            
        if (step + 1) % cfg.SAVE_CHECKPOINT_STEP == 0:
             print(f"💾 Saving checkpoint for step {step+1}...")
             torch.save({
                'iter': step, 'U_w2l': U_w2l.state_dict(), 'U_l2w': U_l2w.state_dict(),
                'D_land': D_land.state_dict(), 'D_water': D_water.state_dict(),
                'opt_U': optimizer_U.state_dict(), 'opt_D': optimizer_D.state_dict(),
                'sched_U': sched_U.state_dict(), 'sched_D': sched_D.state_dict(),
                'scaler_U': scaler_U.state_dict(), 'scaler_D': scaler_D.state_dict(),
                'ema_w2l': ema_w2l.shadow, 'ema_l2w': ema_l2w.shadow
            }, f"{cfg.CHECKPOINT_DIR}/checkpoint_step_{step+1}.pth")

if __name__ == '__main__':
    train()

  scaler_U, scaler_D = GradScaler(enabled=cfg.USE_AMP), GradScaler(enabled=cfg.USE_AMP)
  ckpt = torch.load(cfg.RESUME_CHECKPOINT_PATH, map_location=cfg.DEVICE)


🔄 Resuming training from D:\DAIICT\Sem 3\Major Project 1\AG-CycleDiffusion\checkpoints\checkpoint_step_10000.pth...
✅ Resumed successfully from iteration 10000.
🚀 Starting Simplified Training for AG-CycleDiffusion...


  with autocast(enabled=cfg.USE_AMP):
 10%|█         | 10499/100000 [01:17<3:49:57,  6.49it/s, U Loss=2.51, D Loss=0, LR=0.000196, Phase=1]

📸 Sampling images...


 11%|█         | 10999/100000 [02:37<4:02:39,  6.11it/s, U Loss=2.79, D Loss=0, LR=0.000196, Phase=1]

📸 Sampling images...


 11%|█▏        | 11499/100000 [03:56<3:44:48,  6.56it/s, U Loss=2.12, D Loss=0, LR=0.000195, Phase=1]

📸 Sampling images...


 12%|█▏        | 11999/100000 [05:15<3:54:19,  6.26it/s, U Loss=2.27, D Loss=0, LR=0.000195, Phase=1]

📸 Sampling images...
💾 Saving checkpoint for step 12000...


 12%|█▏        | 12499/100000 [06:39<4:10:48,  5.81it/s, U Loss=2.07, D Loss=0, LR=0.000194, Phase=1] 

📸 Sampling images...


 13%|█▎        | 12999/100000 [08:01<3:55:34,  6.16it/s, U Loss=1.86, D Loss=0, LR=0.000194, Phase=1]

📸 Sampling images...


 13%|█▎        | 13499/100000 [09:24<4:03:04,  5.93it/s, U Loss=2.28, D Loss=0, LR=0.000193, Phase=1]

📸 Sampling images...


 14%|█▍        | 13999/100000 [10:47<4:01:23,  5.94it/s, U Loss=1.8, D Loss=0, LR=0.000193, Phase=1] 

📸 Sampling images...
💾 Saving checkpoint for step 14000...


 14%|█▍        | 14499/100000 [12:07<3:38:02,  6.54it/s, U Loss=1.82, D Loss=0, LR=0.000192, Phase=1] 

📸 Sampling images...


 15%|█▍        | 14999/100000 [13:25<3:45:59,  6.27it/s, U Loss=1.58, D Loss=0, LR=0.000191, Phase=1]

📸 Sampling images...


 15%|█▌        | 15499/100000 [14:45<3:45:36,  6.24it/s, U Loss=2.15, D Loss=0, LR=0.000191, Phase=1]

📸 Sampling images...


 16%|█▌        | 15999/100000 [16:04<3:40:47,  6.34it/s, U Loss=1.83, D Loss=0, LR=0.00019, Phase=1] 

📸 Sampling images...
💾 Saving checkpoint for step 16000...


 16%|█▋        | 16499/100000 [17:22<3:39:53,  6.33it/s, U Loss=1.78, D Loss=0, LR=0.000189, Phase=1]

📸 Sampling images...


 17%|█▋        | 16999/100000 [18:41<3:39:06,  6.31it/s, U Loss=1.86, D Loss=0, LR=0.000189, Phase=1]

📸 Sampling images...


 17%|█▋        | 17499/100000 [20:01<3:33:09,  6.45it/s, U Loss=1.71, D Loss=0, LR=0.000188, Phase=1]

📸 Sampling images...


 18%|█▊        | 17999/100000 [21:20<3:33:57,  6.39it/s, U Loss=1.86, D Loss=0, LR=0.000187, Phase=1]

📸 Sampling images...
💾 Saving checkpoint for step 18000...


 18%|█▊        | 18499/100000 [22:40<3:40:54,  6.15it/s, U Loss=1.66, D Loss=0, LR=0.000186, Phase=1]

📸 Sampling images...


 19%|█▉        | 18999/100000 [23:59<3:31:34,  6.38it/s, U Loss=1.57, D Loss=0, LR=0.000186, Phase=1]

📸 Sampling images...


 19%|█▉        | 19499/100000 [25:19<3:28:31,  6.43it/s, U Loss=1.65, D Loss=0, LR=0.000185, Phase=1]

📸 Sampling images...


 20%|█▉        | 19999/100000 [26:38<3:27:31,  6.43it/s, U Loss=1.31, D Loss=0, LR=0.000184, Phase=1]

📸 Sampling images...
💾 Saving checkpoint for step 20000...


 20%|██        | 20499/100000 [27:58<3:31:44,  6.26it/s, U Loss=1.4, D Loss=0, LR=0.000183, Phase=1] 

📸 Sampling images...


 21%|██        | 20999/100000 [29:20<3:32:34,  6.19it/s, U Loss=1.41, D Loss=0, LR=0.000182, Phase=1]

📸 Sampling images...


 21%|██▏       | 21499/100000 [30:43<3:38:07,  6.00it/s, U Loss=1.37, D Loss=0, LR=0.000181, Phase=1]

📸 Sampling images...


 22%|██▏       | 21999/100000 [32:06<3:35:01,  6.05it/s, U Loss=1.72, D Loss=0, LR=0.00018, Phase=1] 

📸 Sampling images...
💾 Saving checkpoint for step 22000...


 22%|██▏       | 22499/100000 [33:29<3:32:27,  6.08it/s, U Loss=1.35, D Loss=0, LR=0.000179, Phase=1]

📸 Sampling images...


 23%|██▎       | 22999/100000 [34:52<3:37:07,  5.91it/s, U Loss=1.71, D Loss=0, LR=0.000178, Phase=1]

📸 Sampling images...


 23%|██▎       | 23499/100000 [36:15<3:30:31,  6.06it/s, U Loss=1.5, D Loss=0, LR=0.000177, Phase=1] 

📸 Sampling images...


 24%|██▍       | 23999/100000 [37:40<3:33:34,  5.93it/s, U Loss=1.95, D Loss=0, LR=0.000176, Phase=1] 

📸 Sampling images...
💾 Saving checkpoint for step 24000...


 24%|██▍       | 24499/100000 [39:10<3:52:37,  5.41it/s, U Loss=1.31, D Loss=0, LR=0.000175, Phase=1]

📸 Sampling images...


 25%|██▍       | 24999/100000 [40:39<3:39:34,  5.69it/s, U Loss=1.88, D Loss=0, LR=0.000174, Phase=1]

📸 Sampling images...


 25%|██▌       | 25499/100000 [42:04<3:31:41,  5.87it/s, U Loss=1.48, D Loss=0, LR=0.000173, Phase=1] 

📸 Sampling images...


 26%|██▌       | 25999/100000 [43:31<3:20:57,  6.14it/s, U Loss=1.39, D Loss=0, LR=0.000172, Phase=1]

📸 Sampling images...
💾 Saving checkpoint for step 26000...


 26%|██▋       | 26499/100000 [44:53<3:23:05,  6.03it/s, U Loss=1.16, D Loss=0, LR=0.000171, Phase=1] 

📸 Sampling images...


 27%|██▋       | 26999/100000 [46:16<3:26:39,  5.89it/s, U Loss=1.49, D Loss=0, LR=0.00017, Phase=1] 

📸 Sampling images...


 27%|██▋       | 27499/100000 [47:39<3:26:02,  5.86it/s, U Loss=1.54, D Loss=0, LR=0.000168, Phase=1]

📸 Sampling images...


 28%|██▊       | 27999/100000 [49:03<3:24:46,  5.86it/s, U Loss=2.05, D Loss=0, LR=0.000167, Phase=1]

📸 Sampling images...
💾 Saving checkpoint for step 28000...


 28%|██▊       | 28499/100000 [50:27<3:15:45,  6.09it/s, U Loss=1.58, D Loss=0, LR=0.000166, Phase=1] 

📸 Sampling images...


 29%|██▉       | 28999/100000 [51:50<3:14:45,  6.08it/s, U Loss=1.42, D Loss=0, LR=0.000165, Phase=1]

📸 Sampling images...


 29%|██▉       | 29499/100000 [53:12<3:13:31,  6.07it/s, U Loss=1.78, D Loss=0, LR=0.000164, Phase=1] 

📸 Sampling images...


 30%|██▉       | 29999/100000 [54:36<3:12:23,  6.06it/s, U Loss=1.5, D Loss=0, LR=0.000162, Phase=1] 

📸 Sampling images...
💾 Saving checkpoint for step 30000...


  with autocast(enabled=cfg.USE_AMP):
 30%|███       | 30499/100000 [56:08<3:32:15,  5.46it/s, U Loss=1.49, D Loss=0.0167, LR=0.000161, Phase=2] 

📸 Sampling images...


 31%|███       | 30999/100000 [57:38<3:31:07,  5.45it/s, U Loss=1.76, D Loss=0.023, LR=0.00016, Phase=2]   

📸 Sampling images...


 31%|███▏      | 31499/100000 [59:10<3:21:17,  5.67it/s, U Loss=1.9, D Loss=0.00603, LR=0.000159, Phase=2] 

📸 Sampling images...


 32%|███▏      | 31999/100000 [1:00:40<3:27:08,  5.47it/s, U Loss=2.6, D Loss=0.0157, LR=0.000157, Phase=2]  

📸 Sampling images...
💾 Saving checkpoint for step 32000...


 32%|███▏      | 32499/100000 [1:02:11<3:30:57,  5.33it/s, U Loss=2.21, D Loss=0.131, LR=0.000156, Phase=2]  

📸 Sampling images...


 33%|███▎      | 32999/100000 [1:03:42<3:20:00,  5.58it/s, U Loss=2.37, D Loss=0.0659, LR=0.000155, Phase=2]

📸 Sampling images...


 33%|███▎      | 33499/100000 [1:05:13<3:19:22,  5.56it/s, U Loss=2.75, D Loss=0.182, LR=0.000153, Phase=2] 

📸 Sampling images...


 34%|███▍      | 33999/100000 [1:06:43<3:21:51,  5.45it/s, U Loss=1.62, D Loss=0.262, LR=0.000152, Phase=2] 

📸 Sampling images...
💾 Saving checkpoint for step 34000...


 34%|███▍      | 34499/100000 [1:08:14<3:17:27,  5.53it/s, U Loss=2.93, D Loss=0.166, LR=0.00015, Phase=2]  

📸 Sampling images...


 35%|███▍      | 34999/100000 [1:09:46<3:18:11,  5.47it/s, U Loss=1.79, D Loss=0.198, LR=0.000149, Phase=2] 

📸 Sampling images...


 35%|███▌      | 35499/100000 [1:11:14<3:11:59,  5.60it/s, U Loss=2.15, D Loss=0.0744, LR=0.000148, Phase=2]

📸 Sampling images...


 36%|███▌      | 35999/100000 [1:12:43<3:10:51,  5.59it/s, U Loss=2.85, D Loss=0.033, LR=0.000146, Phase=2] 

📸 Sampling images...
💾 Saving checkpoint for step 36000...


 36%|███▋      | 36499/100000 [1:14:12<3:01:53,  5.82it/s, U Loss=3.19, D Loss=0.049, LR=0.000145, Phase=2] 

📸 Sampling images...


 37%|███▋      | 36999/100000 [1:15:40<3:06:14,  5.64it/s, U Loss=2.85, D Loss=0.0936, LR=0.000143, Phase=2]

📸 Sampling images...


 37%|███▋      | 37499/100000 [1:17:09<3:06:20,  5.59it/s, U Loss=2.82, D Loss=0.092, LR=0.000142, Phase=2] 

📸 Sampling images...


 38%|███▊      | 37999/100000 [1:18:38<2:58:19,  5.79it/s, U Loss=2.66, D Loss=0.115, LR=0.00014, Phase=2]  

📸 Sampling images...
💾 Saving checkpoint for step 38000...


 38%|███▊      | 38499/100000 [1:20:07<3:00:49,  5.67it/s, U Loss=3.31, D Loss=0.15, LR=0.000139, Phase=2]  

📸 Sampling images...


 39%|███▉      | 38999/100000 [1:21:36<3:12:43,  5.28it/s, U Loss=3.37, D Loss=0.0501, LR=0.000138, Phase=2]

📸 Sampling images...


 39%|███▉      | 39499/100000 [1:23:04<3:10:55,  5.28it/s, U Loss=3.66, D Loss=0.027, LR=0.000136, Phase=2] 

📸 Sampling images...


 40%|███▉      | 39999/100000 [1:24:33<2:54:38,  5.73it/s, U Loss=3.26, D Loss=0.0377, LR=0.000135, Phase=2]

📸 Sampling images...
💾 Saving checkpoint for step 40000...


 40%|████      | 40499/100000 [1:26:02<2:56:42,  5.61it/s, U Loss=3.77, D Loss=0.0274, LR=0.000133, Phase=2] 

📸 Sampling images...


 41%|████      | 40999/100000 [1:27:30<2:52:21,  5.71it/s, U Loss=2.79, D Loss=0.327, LR=0.000132, Phase=2]  

📸 Sampling images...


 41%|████▏     | 41499/100000 [1:28:58<2:49:29,  5.75it/s, U Loss=4.41, D Loss=0.0195, LR=0.00013, Phase=2] 

📸 Sampling images...


 42%|████▏     | 41999/100000 [1:30:27<2:59:29,  5.39it/s, U Loss=3.59, D Loss=0.111, LR=0.000128, Phase=2] 

📸 Sampling images...
💾 Saving checkpoint for step 42000...


 42%|████▏     | 42499/100000 [1:31:56<2:50:26,  5.62it/s, U Loss=2.61, D Loss=0.131, LR=0.000127, Phase=2] 

📸 Sampling images...


 43%|████▎     | 42999/100000 [1:33:24<2:56:41,  5.38it/s, U Loss=3.33, D Loss=0.0707, LR=0.000125, Phase=2]

📸 Sampling images...


 43%|████▎     | 43499/100000 [1:34:54<2:54:13,  5.40it/s, U Loss=2.85, D Loss=0.159, LR=0.000124, Phase=2] 

📸 Sampling images...


 44%|████▍     | 43999/100000 [1:36:22<2:41:11,  5.79it/s, U Loss=3.76, D Loss=0.0703, LR=0.000122, Phase=2] 

📸 Sampling images...
💾 Saving checkpoint for step 44000...


 44%|████▍     | 44499/100000 [1:37:51<2:41:39,  5.72it/s, U Loss=4.18, D Loss=0.138, LR=0.000121, Phase=2]  

📸 Sampling images...


 45%|████▍     | 44999/100000 [1:39:20<2:38:06,  5.80it/s, U Loss=3.12, D Loss=0.144, LR=0.000119, Phase=2]  

📸 Sampling images...


 45%|████▌     | 45499/100000 [1:40:50<2:42:25,  5.59it/s, U Loss=3.43, D Loss=0.0468, LR=0.000118, Phase=2] 

📸 Sampling images...


 46%|████▌     | 45999/100000 [1:42:18<2:44:43,  5.46it/s, U Loss=3.16, D Loss=0.134, LR=0.000116, Phase=2]  

📸 Sampling images...
💾 Saving checkpoint for step 46000...


 46%|████▋     | 46499/100000 [1:43:47<2:34:21,  5.78it/s, U Loss=3.55, D Loss=0.117, LR=0.000114, Phase=2]  

📸 Sampling images...


 47%|████▋     | 46999/100000 [1:45:16<2:34:20,  5.72it/s, U Loss=2.49, D Loss=0.149, LR=0.000113, Phase=2]  

📸 Sampling images...


 47%|████▋     | 47499/100000 [1:46:45<2:32:36,  5.73it/s, U Loss=3.55, D Loss=0.039, LR=0.000111, Phase=2] 

📸 Sampling images...


 48%|████▊     | 47999/100000 [1:48:14<2:32:39,  5.68it/s, U Loss=2.99, D Loss=0.122, LR=0.00011, Phase=2]  

📸 Sampling images...
💾 Saving checkpoint for step 48000...


 48%|████▊     | 48499/100000 [1:49:43<2:31:56,  5.65it/s, U Loss=3.19, D Loss=0.118, LR=0.000108, Phase=2]  

📸 Sampling images...


 49%|████▉     | 48999/100000 [1:51:13<2:28:56,  5.71it/s, U Loss=3.17, D Loss=0.0309, LR=0.000106, Phase=2] 

📸 Sampling images...


 49%|████▉     | 49499/100000 [1:52:42<2:25:54,  5.77it/s, U Loss=3.45, D Loss=0.178, LR=0.000105, Phase=2]  

📸 Sampling images...


 50%|████▉     | 49999/100000 [1:54:11<2:38:40,  5.25it/s, U Loss=3.89, D Loss=0.0227, LR=0.000103, Phase=2] 

📸 Sampling images...
💾 Saving checkpoint for step 50000...


 50%|█████     | 50499/100000 [1:55:40<2:28:14,  5.57it/s, U Loss=4.67, D Loss=0.0197, LR=0.000102, Phase=2]

📸 Sampling images...


 51%|█████     | 50999/100000 [1:57:09<2:20:52,  5.80it/s, U Loss=3.69, D Loss=0.0378, LR=0.0001, Phase=2]  

📸 Sampling images...


 51%|█████▏    | 51499/100000 [1:58:37<2:22:26,  5.68it/s, U Loss=3.68, D Loss=0.0647, LR=9.84e-5, Phase=2] 

📸 Sampling images...


 52%|█████▏    | 51999/100000 [2:00:07<2:22:10,  5.63it/s, U Loss=3.3, D Loss=0.0652, LR=9.68e-5, Phase=2]  

📸 Sampling images...
💾 Saving checkpoint for step 52000...


 52%|█████▏    | 52499/100000 [2:01:36<2:20:40,  5.63it/s, U Loss=3.89, D Loss=0.0564, LR=9.52e-5, Phase=2] 

📸 Sampling images...


 53%|█████▎    | 52999/100000 [2:03:06<2:17:46,  5.69it/s, U Loss=3.42, D Loss=0.0613, LR=9.36e-5, Phase=2] 

📸 Sampling images...


 53%|█████▎    | 53499/100000 [2:04:36<2:15:54,  5.70it/s, U Loss=2.88, D Loss=0.132, LR=9.2e-5, Phase=2]   

📸 Sampling images...


 54%|█████▍    | 53999/100000 [2:06:09<2:17:56,  5.56it/s, U Loss=4.32, D Loss=0.058, LR=9.04e-5, Phase=2]  

📸 Sampling images...
💾 Saving checkpoint for step 54000...


 54%|█████▍    | 54499/100000 [2:07:45<2:27:54,  5.13it/s, U Loss=4.25, D Loss=0.0249, LR=8.88e-5, Phase=2] 

📸 Sampling images...


 55%|█████▍    | 54999/100000 [2:09:19<2:19:27,  5.38it/s, U Loss=4.79, D Loss=0.02, LR=8.72e-5, Phase=2]   

📸 Sampling images...


 55%|█████▌    | 55499/100000 [2:10:56<2:19:41,  5.31it/s, U Loss=3.64, D Loss=0.00749, LR=8.56e-5, Phase=2]

📸 Sampling images...


 56%|█████▌    | 55999/100000 [2:12:36<2:28:57,  4.92it/s, U Loss=3.12, D Loss=0.04, LR=8.4e-5, Phase=2]    

📸 Sampling images...
💾 Saving checkpoint for step 56000...


 56%|█████▋    | 56499/100000 [2:14:09<2:11:34,  5.51it/s, U Loss=2.96, D Loss=0.115, LR=8.25e-5, Phase=2]  

📸 Sampling images...


 57%|█████▋    | 56999/100000 [2:15:48<2:21:35,  5.06it/s, U Loss=4.35, D Loss=0.0188, LR=8.09e-5, Phase=2] 

📸 Sampling images...


 57%|█████▋    | 57499/100000 [2:17:26<2:11:24,  5.39it/s, U Loss=3.55, D Loss=0.0272, LR=7.93e-5, Phase=2] 

📸 Sampling images...


 58%|█████▊    | 57999/100000 [2:18:57<2:05:01,  5.60it/s, U Loss=3.64, D Loss=0.0273, LR=7.77e-5, Phase=2] 

📸 Sampling images...
💾 Saving checkpoint for step 58000...


 58%|█████▊    | 58499/100000 [2:20:27<2:02:16,  5.66it/s, U Loss=3.62, D Loss=0.0175, LR=7.62e-5, Phase=2] 

📸 Sampling images...


 59%|█████▉    | 58999/100000 [2:21:57<2:00:27,  5.67it/s, U Loss=3.99, D Loss=0.0239, LR=7.46e-5, Phase=2] 

📸 Sampling images...


 59%|█████▉    | 59499/100000 [2:23:26<2:00:15,  5.61it/s, U Loss=4.22, D Loss=0.0112, LR=7.31e-5, Phase=2] 

📸 Sampling images...


 60%|█████▉    | 59999/100000 [2:24:56<1:59:32,  5.58it/s, U Loss=3.88, D Loss=0.022, LR=7.15e-5, Phase=2]  

📸 Sampling images...
💾 Saving checkpoint for step 60000...


 60%|██████    | 60499/100000 [2:26:24<1:53:13,  5.81it/s, U Loss=4.06, D Loss=0.0382, LR=7e-5, Phase=2]    

📸 Sampling images...


 61%|██████    | 60999/100000 [2:27:54<1:55:07,  5.65it/s, U Loss=4.21, D Loss=0.0765, LR=6.85e-5, Phase=2] 

📸 Sampling images...


 61%|██████▏   | 61499/100000 [2:29:24<1:54:40,  5.60it/s, U Loss=3.86, D Loss=0.00592, LR=6.7e-5, Phase=2] 

📸 Sampling images...


 62%|██████▏   | 61839/100000 [2:30:24<1:50:43,  5.74it/s, U Loss=4.14, D Loss=0.0176, LR=6.59e-5, Phase=2] 


KeyboardInterrupt: 