In [1]:
from torchvision.transforms import RandomCrop, Compose, ToPILImage, Resize, ToTensor, Lambda
from med_ddpm.diffusion_model.trainer import GaussianDiffusion, Trainer
from med_ddpm.diffusion_model.unet import create_model
import argparse
import torch
import os 
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import Dataset
from torchvision.transforms import Compose, ToTensor, Lambda
from glob import glob
import matplotlib.pyplot as plt
import nibabel as nib
import torchio as tio
import numpy as np
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
import re
import os
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

APEX: ON


device(type='cuda', index=0)

In [None]:
params = {
    "with_condition": True,
    "inputfolder": "../../data/stroke_mri_nii_256/masks/",
    "targetfolder": "../../data/stroke_mri_nii_256/images/",
    "batchsize": 1,
    "epochs": 10000,
    "input_size": 128,
    "depth_size": 64,
    "num_channels": 64,
    "num_res_blocks": 1,
    "timesteps": 250,
    "save_and_sample_every": 10,
    "model_save_path": "../../model/med_ddpm_synthesis/dwi/",
    "resume_weight": "../../model/med_ddpm/mri/model-17.pt"
}

In [32]:
transform = Compose([
    Lambda(lambda t: torch.tensor(t).float()),
    Lambda(lambda t: (t * 2) - 1),
    Lambda(lambda t: t.unsqueeze(0))
])
class NiftiPairImageGenerator(Dataset):
    def __init__(self,
            input_image,
            target_image,
            input_size: int,
            depth_size: int,
            input_channel: int = 3,
            target_transform=None,
            full_channel_mask=False,
            combine_output=False,
            transform=None,
        ):

        self.input_image = input_image
        self.target_image = target_image
        self.input_size = input_size
        self.depth_size = depth_size
        self.input_channel = input_channel
        self.scaler = MinMaxScaler()
        self.target_transform = target_transform
        self.full_channel_mask = full_channel_mask
        self.combine_output = combine_output

   
    def plot(self, index, n_slice=30):
        data = self[index]
        input_img = data['input']
        target_img = data['target']
        plt.subplot(1, 2, 1)
        plt.imshow(input_img[n_slice,:, :])
        plt.subplot(1, 2, 2)
        plt.imshow(target_img[n_slice,:, :])
        plt.show()

    def __len__(self):
        return len(self.input_image)

    def __getitem__(self, index):
        input_img = self.input_image[index].unsqueeze(0)
        target_img = self.label2masks(self.target_image[index]).permute(3, 0, 1, 2)
      

        if self.combine_output:
            return torch.cat([target_img, input_img], 0)

        return {'input':input_img, 'target':target_img}

def resize_img(img):
    d,h, w = img.shape
    if h !=  params['input_size'] or w !=  params['input_size'] or d != params['depth_size']:
        img = tio.ScalarImage(tensor=img[np.newaxis, ...])
        cop = tio.Resize((params['depth_size'], params['input_size'],  params['input_size']))
        img = np.asarray(cop(img))[0]
    return img 
def label2masks(masked_img, input_channel=3):
    result_img =np.zeros(masked_img.shape + ( input_channel - 1,))
    result_img[masked_img==1, 0] = 1
    result_img[masked_img==2, 1] = 1
    return result_img
input_list=glob(params['inputfolder']+"*.nii.gz")
target_list=[f.replace(params['inputfolder'], params['targetfolder']) for f in input_list]
input_images=torch.zeros((len(input_list),params['depth_size'], params['input_size'], params['input_size']))
target_images=torch.zeros((len(target_list),params['depth_size'], params['input_size'], params['input_size'],2))
for i in tqdm(range(len(input_list))):
    input_img = nib.load(input_list[i]).get_fdata()
    target_img = (nib.load(target_list[i]).get_fdata()+1.)/2.
    input_images[i] =transform(resize_img(label2masks(input_img,3)))
    target_images[i] = transform(resize_img(target_img))
    
dataset = NiftiPairImageGenerator(
        input_images,
        target_images,
        input_size=params['input_size'],
        depth_size=params['depth_size'],
        target_transform=transform,
        full_channel_mask=True
    )
def cycle(dl):
    while True:
        for data in dl:
            yield data
dataloader= DataLoader(dataset, batch_size = params['batchsize'], shuffle=True, num_workers=1, pin_memory=True)

  0%|          | 0/1113 [00:00<?, ?it/s]


ValueError: too many values to unpack (expected 3)

In [31]:
label2masks(input_img,3).shape

(64, 256, 256, 2)

In [None]:
in_channels=3
out_channels=1
model = create_model(params["input_size"], params["num_channels"], params["num_res_blocks"], in_channels=in_channels, out_channels=out_channels).to(device)

diffusion = GaussianDiffusion(
    model,
    image_size = params["input_size"],
    depth_size = params["depth_size"],
    timesteps = params["timesteps"],   # number of steps
    loss_type = 'l1',    # L1 or L2
    with_condition=params["with_condition"],
    channels=out_channels
).to(device)

In [28]:
def save_sample(diffusion, epoch, save_path, sample_shape, device, condition_tensors=None):
    diffusion.eval()

    # 샘플 생성
    samples = diffusion.sample(batch_size=sample_shape[0], condition_tensors=condition_tensors)
    samples = samples.cpu().numpy()  # (B, C, D, H, W)

    os.makedirs(save_path, exist_ok=True)

    for i in range(samples.shape[0]):
        img = samples[i, 0]  # (D, H, W)
        nifti_img = nib.Nifti1Image(img, affine=np.eye(4))
        nib.save(nifti_img, os.path.join(save_path, f'sample_epoch{epoch}_idx{i}.nii.gz'))

    print(f"✅ Sample saved at epoch {epoch} to {save_path}")


In [6]:
optimizer = torch.optim.Adam(diffusion.parameters(), lr=2e-5)
num_epochs = params['epochs']
save_every = params['save_and_sample_every']

for epoch in range(num_epochs):
    sum_loss=0
    diffusion.train()
    with tqdm(dataloader, total=len(dataloader), desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
        for step, data in enumerate(pbar):
            if params['with_condition']:
                input_tensors = data['input'].to(device)    # condition
                target_tensors = data['target'].to(device)  # target
                loss = diffusion(target_tensors, condition_tensors=input_tensors)
            else:
                input_tensors = data.to(device)
                loss = diffusion(input_tensors)

            loss = loss.sum() / params['batchsize']
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            sum_loss += loss.item()
            pbar.set_postfix(loss=sum_loss/(step+1))

    # 모델 저장
    if (epoch + 1) % save_every == 0 or (epoch + 1) == num_epochs:
        os.makedirs(params["model_save_path"], exist_ok=True)
        torch.save(diffusion.state_dict(), os.path.join(params["model_save_path"], f"model-{epoch+1}.pt"))
        print(f"✅ Model saved at epoch {epoch+1}")
        
    condition_example = None
    if params["with_condition"]:
        data_example = next(iter(dataloader))
        condition_example = data_example['input'].to(device)[:1]  # 1개만
    save_sample(
        diffusion=diffusion,
        epoch=epoch + 1,
        save_path=os.path.join(params["model_save_path"], "samples"),
        sample_shape=(1, diffusion.channels, diffusion.depth_size, diffusion.image_size, diffusion.image_size),
        device=device,
        condition_tensors=condition_example
    )

Epoch 1/10000:   0%|          | 0/1113 [00:00<?, ?it/s]


AttributeError: Caught AttributeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "/tmp/ipykernel_165719/255002534.py", line 49, in __getitem__
    target_img = self.label2masks(self.target_image[index]).unsqueeze(0)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'numpy.ndarray' object has no attribute 'unsqueeze'. Did you mean: 'squeeze'?


In [None]:
len(dataloader)