In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset, ConcatDataset, random_split
from torch.optim import AdamW
import matplotlib.pyplot as plt
import os
import math
import shutil
from datetime import datetime
from tqdm import tqdm
import sys
from typing import List, Tuple
from diffusers import DDPMScheduler, DDIMScheduler, DPMSolverMultistepScheduler
import numpy as np

scale_bp = 1
scale_dfa = 3
input_ch = 1
beta_start = 1e-4
beta_end = 0.02
epochs = 31 # for DDPM 31 epochs
# epochs = 51 # for DDIM 51 epochs
eval_freq = 1
device = 'cuda' if torch.cuda.is_available() else 'cpu'
time_embed_dim = 100


print("--- 1. setting ---")

num_timesteps = None
while num_timesteps is None:
    try:
        train_num_timesteps_input_str = input(f"Select train timestep (e.g. 1000, 250) [default: 250]: ").strip()
        if not train_num_timesteps_input_str:
            num_timesteps = 250
        else:
            train_num_timesteps_input = int(train_num_timesteps_input_str)
            if train_num_timesteps_input > 0:
                num_timesteps = train_num_timesteps_input
            else:
                print("Invalid input. Please enter a positive number.")
    except ValueError:
        print("Invalid input. Please enter a number.")
print(f"Train num_timesteps selected: {num_timesteps}")

SAMPLING_STEPS = {
    'ddpm': num_timesteps,
    'ddim': 40,
    'dpm_solver_2': 40,
    'dpm_solver_pp': 40
}
print(f"Sampling steps set to: {SAMPLING_STEPS}")

selected_dataset = 'mnist'
print(f"Dataset selected: {selected_dataset}")

class_filter = None
try:
    class_filter_input = input("Enter MNIST digit to use (0-9), or leave blank for all: ").strip()
    if class_filter_input:
        class_filter = int(class_filter_input)
        if not 0 <= class_filter <= 9:
            print("Invalid digit. Using all digits.")
            class_filter = None
except ValueError:
    print("Invalid input. Using all digits.")
print(f"Class filter set to: {class_filter}")

EVAL_SAMPLER_TYPE = None
while EVAL_SAMPLER_TYPE is None:
    sampler_input = input(f"Select sampler for visualization {list(SAMPLING_STEPS.keys())}: ").strip().lower()
    if sampler_input in SAMPLING_STEPS:
        EVAL_SAMPLER_TYPE = sampler_input
    else:
        print(f"Invalid input. Please choose from {list(SAMPLING_STEPS.keys())}.")
EVAL_SAMPLING_STEPS = SAMPLING_STEPS[EVAL_SAMPLER_TYPE]
print(f"Using [{EVAL_SAMPLER_TYPE}] sampler with {EVAL_SAMPLING_STEPS} steps for visualization.")

USE_DFA = None
while USE_DFA is None:
    mode_input = input("Select training mode (DFA/BP): ").strip().lower()
    if mode_input == 'dfa': USE_DFA = True
    elif mode_input == 'bp': USE_DFA = False
    else: print("Invalid input. Please type 'DFA' or 'BP'.")
    
if USE_DFA:
    scale = scale_dfa # 3
    batch_size = 8
    lr = 1e-3
    weight_decay = 1e-4
else: # BP
    scale = scale_bp # 1
    batch_size = 128 
    lr = 1e-3
    weight_decay = 1e-4
print(f"Mode: {'DFA' if USE_DFA else 'BP'}, Batch size: {batch_size}, scale: {scale}, LR: {lr}, WD: {weight_decay}")

model_version = None
while model_version is None:
    model_input = input("Select UNet version (1/2): ").strip()
    if model_input == '1': model_version = 1
    elif model_input == '2': model_version = 2
    else: print("Invalid input. Please type '1' or '2'.")
print(f"UNet version selected: UNet{model_version}")

image_size = 28
print(f"Image size selected: {image_size}x{image_size} (Fixed for MNIST)")
print("--- 1. setting completed ---")


class Inject_e(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, e, B):
        ctx.save_for_backward(e, B)
        ctx.grad_output_shape = x.shape
        return x

    @staticmethod
    def backward(ctx, grad_output):
        e, B = ctx.saved_tensors
        target_grad_shape = ctx.grad_output_shape
        e_flat = e.view(e.shape[0], -1)
        grad_output_est = e_flat.mm(B)
        grad_output_est = grad_output_est.view(target_grad_shape)
        return grad_output_est, None, None

def show_images(images, num_rows=3, num_cols=4, title="Generated Images"):
    if not images:
        print("No images to display.")
        return
    fig = plt.figure(figsize=(num_cols * 2, num_rows * 2))
    plt.suptitle(title)
    for i, img_pil in enumerate(images):
        if i >= num_rows * num_cols:
            break
        ax = fig.add_subplot(num_rows, num_cols, i + 1)
        if img_pil.mode == 'L':
            ax.imshow(img_pil, cmap='gray')
        else:
            ax.imshow(img_pil)
        ax.axis("off")
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

def _pos_encoding(time_idx, output_dim, device='cpu'):
    t, D = time_idx, output_dim
    v = torch.zeros(D, device=device)
    i = torch.arange(0, D, device=device)
    div_term = torch.exp(i / D * math.log(10000))
    v[0::2] = torch.sin(t / div_term[0::2])
    v[1::2] = torch.cos(t / div_term[1::2])
    return v

def pos_encoding(timesteps, output_dim, device='cpu'):
    batch_size = len(timesteps)
    v = torch.zeros(batch_size, output_dim, device=device)
    for i in range(batch_size):
        v[i] = _pos_encoding(timesteps[i], output_dim, device)
    return v


class ConvBlock1(nn.Module):
    def __init__(self, input_ch, output_ch, time_embed_dim, use_dfa=True):
        super().__init__()
        activation = nn.Tanh() if use_dfa else nn.ReLU()
        self.convs = nn.Sequential(nn.Conv2d(input_ch, output_ch, 5, padding=2), nn.BatchNorm2d(output_ch), activation)
        self.mlp = nn.Sequential(nn.Linear(time_embed_dim, input_ch), activation, nn.Linear(input_ch, input_ch))
    def forward(self, x, v):
        N, C, _, _ = x.shape
        v = self.mlp(v).view(N, C, 1, 1)
        return self.convs(x + v)

class UNet1(nn.Module):
    def __init__(self, input_ch=input_ch, time_embed_dim=time_embed_dim, image_size=image_size, batch_size=batch_size, device='cpu', use_dfa=True):
        super().__init__()
        self.time_embed_dim = time_embed_dim
        self.image_size = image_size
        self.batch_size = batch_size
        self.device = device
        self.use_dfa = use_dfa
        self.down1 = ConvBlock1(input_ch, 64 * scale, time_embed_dim)
        self.down2 = ConvBlock1(64 * scale, 128 * scale, time_embed_dim)
        self.bot1 = ConvBlock1(128 * scale, 256 * scale, time_embed_dim)
        self.up2 = ConvBlock1(128 * scale + 256 * scale, 128 * scale, time_embed_dim)
        self.up1 = ConvBlock1(128 * scale + 64 * scale, 64, time_embed_dim)
        self.out = nn.Conv2d(64, input_ch, 1)
        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        if self.use_dfa:
            self.e_channels = input_ch
            self.e_height = image_size
            self.e_width = image_size
            self.e = torch.zeros(self.batch_size, self.e_channels, self.e_height, self.e_width, device=self.device)
            e_flat_dim = self.e_channels * self.e_height * self.e_width
            x1_flat_dim = (64 * scale) * self.image_size * self.image_size
            self.Bc1 = nn.Parameter(torch.randn(e_flat_dim, x1_flat_dim), requires_grad=False)
            x2_flat_dim = (128 * scale) * (self.image_size // 2) * (self.image_size // 2)
            self.Bc2 = nn.Parameter(torch.randn(e_flat_dim, x2_flat_dim), requires_grad=False)
            x3_flat_dim = (256 * scale) * (self.image_size // 4) * (self.image_size // 4)
            self.Bc3 = nn.Parameter(torch.randn(e_flat_dim, x3_flat_dim), requires_grad=False)
            x4_flat_dim = (128 * scale) * (self.image_size // 2) * (self.image_size // 2)
            self.Bc4 = nn.Parameter(torch.randn(e_flat_dim, x4_flat_dim), requires_grad=False)
    
    def forward(self, x, timesteps, noise=None):
        v = pos_encoding(timesteps, self.time_embed_dim, x.device)
        x1_pre_inject = self.down1(x, v)
        x1 = Inject_e.apply(x1_pre_inject, self.e, self.Bc1) if self.use_dfa else x1_pre_inject
        x_down1_pooled = self.maxpool(x1)
        x2_pre_inject = self.down2(x_down1_pooled, v)
        x2 = Inject_e.apply(x2_pre_inject, self.e, self.Bc2) if self.use_dfa else x2_pre_inject
        x_down2_pooled = self.maxpool(x2)
        x_bot_pre_inject = self.bot1(x_down2_pooled, v)
        x_bot = Inject_e.apply(x_bot_pre_inject, self.e, self.Bc3) if self.use_dfa else x_bot_pre_inject
        x_up2_upsampled = self.upsample(x_bot)
        x_up2_concat = torch.cat([x_up2_upsampled, x2], dim=1)
        x_up2_pre_inject = self.up2(x_up2_concat, v)
        x_up2 = Inject_e.apply(x_up2_pre_inject, self.e, self.Bc4) if self.use_dfa else x_up2_pre_inject
        x_up1_upsampled = self.upsample(x_up2)
        x_up1_concat = torch.cat([x_up1_upsampled, x1], dim=1)
        x_final_conv = self.up1(x_up1_concat, v)
        output_noise_pred = self.out(x_final_conv)
        if self.use_dfa and noise is not None:
            self.e.data.copy_(output_noise_pred - noise)
        return output_noise_pred

class ConvBlock2(nn.Module):
    def __init__(self, input_ch, output_ch, time_embed_dim, use_dfa=True):
        super().__init__()
        activation = nn.Tanh() if use_dfa else nn.ReLU()
        self.convs = nn.Sequential(nn.Conv2d(input_ch, output_ch, 3, padding=1), nn.BatchNorm2d(output_ch), activation)
        self.mlp = nn.Sequential(nn.Linear(time_embed_dim, input_ch), activation, nn.Linear(input_ch, input_ch))
    def forward(self, x, v):
        N, C, _, _ = x.shape
        v = self.mlp(v).view(N, C, 1, 1)
        return self.convs(x + v)

class UNet2(nn.Module):
    def __init__(self, input_ch=input_ch, time_embed_dim=time_embed_dim, image_size=image_size, batch_size=batch_size, device='cpu', use_dfa=True):
        super().__init__()
        self.time_embed_dim = time_embed_dim
        self.image_size = image_size
        self.batch_size = batch_size
        self.device = device
        self.use_dfa = use_dfa
        self.down1 = ConvBlock2(input_ch, 64 * scale, time_embed_dim)
        self.down1_2 = ConvBlock2(64 * scale, 64 * scale, time_embed_dim)
        self.down2 = ConvBlock2(64 * scale, 128 * scale, time_embed_dim)
        self.down2_2 = ConvBlock2(128 * scale, 128 * scale, time_embed_dim)
        self.bot1 = ConvBlock2(128 * scale, 256 * scale, time_embed_dim)
        self.bot1_2 = ConvBlock2(256 * scale, 256 * scale, time_embed_dim)
        self.up2 = ConvBlock2(128 * scale + 256 * scale, 128 * scale, time_embed_dim)
        self.up2_2 = ConvBlock2(128 * scale, 128 * scale, time_embed_dim)
        self.up1 = ConvBlock2(128 * scale + 64 * scale, 64, time_embed_dim)
        self.up1_2 = ConvBlock2(64, 64, time_embed_dim)
        self.out = nn.Conv2d(64, input_ch, 1)
        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        if self.use_dfa:
            self.e_channels = input_ch
            self.e_height = image_size
            self.e_width = image_size
            self.e = torch.zeros(self.batch_size, self.e_channels, self.e_height, self.e_width, device=self.device)
            e_flat_dim = self.e_channels * self.e_height * self.e_width
            x1_flat_dim = (64 * scale) * self.image_size * self.image_size
            self.Bc1 = nn.Parameter(torch.randn(e_flat_dim, x1_flat_dim), requires_grad=False)
            self.Bc1_2 = nn.Parameter(torch.randn(e_flat_dim, x1_flat_dim), requires_grad=False)
            x2_flat_dim = (128 * scale) * (self.image_size // 2) * (self.image_size // 2)
            self.Bc2 = nn.Parameter(torch.randn(e_flat_dim, x2_flat_dim), requires_grad=False)
            self.Bc2_2 = nn.Parameter(torch.randn(e_flat_dim, x2_flat_dim), requires_grad=False)
            x3_flat_dim = (256 * scale) * (self.image_size // 4) * (self.image_size // 4)
            self.Bc3 = nn.Parameter(torch.randn(e_flat_dim, x3_flat_dim), requires_grad=False)
            self.Bc3_2 = nn.Parameter(torch.randn(e_flat_dim, x3_flat_dim), requires_grad=False)
            x4_flat_dim = (128 * scale) * (self.image_size // 2) * (self.image_size // 2)
            self.Bc4 = nn.Parameter(torch.randn(e_flat_dim, x4_flat_dim), requires_grad=False)
            self.Bc4_2 = nn.Parameter(torch.randn(e_flat_dim, x4_flat_dim), requires_grad=False)
    
    def forward(self, x, timesteps, noise=None):
        v = pos_encoding(timesteps, self.time_embed_dim, x.device)
        x1_pre_inject = self.down1(x, v)
        x1 = Inject_e.apply(x1_pre_inject, self.e, self.Bc1) if self.use_dfa else x1_pre_inject
        x1_conv2 = self.down1_2(x1, v)
        x1_conv2_injected = Inject_e.apply(x1_conv2, self.e, self.Bc1_2) if self.use_dfa else x1_conv2
        x_down1_pooled = self.maxpool(x1_conv2_injected)
        x2_pre_inject = self.down2(x_down1_pooled, v)
        x2 = Inject_e.apply(x2_pre_inject, self.e, self.Bc2) if self.use_dfa else x2_pre_inject
        x2_conv2 = self.down2_2(x2, v)
        x2_conv2_injected = Inject_e.apply(x2_conv2, self.e, self.Bc2_2) if self.use_dfa else x2_conv2
        x_down2_pooled = self.maxpool(x2_conv2_injected)
        x_bot_pre_inject = self.bot1(x_down2_pooled, v)
        x_bot = Inject_e.apply(x_bot_pre_inject, self.e, self.Bc3) if self.use_dfa else x_bot_pre_inject
        x_bot2_pre_inject = self.bot1_2(x_bot, v)
        x_bot2 = Inject_e.apply(x_bot2_pre_inject, self.e, self.Bc3_2) if self.use_dfa else x_bot2_pre_inject
        x_up2_upsampled = self.upsample(x_bot2)
        x_up2_concat = torch.cat([x_up2_upsampled, x2_conv2_injected], dim=1)
        x_up2_pre_inject = self.up2(x_up2_concat, v)
        x_up2 = Inject_e.apply(x_up2_pre_inject, self.e, self.Bc4) if self.use_dfa else x_up2_pre_inject
        x_up2_conv2 = self.up2_2(x_up2, v)
        x_up2_conv2_injected = Inject_e.apply(x_up2_conv2, self.e, self.Bc4_2) if self.use_dfa else x_up2_conv2
        x_up1_upsampled = self.upsample(x_up2_conv2_injected)
        x_up1_concat = torch.cat([x_up1_upsampled, x1_conv2_injected], dim=1)
        x_final_conv1 = self.up1(x_up1_concat, v)
        x_final_conv2 = self.up1_2(x_final_conv1, v)
        output_noise_pred = self.out(x_final_conv2)
        if self.use_dfa and noise is not None:
            self.e.data.copy_(output_noise_pred - noise)
        return output_noise_pred

class Diffuser:
    def __init__(self, num_timesteps=num_timesteps, beta_start=beta_start, beta_end=beta_end, device=device):
        self.num_timesteps = num_timesteps
        self.device = device
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps, device=device)
        self.alphas = 1 - self.betas
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)
        self.schedulers = {}

        # 1. DDPMScheduler
        self.schedulers['ddpm'] = DDPMScheduler(
            num_train_timesteps=self.num_timesteps,
            beta_start=beta_start,
            beta_end=beta_end,
            beta_schedule='linear',
            prediction_type='epsilon',
            clip_sample=True, 
        )

        # 2. DDIMScheduler (DPM-Solver-1)
        self.schedulers['ddim'] =  DPMSolverMultistepScheduler(
            num_train_timesteps=self.num_timesteps,
            beta_start=beta_start,
            beta_end=beta_end,
            beta_schedule='linear',
            algorithm_type="dpmsolver",
            solver_order=1,
            final_sigmas_type="sigma_min",
            timestep_spacing='linspace',
        )

        # 3. DPM-Solver-2
        self.schedulers['dpm_solver_2'] = DPMSolverMultistepScheduler(
            num_train_timesteps=self.num_timesteps,
            beta_start=beta_start,
            beta_end=beta_end,
            beta_schedule='linear',
            algorithm_type="dpmsolver",
            solver_order=2,
            final_sigmas_type="sigma_min",
            timestep_spacing='linspace',
        )
        
        # 4. DPM-Solver++ (DPM2M++)
        self.schedulers['dpm_solver_pp'] = DPMSolverMultistepScheduler(
            num_train_timesteps=self.num_timesteps,
            beta_start=beta_start,
            beta_end=beta_end,
            beta_schedule='linear',
            algorithm_type="dpmsolver++",
            solver_order=2,
            use_karras_sigmas=True,
            timestep_spacing='linspace',
        )
        

        print(f"Diffuser initialized with schedulers: {list(self.schedulers.keys())}")

    def add_noise(self, x_0, t):
        T = self.num_timesteps
        assert (t >= 1).all() and (t <= T).all()
        t_idx = t - 1
        alpha_bar = self.alpha_bars[t_idx]
        N = alpha_bar.size(0)
        alpha_bar = alpha_bar.view(N, 1, 1, 1)
        noise = torch.randn_like(x_0, device=self.device)
        x_t = torch.sqrt(alpha_bar) * x_0 + torch.sqrt(1 - alpha_bar) * noise
        return x_t, noise

    def reverse_to_img(self, x):
        x = x * 255
        x = x.clamp(0, 255)
        x = x.to(torch.uint8)
        x = x.cpu()
        to_pil = transforms.ToPILImage()
        return to_pil(x)

    def sample(self, model, sampler_type='ddpm', x_shape=(20, input_ch, image_size, image_size), num_sampling_steps=None, show_progress=True):        
        if sampler_type not in self.schedulers:
            print(f"Warning: Sampler '{sampler_type}' not found. Defaulting to 'ddpm'.")
            sampler_type = 'ddpm'
        scheduler = self.schedulers[sampler_type]
        
        if num_sampling_steps is None:
            num_sampling_steps = SAMPLING_STEPS.get(sampler_type, self.num_timesteps)

        batch_size = x_shape[0]
        x = torch.randn(x_shape, device=self.device)
        scheduler.set_timesteps(num_sampling_steps, device=self.device)
        timesteps = scheduler.timesteps

        model.eval()
        
        desc = f"Sampling ({sampler_type}, {num_sampling_steps} steps)"
        iterator = tqdm(timesteps, desc=desc) if show_progress else timesteps

        for t in iterator:
            t_input = torch.full((batch_size,), t.item(), device=self.device, dtype=torch.long)
            with torch.no_grad():
                noise_pred = model(x, t_input, noise=None)
            scheduler_output = scheduler.step(noise_pred, t, x, return_dict=False)
            x = scheduler_output[0]

        model.train()
        images = [self.reverse_to_img(x[i]) for i in range(batch_size)]

        return images

def load_dataset(dataset_name, transform, class_to_keep=None):
    print(f"Loading {dataset_name} dataset...")
        
    full_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    input_ch = 1
    if class_to_keep is not None:
        print(f"Filtering MNIST for class: {class_to_keep}...")
        indices = [i for i, (_, label) in enumerate(full_dataset) if label == class_to_keep]
        full_dataset = Subset(full_dataset, indices)
        print(f"Filtered dataset size: {len(full_dataset)}")

    print(f"Total images loaded: {len(full_dataset)}, Input channels: {input_ch}")

    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size

    if train_size == 0 or val_size == 0:
        print("Dataset too small to split. Using all data for training.")
        return full_dataset, None, input_ch

    print(f"Splitting dataset into train ({train_size}) and validation ({val_size})...")
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))
    return train_dataset, val_dataset, input_ch



print("\n--- 2. preparing for dataset ---")
preprocess_tensor = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
])

try:
    train_dataset, val_dataset, input_ch = load_dataset('mnist', preprocess_tensor, class_filter)
    val_len = len(val_dataset) if val_dataset is not None else 0
    print(f"training images : ({len(train_dataset)}), validation images : ({val_len})")
except Exception as e:
    print(f"Error loading dataset: {e}")
    sys.exit()


print("\n--- 3. preparing for training ---")
if USE_DFA: method_name = "DFA"
else: method_name = "BP"
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
run_name_base = f"{selected_dataset}_UNet{model_version}_{method_name}_size{image_size}_batch{batch_size}_steps{num_timesteps}"

if class_filter is not None:
    run_name = f"{run_name_base}_class{class_filter}_{timestamp}"
else:
    run_name = f"{run_name_base}_{timestamp}"
    
MODEL_SAVE_DIR = f"./models/{run_name}"
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
print(f"Results will be saved in '{MODEL_SAVE_DIR}'")

dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=os.cpu_count() // 2 if os.cpu_count() else 0)

diffuser = Diffuser(num_timesteps=num_timesteps, beta_start=beta_start, beta_end=beta_end, device=device)

if model_version == 1:
    model = UNet1(
        input_ch=input_ch, time_embed_dim=time_embed_dim,
        image_size=image_size, batch_size=batch_size,
        device=device, use_dfa=USE_DFA
    )
else: # model_version == 2
    model = UNet2(
        input_ch=input_ch, time_embed_dim=time_embed_dim,
        image_size=image_size, batch_size=batch_size,
        device=device, use_dfa=USE_DFA
    )
model.to(device)
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)



print(f"\n--- 4. Starting Training Loop ---")
print(f"Dataset: {selected_dataset} (Class: {class_filter if class_filter is not None else 'All'})")
print(f"Model: UNet{model_version} (scale: {scale})")
print(f"Mode: {method_name}")
print(f"Image Size: {image_size}x{image_size}")
print(f"Train Timesteps: {num_timesteps}")
print(f"Epochs: {epochs}, Batch Size: {batch_size}, LR: {lr}, Weight Decay: {weight_decay}")
print(f"Visualization Sampler: {EVAL_SAMPLER_TYPE} @ {EVAL_SAMPLING_STEPS} steps")
print(f"Results will be saved to: {MODEL_SAVE_DIR}")
print(f"----------------------------------")

losses = []
best_scores = {
    'loss': {'score': float('inf'), 'epoch': -1}
}

for epoch in range(epochs):
    loss_sum = 0.0
    cnt = 0
    
    print(f"\nEpoch {epoch}/{epochs-1}: Sampling images for visualization...")
    model.eval()
    with torch.no_grad():
        try:
            sampled_images_pil = diffuser.sample(model, 
                                                 sampler_type=EVAL_SAMPLER_TYPE,
                                                 x_shape=(20, input_ch, image_size, image_size), 
                                                 num_sampling_steps=EVAL_SAMPLING_STEPS,
                                                 show_progress=True)
            
            vis_dir = os.path.join(MODEL_SAVE_DIR, "visualizations")
            os.makedirs(vis_dir, exist_ok=True)
            vis_grid_tensor = torchvision.utils.make_grid([transforms.ToTensor()(img) for img in sampled_images_pil], nrow=4)
            vis_pil = transforms.ToPILImage()(vis_grid_tensor)
            vis_pil.save(os.path.join(vis_dir, f"epoch_{epoch:04d}.png"))
            print(f"Saved visualization grid to {vis_dir}/epoch_{epoch:04d}.png")
            
        except Exception as e:
            print(f"Visualization sampling failed: {e}")
    model.train()
    

    pbar_train = tqdm(dataloader, desc=f"Epoch {epoch} Training", leave=False)
    for images, labels in pbar_train:
        optimizer.zero_grad()
        x = images.to(device)
        t = torch.randint(1, num_timesteps + 1, (len(x),), device=device)

        try:
            x_noisy, noise = diffuser.add_noise(x, t)
            noise_pred = model(x_noisy, t, noise if USE_DFA else None)

            loss = F.mse_loss(noise, noise_pred)
            loss.backward()
            optimizer.step()

            loss_item = loss.item()
            loss_sum += loss_item
            cnt += 1
            pbar_train.set_postfix(loss=f"{loss_item:.4f}")
        except Exception as e:
            print(f"Error during training step: {e}")
            continue

    if cnt > 0:
        loss_avg = loss_sum / cnt
        losses.append(loss_avg)
        print(f'Epoch {epoch} finished | Average Loss: {loss_avg:.4f}')
        if loss_avg < best_scores['loss']['score']:
            best_scores['loss']['score'] = loss_avg
            best_scores['loss']['epoch'] = epoch
            torch.save(model.state_dict(), os.path.join(MODEL_SAVE_DIR, "best_loss_model.pth"))
            print(f"New best loss model saved...")
    else:
        print(f"Epoch {epoch} had no successful batches.")
        losses.append(None)

print("\n--- 5. Training Finished ---")
print(f"Results saved in {MODEL_SAVE_DIR}")
print(f"Best Loss: {best_scores['loss']['score']:.4f} at Epoch {best_scores['loss']['epoch']}")


print("\n--- 6. Generating Result Plots ---")
epochs_list = list(range(epochs))
valid_losses = [l for l in losses if l is not None and l != float('inf')]
valid_loss_epochs = [epochs_list[i] for i, l in enumerate(losses) if l is not None and l != float('inf')]


if valid_losses:
    try:
        plt.figure(figsize=(12, 8))
        plt.plot(valid_loss_epochs, valid_losses, label='Training Loss', marker='.')
        plt.xlabel("Epoch"); plt.ylabel("Loss")
        plt.title(f"Training Loss Curve\n{run_name}")
        plt.legend(); plt.grid(True); plt.tight_layout()
        save_filename = f"{run_name}_loss_curve.png"
        plt.savefig(os.path.join(MODEL_SAVE_DIR, save_filename))
        plt.close()
    except Exception as e:
        print(f"Failed to generate/save loss curve: {e}")

print(f"\nAll processes finished for {run_name}.")
print(f"## {MODEL_SAVE_DIR}")
print(f"losses:{losses}")
print(f"Best Loss: {best_scores['loss']['score']:.4f} at Epoch {best_scores['loss']['epoch']}")