In [1]:
import os, sys
sys.path.append(os.path.abspath("/workspace/nas100/forGPU2/Kimjihoo/ASAN_02_BB_T1CE"))

# --- Imports ---
import torch
from omegaconf import OmegaConf
from datasets import BB2T1CEDataset2D
from models import Model

# Device 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


Device: cuda


In [2]:
from pathlib import Path
root = Path("/workspace/nas100/forGPU2/Kimjihoo/ASAN_02_BB_T1CE/data/0_RAW_DATA/meta-bb")
for f in sorted(root.glob("*.nii.gz"))[:10]:
    print(f.name)


12028537_tp20201206_bb_coreg.nii.gz
13771269_tp20201224_bb_coreg.nii.gz
15369028_tp20201215_bb_coreg.nii.gz
17530376_tp20201104_bb_coreg.nii.gz
18482526_tp20210129_bb_coreg.nii.gz
18849484_tp20201124_bb_coreg.nii.gz
19002581_tp20201106_bb_coreg.nii.gz
19722063_tp20201127_bb_coreg.nii.gz
20721084_tp20201219_bb_coreg.nii.gz
20851512_tp20201209_bb_coreg.nii.gz


In [4]:
from datasets import BB2T1CEDataset2D

ds = BB2T1CEDataset2D(
    root_dir="/workspace/nas100/forGPU2/Kimjihoo/ASAN_02_BB_T1CE/data/0_RAW_DATA",
    n_slices_per_vol=4
)

bb, t1ce = ds[0]
print(bb.shape, t1ce.shape)
# ✅ torch.Size([4, 1, 256, 256]) torch.Size([4, 1, 256, 256])


torch.Size([4, 1, 256, 256]) torch.Size([4, 1, 256, 256])


In [5]:
from types import SimpleNamespace

config = SimpleNamespace(
    model=SimpleNamespace(
        ch=64,
        out_ch=1,
        ch_mult=(1, 2, 4, 8),
        num_res_blocks=2,
        attn_resolutions=[16],
        dropout=0.1,
        in_channels=1,
        resamp_with_conv=True,
        ema=True,
        ema_rate=0.9999,
        var_type="fixedsmall"
    ),
    data=SimpleNamespace(
        image_size=256
    ),
    diffusion=SimpleNamespace(
        num_diffusion_timesteps=1000,
        beta_start=1e-4,
        beta_end=0.02,
        beta_schedule="linear"
    )
)
from models import Model
import torch

# 모델 초기화
model = Model(config).cuda()

# 데이터셋 샘플 불러오기
from datasets import BB2T1CEDataset2D
ds = BB2T1CEDataset2D(
    root_dir="/workspace/nas100/forGPU2/Kimjihoo/ASAN_02_BB_T1CE/data/0_RAW_DATA",
    n_slices_per_vol=4
)
bb, t1ce = ds[0]
bb = bb.cuda()

# timestep
t = torch.randint(0, config.diffusion.num_diffusion_timesteps, (bb.size(0),)).cuda()

# forward
out, feats = model(bb, t)
print("✅ Output:", out.shape)


✅ Output: torch.Size([4, 1, 256, 256])


In [None]:
# ================================================================
# 🔹 BB→T1CE Diffusion Training (No WandB)
# ================================================================
import os, torch, time
from torch.utils.data import DataLoader
from omegaconf import OmegaConf
from models import Model
from datasets import BB2T1CEDataset2D
from functions.losses import noise_estimation_loss
from functions.utils import get_beta_schedule
from models.ema_helper import EMAHelper

# ------------------------------------------------
# 1️⃣ Load Config
# ------------------------------------------------
cfg = OmegaConf.load("/workspace/nas100/forGPU2/Kimjihoo/ASAN_02_BB_T1CE/configs/bb2t1ce_train.yml")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

os.makedirs(cfg.log_dir, exist_ok=True)
os.makedirs(cfg.ckpt_dir, exist_ok=True)

# ------------------------------------------------
# 2️⃣ Dataset / Dataloader
# ------------------------------------------------
ds = BB2T1CEDataset2D(cfg.data.root_dir, n_slices_per_vol=cfg.data.n_slices_per_vol)
train_loader = DataLoader(
    ds, batch_size=cfg.data.batch_size, shuffle=True, num_workers=cfg.data.num_workers
)

# ------------------------------------------------
# 3️⃣ Model / Optimizer / EMA
# ------------------------------------------------
model = Model(cfg).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.optim.lr)
ema_helper = EMAHelper(mu=cfg.model.ema_rate)
ema_helper.register(model)

# Diffusion betas
betas = torch.from_numpy(get_beta_schedule(
    beta_schedule=cfg.diffusion.beta_schedule,
    beta_start=cfg.diffusion.beta_start,
    beta_end=cfg.diffusion.beta_end,
    num_diffusion_timesteps=cfg.diffusion.num_diffusion_timesteps
)).float().to(device)

# ------------------------------------------------
# 4️⃣ Training Loop
# ------------------------------------------------
global_step = 0
for epoch in range(cfg.training.epochs):
    model.train()
    epoch_loss = 0.0
    start_time = time.time()

    for i, (bb, t1ce) in enumerate(train_loader):
        bb = bb.to(device)         # conditional input
        t1ce = t1ce.to(device)     # target to predict noise
        e = torch.randn_like(t1ce)
        t = torch.randint(0, cfg.diffusion.num_diffusion_timesteps, (t1ce.size(0),), device=device).long()

        loss = noise_estimation_loss(model, t1ce, bb, t, e, betas)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.optim.grad_clip)
        optimizer.step()
        ema_helper.update(model)

        global_step += 1
        epoch_loss += loss.item()

        # ---- Logging ----
        if global_step % cfg.training.print_freq == 0:
            print(f"[Epoch {epoch:03d} | Step {global_step:05d}] Loss: {loss.item():.6f}")

        # ---- Save every 5 steps ----
        if global_step % 5 == 0:
            ckpt_path = os.path.join(cfg.ckpt_dir, f"ckpt_step{global_step:05d}.pth")
            torch.save({
                "model": model.state_dict(),
                "ema": ema_helper.state_dict(),
                "optimizer": optimizer.state_dict(),
                "step": global_step,
                "epoch": epoch,
            }, ckpt_path)
            print(f"✅ Checkpoint saved at {ckpt_path}")

    avg_loss = epoch_loss / len(train_loader)
    print(f"✅ Epoch {epoch} done | Avg Loss: {avg_loss:.6f} | Time: {time.time() - start_time:.2f}s")

# ------------------------------------------------
# 5️⃣ Save Final Model
# ------------------------------------------------
final_path = os.path.join(cfg.ckpt_dir, "final_model.pth")
torch.save({
    "model": model.state_dict(),
    "ema": ema_helper.state_dict(),
    "optimizer": optimizer.state_dict(),
    "step": global_step,
}, final_path)
print(f"🎯 Final model saved to {final_path}")


TypeError: get_beta_schedule() takes 1 positional argument but 4 were given