In [None]:
# 1. Module
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, ConcatDataset, random_split
from torch.optim import AdamW
import matplotlib.pyplot as plt
import os
import math
import shutil
from datetime import datetime
from tqdm import tqdm
import sys
import torch_fidelity
from typing import List, Tuple
from diffusers import DPMSolverMultistepScheduler, DDPMScheduler
import numpy as np

# 2. Hyperparameter
image_size = 8
batch_size = 8
device = 'cuda' if torch.cuda.is_available() else 'cpu'
num_timesteps = 250 # タイムステップ250用
beta_start = 1e-4
beta_end = 0.02
DPM_SOLVER_STEPS = 40
scale = 3
input_ch = 3
time_embed_dim = 100
epochs = 51
lr = 1e-3
weight_decay = 1e-4
eval_freq = 1
num_eval_samples = 500
REAL_DIR = "./real_images"
GEN_DIR = "./generated_images"
preprocess_tensor = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
])

# 3. for DFA
class Inject_e(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, e, B):
        ctx.save_for_backward(e, B)
        ctx.grad_output_shape = x.shape
        return x

    @staticmethod
    def backward(ctx, grad_output):
        e, B = ctx.saved_tensors
        target_grad_shape = ctx.grad_output_shape
        e_flat = e.view(e.shape[0], -1)
        grad_output_est = e_flat.mm(B)
        grad_output_est = grad_output_est.view(target_grad_shape)
        return grad_output_est, None, None

# 4. Helper function
def show_images(images, num_rows=3, num_cols=4, title="Generated Images"):
    if not images:
        print("No images to display.")
        return
    fig = plt.figure(figsize=(num_cols * 2, num_rows * 2))
    plt.suptitle(title)
    for i, img_pil in enumerate(images):
        if i >= num_rows * num_cols:
            break
        ax = fig.add_subplot(num_rows, num_cols, i + 1)
        ax.imshow(img_pil)
        ax.axis("off")
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

def _pos_encoding(time_idx, output_dim, device='cpu'):
    t, D = time_idx, output_dim
    v = torch.zeros(D, device=device)
    i = torch.arange(0, D, device=device)
    div_term = torch.exp(i / D * math.log(10000))
    v[0::2] = torch.sin(t / div_term[0::2])
    v[1::2] = torch.cos(t / div_term[1::2])
    return v

def pos_encoding(timesteps, output_dim, device='cpu'):
    batch_size = len(timesteps)
    device = timesteps.device
    v = torch.zeros(batch_size, output_dim, device=device)
    for i in range(batch_size):
        v[i] = _pos_encoding(timesteps[i], output_dim, device)
    return v

# 5-1. U-Net1 architecture
class ConvBlock1(nn.Module):
    def __init__(self, input_ch, output_ch, time_embed_dim, use_dfa=True):
        super().__init__()
        activation = nn.Tanh() if use_dfa else nn.ReLU()
        self.convs = nn.Sequential(nn.Conv2d(input_ch, output_ch, 5, padding=2), nn.BatchNorm2d(output_ch), activation)
        self.mlp = nn.Sequential(nn.Linear(time_embed_dim, input_ch), activation, nn.Linear(input_ch, input_ch))
    def forward(self, x, v):
        N, C, _, _ = x.shape
        v = self.mlp(v).view(N, C, 1, 1)
        return self.convs(x + v)

class UNet1(nn.Module):
    def __init__(self, input_ch=input_ch, time_embed_dim=time_embed_dim, image_size=image_size, batch_size=batch_size, device='cpu', use_dfa=True):
        super().__init__()
        self.time_embed_dim = time_embed_dim
        self.image_size = image_size
        self.batch_size = batch_size
        self.device = device
        self.use_dfa = use_dfa
        self.down1 = ConvBlock1(input_ch, 64 * scale, time_embed_dim)
        self.down2 = ConvBlock1(64 * scale, 128 * scale, time_embed_dim)
        self.bot1 = ConvBlock1(128 * scale, 256 * scale, time_embed_dim)
        self.up2 = ConvBlock1(128 * scale + 256 * scale, 128 * scale, time_embed_dim)
        self.up1 = ConvBlock1(128 * scale + 64 * scale, 64, time_embed_dim)
        self.out = nn.Conv2d(64, input_ch, 1)
        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        if self.use_dfa:
            self.e_channels = input_ch
            self.e_height = image_size
            self.e_width = image_size
            self.e = torch.zeros(self.batch_size, self.e_channels, self.e_height, self.e_width, device=self.device)
            e_flat_dim = self.e_channels * self.e_height * self.e_width
            x1_flat_dim = (64 * scale) * self.image_size * self.image_size
            self.Bc1 = nn.Parameter(torch.randn(e_flat_dim, x1_flat_dim), requires_grad=False)
            x2_flat_dim = (128 * scale) * (self.image_size // 2) * (self.image_size // 2)
            self.Bc2 = nn.Parameter(torch.randn(e_flat_dim, x2_flat_dim), requires_grad=False)
            x3_flat_dim = (256 * scale) * (self.image_size // 4) * (self.image_size // 4)
            self.Bc3 = nn.Parameter(torch.randn(e_flat_dim, x3_flat_dim), requires_grad=False)
            x4_flat_dim = (128 * scale) * (self.image_size // 2) * (self.image_size // 2)
            self.Bc4 = nn.Parameter(torch.randn(e_flat_dim, x4_flat_dim), requires_grad=False)
    
    def forward(self, x, timesteps, noise=None):
        v = pos_encoding(timesteps, self.time_embed_dim, x.device)
        x1_pre_inject = self.down1(x, v)
        x1 = Inject_e.apply(x1_pre_inject, self.e, self.Bc1) if self.use_dfa else x1_pre_inject
        x_down1_pooled = self.maxpool(x1)
        x2_pre_inject = self.down2(x_down1_pooled, v)
        x2 = Inject_e.apply(x2_pre_inject, self.e, self.Bc2) if self.use_dfa else x2_pre_inject
        x_down2_pooled = self.maxpool(x2)
        x_bot_pre_inject = self.bot1(x_down2_pooled, v)
        x_bot = Inject_e.apply(x_bot_pre_inject, self.e, self.Bc3) if self.use_dfa else x_bot_pre_inject
        x_up2_upsampled = self.upsample(x_bot)
        x_up2_concat = torch.cat([x_up2_upsampled, x2], dim=1)
        x_up2_pre_inject = self.up2(x_up2_concat, v)
        x_up2 = Inject_e.apply(x_up2_pre_inject, self.e, self.Bc4) if self.use_dfa else x_up2_pre_inject
        x_up1_upsampled = self.upsample(x_up2)
        x_up1_concat = torch.cat([x_up1_upsampled, x1], dim=1)
        x_final_conv = self.up1(x_up1_concat, v)
        output_noise_pred = self.out(x_final_conv)
        if self.use_dfa and noise is not None:
            self.e.data.copy_(output_noise_pred - noise)
        return output_noise_pred

# 5-2. U-Net2 architecture
class ConvBlock2(nn.Module):
    def __init__(self, input_ch, output_ch, time_embed_dim, use_dfa=True):
        super().__init__()
        activation = nn.Tanh() if use_dfa else nn.ReLU()
        self.convs = nn.Sequential(nn.Conv2d(input_ch, output_ch, 3, padding=1), nn.BatchNorm2d(output_ch), activation)
        self.mlp = nn.Sequential(nn.Linear(time_embed_dim, input_ch), activation, nn.Linear(input_ch, input_ch))
    def forward(self, x, v):
        N, C, _, _ = x.shape
        v = self.mlp(v).view(N, C, 1, 1)
        return self.convs(x + v)

class UNet2(nn.Module):
    def __init__(self, input_ch=input_ch, time_embed_dim=time_embed_dim, image_size=image_size, batch_size=batch_size, device='cpu', use_dfa=True):
        super().__init__()
        self.time_embed_dim = time_embed_dim
        self.image_size = image_size
        self.batch_size = batch_size
        self.device = device
        self.use_dfa = use_dfa
        self.down1 = ConvBlock2(input_ch, 64 * scale, time_embed_dim)
        self.down1_2 = ConvBlock2(64 * scale, 64 * scale, time_embed_dim)
        self.down2 = ConvBlock2(64 * scale, 128 * scale, time_embed_dim)
        self.down2_2 = ConvBlock2(128 * scale, 128 * scale, time_embed_dim)
        self.bot1 = ConvBlock2(128 * scale, 256 * scale, time_embed_dim)
        self.bot1_2 = ConvBlock2(256 * scale, 256 * scale, time_embed_dim)
        self.up2 = ConvBlock2(128 * scale + 256 * scale, 128 * scale, time_embed_dim)
        self.up2_2 = ConvBlock2(128 * scale, 128 * scale, time_embed_dim)
        self.up1 = ConvBlock2(128 * scale + 64 * scale, 64, time_embed_dim)
        self.up1_2 = ConvBlock2(64, 64, time_embed_dim)
        self.out = nn.Conv2d(64, input_ch, 1)
        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        if self.use_dfa:
            self.e_channels = input_ch
            self.e_height = image_size
            self.e_width = image_size
            self.e = torch.zeros(self.batch_size, self.e_channels, self.e_height, self.e_width, device=self.device)
            e_flat_dim = self.e_channels * self.e_height * self.e_width
            x1_flat_dim = (64 * scale) * self.image_size * self.image_size
            self.Bc1 = nn.Parameter(torch.randn(e_flat_dim, x1_flat_dim), requires_grad=False)
            self.Bc1_2 = nn.Parameter(torch.randn(e_flat_dim, x1_flat_dim), requires_grad=False)
            x2_flat_dim = (128 * scale) * (self.image_size // 2) * (self.image_size // 2)
            self.Bc2 = nn.Parameter(torch.randn(e_flat_dim, x2_flat_dim), requires_grad=False)
            self.Bc2_2 = nn.Parameter(torch.randn(e_flat_dim, x2_flat_dim), requires_grad=False)
            x3_flat_dim = (256 * scale) * (self.image_size // 4) * (self.image_size // 4)
            self.Bc3 = nn.Parameter(torch.randn(e_flat_dim, x3_flat_dim), requires_grad=False)
            self.Bc3_2 = nn.Parameter(torch.randn(e_flat_dim, x3_flat_dim), requires_grad=False)
            x4_flat_dim = (128 * scale) * (self.image_size // 2) * (self.image_size // 2)
            self.Bc4 = nn.Parameter(torch.randn(e_flat_dim, x4_flat_dim), requires_grad=False)
            self.Bc4_2 = nn.Parameter(torch.randn(e_flat_dim, x4_flat_dim), requires_grad=False)
    
    def forward(self, x, timesteps, noise=None):
        v = pos_encoding(timesteps, self.time_embed_dim, x.device)
        x1_pre_inject = self.down1(x, v)
        x1 = Inject_e.apply(x1_pre_inject, self.e, self.Bc1) if self.use_dfa else x1_pre_inject
        x1_conv2 = self.down1_2(x1, v)
        x1_conv2_injected = Inject_e.apply(x1_conv2, self.e, self.Bc1_2) if self.use_dfa else x1_conv2
        x_down1_pooled = self.maxpool(x1_conv2_injected)
        x2_pre_inject = self.down2(x_down1_pooled, v)
        x2 = Inject_e.apply(x2_pre_inject, self.e, self.Bc2) if self.use_dfa else x2_pre_inject
        x2_conv2 = self.down2_2(x2, v)
        x2_conv2_injected = Inject_e.apply(x2_conv2, self.e, self.Bc2_2) if self.use_dfa else x2_conv2
        x_down2_pooled = self.maxpool(x2_conv2_injected)
        x_bot_pre_inject = self.bot1(x_down2_pooled, v)
        x_bot = Inject_e.apply(x_bot_pre_inject, self.e, self.Bc3) if self.use_dfa else x_bot_pre_inject
        x_bot2_pre_inject = self.bot1_2(x_bot, v)
        x_bot2 = Inject_e.apply(x_bot2_pre_inject, self.e, self.Bc3_2) if self.use_dfa else x_bot2_pre_inject
        x_up2_upsampled = self.upsample(x_bot2)
        x_up2_concat = torch.cat([x_up2_upsampled, x2_conv2_injected], dim=1)
        x_up2_pre_inject = self.up2(x_up2_concat, v)
        x_up2 = Inject_e.apply(x_up2_pre_inject, self.e, self.Bc4) if self.use_dfa else x_up2_pre_inject
        x_up2_conv2 = self.up2_2(x_up2, v)
        x_up2_conv2_injected = Inject_e.apply(x_up2_conv2, self.e, self.Bc4_2) if self.use_dfa else x_up2_conv2
        x_up1_upsampled = self.upsample(x_up2_conv2_injected)
        x_up1_concat = torch.cat([x_up1_upsampled, x1_conv2_injected], dim=1)
        x_final_conv1 = self.up1(x_up1_concat, v)
        x_final_conv2 = self.up1_2(x_final_conv1, v)
        output_noise_pred = self.out(x_final_conv2)
        if self.use_dfa and noise is not None:
            self.e.data.copy_(output_noise_pred - noise)
        return output_noise_pred

# 6. Diffusers
class Diffuser:
    def __init__(self, num_timesteps=num_timesteps, beta_start=beta_start, beta_end=beta_end, device=device):
        self.num_timesteps = num_timesteps
        self.device = device
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps, device=device)
        self.alphas = 1 - self.betas
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)
        self.schedulers = {}

        # 6-1. DDPMScheduler
        self.schedulers['ddpm'] = DDPMScheduler(
            num_train_timesteps=self.num_timesteps,
            beta_start=beta_start,
            beta_end=beta_end,
            beta_schedule='linear',
            prediction_type='epsilon',
            clip_sample=True, 
        )

        # 6-2. DDIMScheduler (DPM-Solver-1)
        self.schedulers['ddim'] =  DPMSolverMultistepScheduler(
            num_train_timesteps=self.num_timesteps,
            beta_start=beta_start,
            beta_end=beta_end,
            beta_schedule='linear',
            algorithm_type="dpmsolver",
            solver_order=1,
            final_sigmas_type="sigma_min",
            timestep_spacing='linspace',
        )

        # 6-3. DPM-Solver-2
        self.schedulers['dpm_solver_2'] = DPMSolverMultistepScheduler(
            num_train_timesteps=self.num_timesteps,
            beta_start=beta_start,
            beta_end=beta_end,
            beta_schedule='linear',
            algorithm_type="dpmsolver",
            solver_order=2,
            final_sigmas_type="sigma_min",
            timestep_spacing='linspace',
        )
        
        # 6-4. DPM-Solver++ (DPM2M++)
        self.schedulers['dpm_solver_pp'] = DPMSolverMultistepScheduler(
            num_train_timesteps=self.num_timesteps,
            beta_start=beta_start,
            beta_end=beta_end,
            beta_schedule='linear',
            algorithm_type="dpmsolver++",
            solver_order=2,
            use_karras_sigmas=True,
            timestep_spacing='linspace',
        )
        

        print(f"Diffuser initialized with schedulers: {list(self.schedulers.keys())}")

    def add_noise(self, x_0, t):
        T = self.num_timesteps
        assert (t >= 1).all() and (t <= T).all()
        t_idx = t - 1
        alpha_bar = self.alpha_bars[t_idx]
        N = alpha_bar.size(0)
        alpha_bar = alpha_bar.view(N, 1, 1, 1)
        noise = torch.randn_like(x_0, device=self.device)
        x_t = torch.sqrt(alpha_bar) * x_0 + torch.sqrt(1 - alpha_bar) * noise
        return x_t, noise

    def reverse_to_img(self, x):
        x = x * 255
        x = x.clamp(0, 255)
        x = x.to(torch.uint8)
        x = x.cpu()
        to_pil = transforms.ToPILImage()
        return to_pil(x)

    def sample(self, model, sampler_type='ddpm', x_shape=(20, input_ch, image_size, image_size), num_sampling_steps=None, show_progress=True):        
        if sampler_type not in self.schedulers:
            print(f"Warning: Sampler '{sampler_type}' not found. Defaulting to 'ddpm'.")
            sampler_type = 'ddpm'
        scheduler = self.schedulers[sampler_type]
        
        if num_sampling_steps is None:
            num_sampling_steps = SAMPLING_STEPS.get(sampler_type, self.num_timesteps)

        batch_size = x_shape[0]
        x = torch.randn(x_shape, device=self.device)
        scheduler.set_timesteps(num_sampling_steps, device=self.device)
        timesteps = scheduler.timesteps

        model.eval()
        
        desc = f"Sampling ({sampler_type}, {num_sampling_steps} steps)"
        iterator = tqdm(timesteps, desc=desc) if show_progress else timesteps

        for t in iterator:
            t_input = torch.full((batch_size,), t.item(), device=self.device, dtype=torch.long)
            with torch.no_grad():
                noise_pred = model(x, t_input, noise=None)
            scheduler_output = scheduler.step(noise_pred, t, x, return_dict=False)
            x = scheduler_output[0]

        model.train()
        images = [self.reverse_to_img(x[i]) for i in range(batch_size)]

        return images


# 7. Training
# --- Select num_timesteps ---
num_timesteps = None
while num_timesteps is None:
    try:
        timesteps_input = int(input("Select num_timesteps (e.g. 250, 500, 750, 1000): ").strip())
        if timesteps_input > 0:
            num_timesteps = timesteps_input
            print(f"num_timesteps selected: {num_timesteps}")
        else:
            print("Invalid input. Please enter a positive number.")
    except ValueError:
        print("Invalid input. Please enter a number.")

# --- Select Training Mode (DFA/BP) ---
USE_DFA = None
while USE_DFA is None:
    mode_input = input("Select training mode (DFA/BP): ").strip().lower()
    if mode_input == 'dfa': USE_DFA = True
    elif mode_input == 'bp': USE_DFA = False
    else: print("Invalid input. Please type 'DFA' or 'BP'.")
batch_size = 8 if USE_DFA else 64
scale = 3 if USE_DFA else 1
print(f"Batch size set to: {batch_size}, scale set to: {scale}")

# --- Select U-Net version ---
model_version = None
while model_version is None:
    model_input = input("Select UNet version (1/2): ").strip()
    if model_input == '1':
        model_version = 1
        print("UNet version selected: UNet1 (1 ConvBlock per layer)")
    elif model_input == '2':
        model_version = 2
        print("UNet version selected: UNet2 (2 ConvBlocks per layer)")
    else:
        print("Invalid input. Please type '1' or '2'.")

# --- Select image_size ---
image_size = None
while image_size is None:
    try:
        size_input = int(input("Select image size (e.g., 8, 16, 32, 64): ").strip())
        if size_input > 0:
            image_size = size_input
            print(f"Image size selected: {image_size}")
        else:
            print("Invalid input. Please enter a positive number.")
    except ValueError:
        print("Invalid input. Please enter a number.")

# --- Select Sampler  ---
sampler_type = None
sampling_steps = DPM_SOLVER_STEPS # Default 40

while sampler_type is None:
    print("\nSelect Sampler for Evaluation:")
    print("1: DDPM (Slow but standard)")
    print("2: DDIM (DPM-Solver-1)")
    print("3: DPM-Solver-2")
    print("4: DPM-Solver++")
    
    s_input = input("Enter choice (1-4): ").strip()
    
    if s_input == '1':
        sampler_type = 'ddpm'
        sampling_steps = num_timesteps # DDPMは全ステップ推奨
    elif s_input == '2':
        sampler_type = 'ddim'
    elif s_input == '3':
        sampler_type = 'dpm_solver_2'
    elif s_input == '4':
        sampler_type = 'dpm_solver_pp'
    else:
        print("Invalid input.")
        
print(f"Sampler selected: {sampler_type}, Sampling steps: {sampling_steps}")


# --- Load dataset, combine, split ---
print("\nLoading Oxford Flowers 102 dataset (all splits: train, val, test) and combining them...")
try:
    original_train_ds = torchvision.datasets.Flowers102(root='./data', split='train', download=True, transform=preprocess_tensor)
    original_val_ds = torchvision.datasets.Flowers102(root='./data', split='val', download=True, transform=preprocess_tensor)
    original_test_ds = torchvision.datasets.Flowers102(root='./data', split='test', download=True, transform=preprocess_tensor)
    full_combined_dataset = ConcatDataset([original_train_ds, original_val_ds, original_test_ds])
    print(f"Total images: {len(full_combined_dataset)}")
    train_size = int(0.8 * len(full_combined_dataset))
    val_size = len(full_combined_dataset) - train_size
    train_dataset, val_dataset = random_split(full_combined_dataset, [train_size, val_size],
                                              generator=torch.Generator().manual_seed(42))
    print(f"Training : ({len(train_dataset)}), Validation({len(val_dataset)})")
except Exception as e:
    print(f"Error loading Flowers102 dataset: {e}")
    try:
        train_dataset = torchvision.datasets.Flowers102(root='./data', download=True, transform=preprocess_tensor)
        val_dataset = None
        full_combined_dataset = train_dataset
    except Exception as e_retry:
        print(f"Fatal: Could not load Flowers102 dataset even with default settings: {e_retry}")
        sys.exit()

# --- Setting for the saving directory ---
if USE_DFA:
    method_name = "DFA"
else:
    method_name = "BP"

timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
run_name = f"UNet{model_version}_{method_name}_train timestep{num_timesteps}_size{image_size}_batch{batch_size}_{timestamp}"
MODEL_SAVE_DIR = f"./models/{run_name}"
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
print(f"Results will be saved in '{MODEL_SAVE_DIR}' ")

dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=os.cpu_count() // 2 if os.cpu_count() else 0)

diffuser = Diffuser(num_timesteps, device=device)

if model_version == 1:
    model = UNet1(
        input_ch=input_ch,
        time_embed_dim=time_embed_dim,
        image_size=image_size,
        batch_size=batch_size,
        device=device,
        use_dfa=USE_DFA
    )
else: # model_version == 2
    model = UNet2(
        input_ch=input_ch,
        time_embed_dim=time_embed_dim,
        image_size=image_size,
        batch_size=batch_size,
        device=device,
        use_dfa=USE_DFA
    )

model.to(device)
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

REAL_DIR = f"./real_images_size{image_size}"
real_images_for_eval_source = val_dataset if val_dataset else train_dataset
actual_num_real_eval_samples = min(num_eval_samples, len(real_images_for_eval_source))

print(f"\n Saving for torch-fidelity (from Validation dataset : {actual_num_real_eval_samples})")
if not os.path.exists(REAL_DIR) or len(os.listdir(REAL_DIR)) < actual_num_real_eval_samples:
    print(f"Saving {actual_num_real_eval_samples} real images for calculation...")
    if os.path.exists(REAL_DIR):
        shutil.rmtree(REAL_DIR)
    os.makedirs(REAL_DIR, exist_ok=True)

    real_images_saved_count = 0
    temp_dataloader_for_real = DataLoader(real_images_for_eval_source, batch_size=1, shuffle=True)
    for img_tensor, _ in tqdm(temp_dataloader_for_real, desc="Saving real images"):
        if real_images_saved_count >= actual_num_real_eval_samples:
            break
        img_pil = transforms.ToPILImage()(img_tensor.squeeze(0))
        img_pil.save(os.path.join(REAL_DIR, f"real_{real_images_saved_count:05d}.png"))
        real_images_saved_count += 1
    print(f"{real_images_saved_count} real images saved to {REAL_DIR}")
else:
    print(f"Enough real images already exist in {REAL_DIR}. Skipping saving.")

print(f"scale : {scale}, batch_size : {batch_size}, learning_rate : {lr}, weight_decay : {weight_decay}")

losses = []
fid_scores = []
kid_scores = []
is_scores = []

best_scores = {
    'fid': {'score': float('inf'), 'epoch': -1},
    'kid': {'score': float('inf'), 'epoch': -1},
    'is': {'score': float('-inf'), 'epoch': -1},
    'loss': {'score': float('inf'), 'epoch': -1}
}

for epoch in range(epochs):
    loss_sum = 0.0
    cnt = 0
    
    # 修正: 選択したサンプラーとステップ数を使用
    print(f"\nEpoch {epoch}: Sampling images for visualization ({sampler_type}, {sampling_steps} steps)...")
    model.eval()
    with torch.no_grad():
        sampled_images = diffuser.sample(model, x_shape=(20, input_ch, image_size, image_size), 
                                         sampler_type=sampler_type, # 修正
                                         num_sampling_steps=sampling_steps, # 修正
                                         show_progress=True)
        show_images(sampled_images, title=f"Epoch {epoch} Generated Images ({sampler_type.upper()})")

    model.train()
    
    if (epoch + 1) % eval_freq == 0:
        print(f"Epoch {epoch}: Calculating FID, KID, IS...")

        model.eval()

        GENERATED_DIR_RUN = os.path.join(MODEL_SAVE_DIR, f"generated_images_epoch_{epoch}")
        if os.path.exists(GENERATED_DIR_RUN):
            shutil.rmtree(GENERATED_DIR_RUN)
        os.makedirs(GENERATED_DIR_RUN, exist_ok=True)

        generated_images_count = 0
        with tqdm(total=num_eval_samples, desc="Generating & Saving for evaluation", unit="img") as pbar:
            num_to_generate_per_call = 20
            num_sample_calls = math.ceil(num_eval_samples / num_to_generate_per_call)

            for _ in range(num_sample_calls):
                if generated_images_count >= num_eval_samples:
                    break
                
                # 修正: 選択したサンプラーとステップ数を使用
                gen_batch_pil = diffuser.sample(model, x_shape=(num_to_generate_per_call, input_ch, image_size, image_size), 
                                                sampler_type=sampler_type, # 修正
                                                num_sampling_steps=sampling_steps, # 修正
                                                show_progress=False)

                for img_pil in gen_batch_pil:
                    if generated_images_count >= num_eval_samples:
                        break
                    img_pil.save(os.path.join(GENERATED_DIR_RUN, f"gen_{generated_images_count:05d}.png"))
                    generated_images_count += 1
                    pbar.update(1)

        print(f"{generated_images_count} generated images saved to {GENERATED_DIR_RUN}")

        try:
            metrics_dict = torch_fidelity.calculate_metrics(
                input1=GENERATED_DIR_RUN,
                input2=REAL_DIR,
                cuda=True,
                fid=True,
                kid=True,
                isc=True,
                prc=False,
                kid_subset_size=num_eval_samples
            )

            fid_score = metrics_dict['frechet_inception_distance']
            kid_score = metrics_dict['kernel_inception_distance_mean']
            is_score = metrics_dict['inception_score_mean']

            print(f"Epoch {epoch} | FID: {fid_score:.4f} | KID: {kid_score:.4f} | IS: {is_score:.4f}")

            fid_scores.append(fid_score)
            kid_scores.append(kid_score)
            is_scores.append(is_score)
            
            current_eval_epoch = epoch
            
            
            if kid_score < best_scores['kid']['score']:
                best_scores['kid']['score'] = kid_score
                best_scores['kid']['epoch'] = current_eval_epoch
                torch.save(model.state_dict(), os.path.join(MODEL_SAVE_DIR, "best_kid_model.pth"))
                print(f"New best KID model saved with KID: {best_scores['kid']['score']:.4f} at Epoch {current_eval_epoch}")

            if fid_score < best_scores['fid']['score']:
                best_scores['fid']['score'] = fid_score
                best_scores['fid']['epoch'] = current_eval_epoch
                torch.save(model.state_dict(), os.path.join(MODEL_SAVE_DIR, "best_fid_model.pth"))
                print(f"New best FID model saved with FID: {best_scores['fid']['score']:.4f} at Epoch {current_eval_epoch}")

            if is_score > best_scores['is']['score']:
                best_scores['is']['score'] = is_score
                best_scores['is']['epoch'] = current_eval_epoch
                torch.save(model.state_dict(), os.path.join(MODEL_SAVE_DIR, "best_is_model.pth"))
                print(f"New best IS model saved with IS: {best_scores['is']['score']:.4f} at Epoch {current_eval_epoch}")

        except Exception as e:
            print(f"FID/KID/IS calculation failed: {e}")
            print("Make sure you have enough samples and torch-fidelity is correctly installed.")
            fid_scores.append(None)
            kid_scores.append(None)
            is_scores.append(None)

        shutil.rmtree(GENERATED_DIR_RUN)
        model.train()


    for images, labels in tqdm(dataloader, desc=f"Epoch {epoch} Training"):
        optimizer.zero_grad()
        x = images.to(device)
        t = torch.randint(1, num_timesteps + 1, (len(x),), device=device)

        x_noisy, noise = diffuser.add_noise(x, t)
        noise_pred = model(x_noisy, t, noise)

        loss = F.mse_loss(noise, noise_pred)

        loss.backward()
        optimizer.step()

        loss_sum += loss.item()
        cnt += 1

    loss_avg = loss_sum / cnt
    if loss_avg < best_scores['loss']['score']:
        best_scores['loss']['score'] = loss_avg
        best_scores['loss']['epoch'] = epoch + 1
        torch.save(model.state_dict(), os.path.join(MODEL_SAVE_DIR, "best_loss_model.pth"))
    losses.append(loss_avg)
    print(f'Epoch {epoch} | Loss: {loss_avg}')

print("Training finished.")

# 8. Graphs
import matplotlib.pyplot as plt
import numpy as np
import os
from typing import List, Tuple


epochs_evaluated_fid_kid_is = [(i + 0) for i in range(epochs) if (i + 1) % eval_freq == 0] 

best_scores = {
    'loss': {'score': float('inf'), 'epoch': -1}, # min
    'fid': {'score': float('inf'), 'epoch': -1},  # min
    'kid': {'score': float('inf'), 'epoch': -1},  # min
    'is': {'score': float('-inf'), 'epoch': -1}   # max
}

# Loss
min_loss_idx = np.argmin(losses)
best_scores['loss']['score'] = losses[min_loss_idx]
best_scores['loss']['epoch'] = min_loss_idx 

# FID
min_fid_idx = np.argmin(fid_scores)
best_scores['fid']['score'] = fid_scores[min_fid_idx]
best_scores['fid']['epoch'] = epochs_evaluated_fid_kid_is[min_fid_idx]

# KID
min_kid_idx = np.argmin(kid_scores)
best_scores['kid']['score'] = kid_scores[min_kid_idx]
best_scores['kid']['epoch'] = epochs_evaluated_fid_kid_is[min_kid_idx]

# IS
max_is_idx = np.argmax(is_scores)
best_scores['is']['score'] = is_scores[max_is_idx]
best_scores['is']['epoch'] = epochs_evaluated_fid_kid_is[max_is_idx]

def plot_scores_with_annotations(
    epochs_list: List[int], 
    scores: List[float],
    best_info: dict,
    title: str,
    ylabel: str,
    metric_type: str,
    color: str
):
   
    valid_data = [(e, s) for e, s in zip(epochs_list, scores) if s is not None]
    if not valid_data:
        print(f"No valid data to plot for {ylabel}.")
        return

    valid_epochs = [d[0] for d in valid_data]
    valid_scores = [d[1] for d in valid_data]

    plt.figure(figsize=(10, 6))
    plt.plot(valid_epochs, valid_scores, label=ylabel, color=color, marker='o', markersize=5, linestyle='-')

    best_epoch_idx = best_info['epoch'] 

    best_score = best_info['score']
    
    plot_epoch = best_epoch_idx 

    if best_epoch_idx != -1:
        
        plt.scatter(plot_epoch, best_score, color='red', s=150, zorder=5, label=f'Best ({metric_type.capitalize()})', marker='*')

        best_text = f'{metric_type.capitalize()}: {best_score:.4f}\n(Epoch {plot_epoch})' 

        std_scores = np.std(valid_scores)
        if metric_type == 'min':
            text_y_pos = best_score + std_scores * 0.5
            if text_y_pos > max(valid_scores):
                text_y_pos = best_score - std_scores * 0.5
        else: # metric_type == 'max'
            text_y_pos = best_score - std_scores * 0.5
            if text_y_pos < min(valid_scores):
                text_y_pos = best_score + std_scores * 0.5

        plt.annotate(
            best_text,
            xy=(plot_epoch, best_score),
            xytext=(plot_epoch, text_y_pos),
            arrowprops=dict(facecolor='black', shrink=0.05, width=0.5, headwidth=8),
            fontsize=10,
            ha='center',
            bbox=dict(boxstyle="round,pad=0.3", fc="white", alpha=0.7)
        )

    plt.title(title, fontsize=12, fontweight='bold')
    plt.xlabel('Epoch', fontsize=10)
    plt.ylabel(ylabel, fontsize=10)
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.legend(loc='best')
    
    all_epochs = valid_epochs + [best_epoch_idx]
    tick_step = max(1, len(valid_epochs) // 10)
    plt.xticks(sorted(list(set([e for e in all_epochs if e % tick_step == 0] + [best_epoch_idx]))))
    
    plt.tight_layout()
    plt.savefig(os.path.join(MODEL_SAVE_DIR, f"{ylabel.lower().replace(' ', '_')}_curve.png"))
    plt.show()

plt.figure(figsize=(10, 6))
plt.plot([(e) for e in range(len(losses))], losses, label='Training Loss') 
plt.scatter(best_scores['loss']['epoch'], best_scores['loss']['score'], color='red', s=150, zorder=5, label='Min Loss', marker='*')
if best_scores['loss']['epoch'] != -1 and len(losses) > 1:
    plt.annotate(
        f'Min Loss: {best_scores["loss"]["score"]:.4f}\n(Epoch {best_scores["loss"]["epoch"]})', 
        xy=(best_scores['loss']['epoch'], best_scores['loss']['score']),
        xytext=(best_scores['loss']['epoch'], best_scores['loss']['score'] + np.std(losses) * 0.5),
        arrowprops=dict(facecolor='black', shrink=0.05, width=0.5, headwidth=8),
        fontsize=10,
        ha='center',
        bbox=dict(boxstyle="round,pad=0.3", fc="white", alpha=0.7)
    )
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title(f"Training Loss Curve for {run_name}")
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(MODEL_SAVE_DIR, "loss_curve.png"))
plt.show()

# FID
plot_scores_with_annotations(
    epochs_evaluated_fid_kid_is, fid_scores, best_scores['fid'],
    f'FID Score (Lower is Better) for {run_name}', 'FID', 'min', 'tab:blue'
)

# KID
plot_scores_with_annotations(
    epochs_evaluated_fid_kid_is, kid_scores, best_scores['kid'],
    f'KID Score (Lower is Better) for {run_name}', 'KID', 'min', 'tab:orange'
)

# IS
plot_scores_with_annotations(
    epochs_evaluated_fid_kid_is, is_scores, best_scores['is'],
    f'Inception Score (Higher is Better) for {run_name}', 'IS', 'max', 'tab:green'
)