In [None]:
from torchvision.transforms import RandomCrop, Compose, ToPILImage, Resize, ToTensor, Lambda
from med_ddpm.diffusion_model.trainer import GaussianDiffusion, Trainer
from med_ddpm.diffusion_model.unet import create_model
import argparse
import torch
import os 
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import Dataset
from torchvision.transforms import Compose, ToTensor, Lambda
from glob import glob
import matplotlib.pyplot as plt
import nibabel as nib
import torchio as tio
import numpy as np
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
import re
import os
import torch.nn.functional as F
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
device

In [None]:
params = {
    "with_condition": True,
    "inputfolder": "../../data/registration_data/registration_DWI/",
    "targetfolder": "../../data/registration_data/CT/",
    "batchsize": 1,
    "epochs": 10000,
    "input_size": 128,
    "depth_size": 64,
    "num_channels": 64,
    "num_res_blocks": 1,
    "timesteps": 250,
    "save_and_sample_every": 10,
    "model_save_path": "../../model/med_ddpm_translation/dwi2ct/",
    "resume_weight": "../../model/med_ddpm/mri/model-17.pt"
}

In [None]:
transform = Compose([
    Lambda(lambda t: torch.tensor(t).float()),
    Lambda(lambda t: (t * 2) - 1),
    Lambda(lambda t: t.unsqueeze(0))
])
class NiftiPairImageGenerator(Dataset):
    def __init__(self,
            input_image,
            target_image,
            input_size: int,
            depth_size: int,
            input_channel: int = 3,
            target_transform=None,
            full_channel_mask=False,
            combine_output=False,
            transform=None,
        ):

        self.input_image = input_image
        self.target_image = target_image
        self.input_size = input_size
        self.depth_size = depth_size
        self.input_channel = input_channel
        self.scaler = MinMaxScaler()
        self.target_transform = target_transform
        self.full_channel_mask = full_channel_mask
        self.combine_output = combine_output


    def plot(self, index, n_slice=30):
        data = self[index]
        input_img = data['input']
        target_img = data['target']
        plt.subplot(1, 2, 1)
        plt.imshow(input_img[n_slice,:, :])
        plt.subplot(1, 2, 2)
        plt.imshow(target_img[n_slice,:, :])
        plt.show()

    def __len__(self):
        return len(self.input_image)

    def __getitem__(self, index):
        input_img = self.input_image[index].unsqueeze(0)
        target_img = self.target_image[index].unsqueeze(0)
      

        if self.combine_output:
            return torch.cat([target_img, input_img], 0)

        return {'input':input_img, 'target':target_img}

def resize_img(img):
    d, h, w = img.shape
    target_d = params['depth_size']
    target_h = params['input_size']
    target_w = params['input_size']

    img_tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0)  # (1, 1, D, H, W)

    # 현재 사이즈
    curr_d, curr_h, curr_w = img_tensor.shape[2:]

    
    img_tensor = F.interpolate(img_tensor, size=(target_d, target_h, target_w), mode='trilinear', align_corners=False)

    return img_tensor.squeeze(0).squeeze(0).numpy()  # (D, H, W)
input_list=glob(params['inputfolder']+"*.nii.gz")
target_list=[f.replace(params['inputfolder'], params['targetfolder']) for f in input_list]
input_images=torch.zeros((len(input_list), params['depth_size'], params['input_size'], params['input_size']))
target_images=torch.zeros((len(target_list), params['depth_size'], params['input_size'], params['input_size']))
for i in tqdm(range(len(input_list))):
    input_img = nib.load(input_list[i]).get_fdata()/2.
    target_img = nib.load(target_list[i]).get_fdata()/2.
    input_images[i] =transform(resize_img(input_img))
    target_images[i] = transform(resize_img(target_img))
    
dataset = NiftiPairImageGenerator(
        input_images,
        target_images,
        input_size=params['input_size'],
        depth_size=params['depth_size'],
        target_transform=transform,
        full_channel_mask=True
    )
def cycle(dl):
    while True:
        for data in dl:
            yield data
dataloader= DataLoader(dataset, batch_size = params['batchsize'], shuffle=True, num_workers=1, pin_memory=True)

In [None]:
in_channels=2
out_channels=1
model = create_model(params["input_size"], params["num_channels"], params["num_res_blocks"], in_channels=in_channels, out_channels=out_channels).to(device)

diffusion = GaussianDiffusion(
    model,
    image_size = params["input_size"],
    depth_size = params["depth_size"],
    timesteps = params["timesteps"],   # number of steps
    loss_type = 'l1',    # L1 or L2
    with_condition=params["with_condition"],
    channels=out_channels
).to(device)
diffusion.load_state_dict(torch.load('../../model/med_ddpm/translation/dwi2ct.pt', map_location=device))

In [None]:
def save_sample(diffusion, epoch, save_path, sample_shape, device, condition_tensors=None):
    diffusion.eval()

    # 샘플 생성
    samples = diffusion.sample(batch_size=sample_shape[0], condition_tensors=condition_tensors)
    samples = samples.cpu().numpy()  # (B, C, D, H, W)

    # condition도 numpy로 변환 (B, C, D, H, W)
    if condition_tensors is not None:
        condition_np = condition_tensors.detach().cpu().numpy()
    else:
        condition_np = None

    os.makedirs(save_path, exist_ok=True)

    for i in range(samples.shape[0]):
        gen_img = samples[i, 0]  # (D, H, W)
        nifti_gen = nib.Nifti1Image(gen_img, affine=np.eye(4))
        nib.save(nifti_gen, os.path.join(save_path, f'gen_epoch{epoch}.nii.gz'))

        if condition_np is not None:
            cond_img = condition_np[i, 0]  # (D, H, W)
            nifti_cond = nib.Nifti1Image(cond_img, affine=np.eye(4))
            nib.save(nifti_cond, os.path.join(save_path, f'cond_epoch{epoch}.nii.gz'))


In [None]:
optimizer = torch.optim.Adam(diffusion.parameters(), lr=2e-5)
num_epochs = params['epochs']
save_every = params['save_and_sample_every']

for epoch in range(num_epochs):
    sum_loss=0
    diffusion.train()
    with tqdm(dataloader, total=len(dataloader), desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
        for step, data in enumerate(pbar):
            if params['with_condition']:
                input_tensors = data['input'].to(device)    # condition
                target_tensors = data['target'].to(device)  # target
                loss = diffusion(target_tensors, condition_tensors=input_tensors)
            else:
                input_tensors = data.to(device)
                loss = diffusion(input_tensors)

            loss = loss.sum() / params['batchsize']
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            sum_loss += loss.item()
            pbar.set_postfix(loss=sum_loss/(step+1))

    # 모델 저장
    if (epoch + 1) % save_every == 0 or (epoch + 1) == num_epochs:
        os.makedirs(params["model_save_path"], exist_ok=True)
        torch.save(diffusion.state_dict(), os.path.join(params["model_save_path"], f"model-{epoch+1}.pt"))
        print(f"✅ Model saved at epoch {epoch+1}")
        
    condition_example = None
    if params["with_condition"]:
        data_example = next(iter(dataloader))
        condition_example = data_example['input'].to(device)[:1]  # 1개만
    save_sample(
        diffusion=diffusion,
        epoch=epoch + 1,
        save_path=os.path.join(params["model_save_path"], "samples"),
        sample_shape=(1, diffusion.channels, diffusion.depth_size, diffusion.image_size, diffusion.image_size),
        device=device,
        condition_tensors=condition_example
    )

In [None]:
len(dataloader)