In [2]:
import os
import sys

sys.path.append("/kaggle/working/MADM")


In [3]:
import os
import shutil
import random

# ====== PATH ======
src_root = "/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData"
dst_root = "/kaggle/working/brats_pretrain"

os.makedirs(dst_root, exist_ok=True)

# ====== SPLIT ======
cases = sorted([
    d for d in os.listdir(src_root)
    if os.path.isdir(os.path.join(src_root, d))
])

random.seed(42)
random.shuffle(cases)

train_ratio = 0.8
split_idx = int(len(cases) * train_ratio)

splits = {
    "train": cases[:split_idx],
    "val": cases[split_idx:]
}

modalities = ["flair", "t1", "t1ce", "t2"]

# ====== REFORMAT ======
for split, case_list in splits.items():
    for case in case_list:
        src_case = os.path.join(src_root, case)
        dst_case = os.path.join(dst_root, split, case)

        os.makedirs(dst_case, exist_ok=True)

        for f in os.listdir(src_case):
            for m in modalities:
                if f.endswith(f"{m}.nii.gz"):
                    shutil.copy(
                        os.path.join(src_case, f),
                        os.path.join(dst_case, f"{m}.nii.gz")
                    )

print("✅ Done: BraTS pretrain dataset ready")


✅ Done: BraTS pretrain dataset ready


In [None]:
!ls /kaggle/input
!ls /kaggle/input/cmlaimadm


In [None]:
!cp -r /kaggle/input/cmlaimadm /kaggle/working/

In [None]:
!ls /kaggle/working/cmlaimadm


In [None]:
with open("/kaggle/working/cmlaimadm/dataloader_scripts/load_pet_2_5D.py") as f:
    print(f.read()[:1500])


In [None]:
import sys
sys.path.append("/kaggle/working/cmlaimadm")



In [None]:
import os
import numpy as np
import random
import torch
import nibabel as nib
import scipy.ndimage as ndimage

from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

In [None]:
def normalize_ac(data, mask):
    non_zero_data = data[mask == True]
    non_zero_mean = np.mean(non_zero_data)
    data = data / non_zero_mean
    data = np.tanh(data / 5)
    return data

def get_mask(nac, num_blurred=5, threshold=0.05):
    blurred_nac = np.copy(nac)
    for _ in range(num_blurred):
        blurred_nac = ndimage.gaussian_filter(blurred_nac, sigma=1)
    mask = nac > threshold
    return mask


class LoadPetSlices(Dataset):
    def __init__(
        self,
        root_dir,
        axis="z",
        load_adj=8,
        out_size=192,
        use_every_n_slice=2,
        seed=1,
    ):
        super().__init__()
        assert axis in ["x", "y", "z"]
        random.seed(seed)

        self.axis = axis
        self.load_adj = load_adj
        self.out_size = out_size
        self.root_dir = root_dir

        self.ids = [f[:-9] for f in os.listdir(os.path.join(root_dir, "5NAC"))]
        self.slice_map = []

        for pid in self.ids:
            nac_path = os.path.join(root_dir, "5NAC", pid + "5_NAC.nii")
            shape = nib.load(nac_path).shape

            depth = shape[2] if axis == "z" else shape[1] if axis == "y" else shape[0]

            for s in range(load_adj, depth - load_adj, use_every_n_slice):
                self.slice_map.append((pid, s))

        print(f"[Dataset] total slices = {len(self.slice_map)}")

    def __len__(self):
        return len(self.slice_map)

    def extract_25d(self, vol, center_idx):
        slices = []
        for i in range(center_idx - self.load_adj, center_idx + self.load_adj + 1):
            if self.axis == "z":
                slices.append(vol[:, :, i])
            elif self.axis == "y":
                slices.append(vol[:, i, :])
            else:
                slices.append(vol[i, :, :])
        return torch.from_numpy(np.stack(slices, axis=0))  # (C,H,W)
    def __getitem__(self, idx):
        pid, slice_idx = self.slice_map[idx]

    # -------------------------
    # Load volumes
    # -------------------------
        nac = nib.load(
            os.path.join(self.root_dir, "5NAC", pid + "5_NAC.nii")
        ).get_fdata().astype(np.float32)

        ac = nib.load(
            os.path.join(self.root_dir, "100AC", pid + "100_AC.nii")
        ).get_fdata().astype(np.float32)

    # -------------------------
    # Normalize with mask
    # -------------------------
        mask = get_mask(nac)
        nac = normalize_ac(nac, mask)
        ac  = normalize_ac(ac, mask)

    # -------------------------
    # 2.5D NAC 
    # -------------------------
        nac_25d = self.extract_25d(nac, slice_idx)   
        nac_25d = torch.from_numpy(nac_25d).float()

        center = self.load_adj

    
        cond = torch.cat(
            [
                nac_25d[:center],        
                nac_25d[center + 1:],   
            ],
            dim=0
        )                               

    # -------------------------
    # AC target (2D)
    # -------------------------
        if self.axis == "z":
            ac_2d = ac[:, :, slice_idx]
        elif self.axis == "y":
            ac_2d = ac[:, slice_idx, :]
        else:
            ac_2d = ac[slice_idx, :, :]

        ac_2d = torch.from_numpy(ac_2d).unsqueeze(0).float()


        if random.random() < 0.5:
            cond  = torch.flip(cond, dims=[1])
            ac_2d = torch.flip(ac_2d, dims=[1])
        if random.random() < 0.5:
            cond  = torch.flip(cond, dims=[2])
            ac_2d = torch.flip(ac_2d, dims=[2])


        if idx == 0:
            print("AC   :", ac_2d.shape)
            print("COND :", cond.shape)

        assert cond.shape[0] == self.load_adj * 2
        assert ac_2d.shape[0] == 1
        return ac_2d, {"cond": cond}



def load_data(
    batch_size,
    root_dir,
    axis="z",
    load_adj=8,
    out_size=192,
):
    dataset = LoadPetSlices25D(
        root_dir=root_dir,
        axis=axis,
        load_adj=load_adj,
        out_size=out_size,
        use_every_n_slice=2,   
    )

    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,       
        pin_memory=False,
    )

    while True:
        yield from loader


In [None]:
!pip install mpi4py


In [None]:
import os
import shutil
import nibabel as nib

src_root = "/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData"
dst_root = "/kaggle/working/NAC_data"

splits = {
    "train": os.listdir(src_root)[:50],
    "val": os.listdir(src_root)[50:70],
}

os.makedirs(dst_root, exist_ok=True)

for split, cases in splits.items():
    for sub in ["5NAC", "100AC"]:
        os.makedirs(os.path.join(dst_root, split, sub), exist_ok=True)

    for case in cases:
        case_dir = os.path.join(src_root, case)
        if not os.path.isdir(case_dir):
            continue

        # chọn modality
        nac_path = [f for f in os.listdir(case_dir) if "t1.nii" in f][0]
        ac_path  = [f for f in os.listdir(case_dir) if "t1ce.nii" in f][0]

        shutil.copy(
            os.path.join(case_dir, nac_path),
            os.path.join(dst_root, split, "5NAC", f"{case}_5_NAC.nii")
        )
        shutil.copy(
            os.path.join(case_dir, ac_path),
            os.path.join(dst_root, split, "100AC", f"{case}_100_AC.nii")
        )

print("✅ Dataset mapped to NAC_data/")


In [None]:
import torch as th
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# ===============================
# Train diffusion model
# ===============================

import os
from guided_diffusion import dist_util, logger
from guided_diffusion.resample import create_named_schedule_sampler
from guided_diffusion.script_util import (
    model_and_diffusion_defaults,
    create_model_and_diffusion,
    args_to_dict,
)
from guided_diffusion.train_util import TrainLoop
from dataloader_scripts.load_pet_2_5D import LoadValData, load_data
class Args:
    schedule_sampler = "uniform"
    resume_checkpoint = ""
    use_fp16 = False
    fp16_scale_growth = 1e-3
    # 2.5D setting
    train_axis = "z"# 3 lan cho x,y,z
    load_adj = 2
    out_channels = 1

    # path
    logdir = "checkpoint_z"
    data_root = "NAC_data"

args = Args()

for k, v in model_and_diffusion_defaults().items():
    setattr(args, k, v)
args.batch_size = 1
args.microbatch = -1
args.lr = 1e-4
args.weight_decay = 0.0
args.lr_anneal_steps = 1000

args.log_interval = 50
args.save_interval = 500

args.ema_rate = "0.0"
          # tắt EMA
args.use_fp16 = False

# diffusion
args.diffusion_steps = 25#50
args.noise_schedule = "linear"

# model I/O
args.in_channels = 1 + (2 * args.load_adj + 1)
args.out_channels = 1
args.image_size = 128

# safe UNet config
args.model_channels = 32
args.num_res_blocks = 1#2
args.channel_mult = "1,2,4"  
args.attention_resolutions = "9999"

args.use_scale_shift_norm = False
args.resblock_updown = False
args.use_checkpoint = False


dist_util.setup_dist()
logger.configure(dir=args.logdir)

logger.log("creating model and diffusion...")
args.channel_mult = ",".join(map(str, args.channel_mult)) \
    if isinstance(args.channel_mult, (list, tuple)) else args.channel_mult

args.attention_resolutions = ",".join(map(str, args.attention_resolutions)) \
    if isinstance(args.attention_resolutions, (list, tuple)) else args.attention_resolutions
model, diffusion = create_model_and_diffusion(
    **args_to_dict(args, model_and_diffusion_defaults().keys())
)
model.to(dist_util.dev())

schedule_sampler = create_named_schedule_sampler(
    args.schedule_sampler, diffusion
)


In [None]:

# -------------------------------
# Data loader
# -------------------------------
logger.log("creating data loader...")

train_dir = os.path.join(args.data_root, "train")
val_dir   = os.path.join(args.data_root, "val")

data = load_data(
    batch_size=args.batch_size,
    root_dir=train_dir,
    axis=args.train_axis,
    load_adj=args.load_adj,
)

val_data = LoadValData(
    root_dir=val_dir,
    axis=args.train_axis,
    load_adj=args.load_adj,
)
_orig_training_losses = diffusion.training_losses
def safe_training_losses(model, x_start, cond, t, *args, **kwargs):
    if not th.is_tensor(t):
        t = th.tensor(t, device=x_start.device)
    t = th.clamp(t, 0, diffusion.num_timesteps - 1)
    return _orig_training_losses(model, x_start, cond, t, *args, **kwargs)

diffusion.training_losses = safe_training_losses

# -------------------------------
# Train
# -------------------------------
logger.log("training...")
train_loop = TrainLoop(
    model=model,
    diffusion=diffusion,
    data=data,
    val_data=val_data,
    batch_size=args.batch_size,
    microbatch=args.microbatch,
    lr=args.lr,
    ema_rate=args.ema_rate,
    log_interval=args.log_interval,
    save_interval=args.save_interval,
    resume_checkpoint=args.resume_checkpoint,
    use_fp16=args.use_fp16,
    fp16_scale_growth=args.fp16_scale_growth,
    schedule_sampler=schedule_sampler,
    weight_decay=args.weight_decay,
    lr_anneal_steps=args.lr_anneal_steps,
    logdir=args.logdir,#save
)
# def save_model(model, logdir, axis, step):
#     os.makedirs(logdir, exist_ok=True)
#     save_path = os.path.join(logdir, f"model_{axis}_step{step}.pt")
#     th.save(model.state_dict(), save_path)
#     logger.log(f"[SAVE] Model saved to {save_path}")
logger.log("training...")
train_loop.run_loop()

import torch
import os

def export_infer_ckpt(train_ckpt, save_path):
    state = torch.load(train_ckpt, map_location="cpu")

    infer_ckpt = {
        "model": state, 
    }

    torch.save(infer_ckpt, save_path)
    print(f"[OK] Exported inference ckpt → {save_path}")


export_infer_ckpt(
    "checkpoint_y/best_model000500.pt",
    "checkpoint_y/infer_model_y.pt"
)


In [None]:
# ==========================================================
# Validation & Sampling – 2.5D → 3D (x, y, z)
# ==========================================================

import os
import numpy as np
import torch as th
import nibabel as nib

from guided_diffusion import dist_util, logger
from guided_diffusion.script_util import (
    model_and_diffusion_defaults,
    create_model_and_diffusion,
    args_to_dict,
)

from dataloader_scripts.load_pet_2_5D import LoadTestData



class Args:

    batch_size = 8
    clip_denoised = True
    use_ddim = False
    prior_start_t = 25

    data_root = "NAC_data"
    save_root = "MADM"


    load_adj = 2

    in_channels = 1 + (2 * load_adj + 1)   # = 6 ✅
    out_channels = 1
    image_size = 128


    model_channels = 136
    channel_mult = (1,2,4,4)
    num_res_blocks = 2
    attention_resolutions = (16,)
    
    use_scale_shift_norm = False
    resblock_updown = True
    use_checkpoint = False
    use_fp16 = False


    diffusion_steps = 25
    noise_schedule = "linear"


    model_dirs = {
        "x": "checkpoint_x",
        "y": "checkpoint_y",
        "z": "checkpoint_z",
    }

    axes = ["x", "y", "z"]


args = Args()


defaults = model_and_diffusion_defaults()
for k, v in defaults.items():
    if not hasattr(args, k):
        setattr(args, k, v)

args.in_channels = 1 + (2 * args.load_adj + 1)
args.out_channels = 1
args.image_size = 128


args.diffusion_steps = 25
args.noise_schedule = "linear"

dist_util.setup_dist()
logger.configure()

device = dist_util.dev()
logger.log(f"Device: {device}")


test_dir = os.path.join(args.data_root, "val")
test_data = LoadTestData(
    root_dir=test_dir,
    load_adj=args.load_adj,
)

logger.log(f"Number of test cases: {len(test_data)}")

for idx in range(len(test_data)):
    logger.log(f"\n===== Sampling case {idx + 1}/{len(test_data)} =====")
    test_data.idx = idx

    axis_predictions = {}



    for axis in args.axes:
        logger.log(f"--- Axis: {axis}")

        test_data = LoadTestData(
        root_dir=test_dir,
        axis=axis,
        load_adj=args.load_adj
        )

        args.channel_mult = ",".join(map(str, args.channel_mult)) \
            if isinstance(args.channel_mult, (list, tuple)) else args.channel_mult

        args.attention_resolutions = ",".join(map(str, args.attention_resolutions)) \
            if isinstance(args.attention_resolutions, (list, tuple)) else args.attention_resolutions

        model, diffusion = create_model_and_diffusion(
            **args_to_dict(args, defaults.keys())
        )

        ckpt_path = os.path.join(
            args.model_dirs[axis],
            f"infer_model_{axis}.pt"
        )
        
        logger.log(f"Loading checkpoint: {ckpt_path}")

        ckpt = torch.load(ckpt_path, map_location="cpu")
        model.load_state_dict(ckpt["model"], strict=True)

        model.to(dist_util.dev())
        model.eval()
        model.requires_grad_(False)

        H = args.image_size
        W = args.image_size
        D = test_data.get_zsize()

        shape = (H, W, D)


        with th.no_grad():
            sample = diffusion.p_sample_loop(
                model=model,
                test_data=test_data,
                shape=shape,
                batch_size=args.batch_size,
                start_t=args.prior_start_t,
                clip_denoised=args.clip_denoised,
                model_kwargs={},
            )

        axis_predictions[axis] = sample.cpu().numpy()

        # free memory
        del model
        th.cuda.empty_cache()

    # ======================================================
    # Fuse 3 axis → final 3D volume
    # ======================================================
    logger.log("Fusing x, y, z predictions...")

    final_pred = (
        axis_predictions["x"] +
        axis_predictions["y"] +
        axis_predictions["z"]
    ) / 3.0

    final_pred[final_pred < 0] = 0


    # ======================================================
    # Save NIfTI
    # ======================================================
    out_dir = os.path.join(
        args.save_root,
        f"adj{args.load_adj}_xyz_fusion"
    )
    os.makedirs(out_dir, exist_ok=True)

    out_path = os.path.join(
        out_dir,
        f"{test_data.get_name(idx)}_pred.nii.gz"
    )

    nii = nib.Nifti1Image(final_pred.astype(np.float32), affine=np.eye(4))
    nib.save(nii, out_path)

    logger.log(f"Saved: {out_path}")

logger.log("\n✅ Sampling & validation finished.")


In [None]:
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
gt_path = test_data.get_gt_path(idx)  # hoặc tự ghép path
gt_nii = nib.load(gt_path)
gt = gt_nii.get_fdata()
mae = np.mean(np.abs(final_pred - gt))
psnr_val = psnr(gt, final_pred, data_range=gt.max() - gt.min())
ssim_vals = []
for i in range(gt.shape[2]):  # duyệt theo trục z
    ssim_i = ssim(
        gt[:, :, i],
        final_pred[:, :, i],
        data_range=gt.max() - gt.min()
    )
    ssim_vals.append(ssim_i)

ssim_val = np.mean(ssim_vals)
logger.log(
    f"[METRIC] MAE={mae:.4f} | PSNR={psnr_val:.2f} | SSIM={ssim_val:.4f}"
)
