In [1]:
import torch
import nibabel as nib
import numpy as np
from torch.utils.data import Dataset, DataLoader
import glob
from denoising_diffusion_pytorch import Unet, GaussianDiffusion
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt
from torchvision.utils import save_image

#from data_load import load_data
#from model import define_model
#from harmonize import generate_harmonized_image
#from model import define_diffusion
#from training import train_model
#from training_enhanced import train_model

In [2]:
class NiftiiDataset(Dataset):
    def __init__(self, source_paths, target_paths):
        self.source_slices = []
        self.target_slices = []

        for source_path, target_path in zip(source_paths, target_paths):
            source_nii = nib.load(source_path)
            target_nii = nib.load(target_path)

            source_img = torch.tensor(source_nii.get_fdata(dtype=np.float32))
            target_img = torch.tensor(target_nii.get_fdata(dtype=np.float32))

            source_slice = source_img[:, :, source_img.shape[2] // 2].unsqueeze(0)
            target_slice = target_img[:, :, target_img.shape[2] // 2].unsqueeze(0)

            self.source_slices.append(source_slice)
            self.target_slices.append(target_slice)

    def __len__(self):
        return len(self.source_slices)

    def __getitem__(self, idx):
        return self.source_slices[idx], self.target_slices[idx]


In [24]:
def load_data():
    source_image_paths = sorted(glob.glob("/home/youssef/harmo_4/ALL_training_data/Pat11*_CHU_zscore_minmax_unbias.nii.gz"))
    target_image_paths = sorted(glob.glob("/home/youssef/harmo_4/ALL_training_data/Pat11*_COL_zscore_minmax_unbias.nii.gz"))

    dataset = NiftiiDataset(source_image_paths, target_image_paths)
    dataloader = DataLoader(dataset, batch_size=2)

    return dataloader

In [21]:
def define_model():
    model = Unet(
        dim = 128,  # Increased dimensions
        dim_mults = (1, 2, 4, 8),
        channels = 1,  # Assuming grayscale MRI images
        self_condition = True  # Enable self-conditioning
    )

    return model

In [5]:
def define_diffusion(model):
    diffusion = GaussianDiffusion(
        model = model,
        image_size = 256,
        timesteps = 1000,
        objective = 'pred_noise',  # Change the objective if needed
        beta_schedule = 'cosine',  # Changing noise schedule to cosine
    )

    return diffusion

In [25]:
def train_model(dataloader, model, diffusion):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.1)  # Adjust step_size and gamma as needed

    scaler = GradScaler()  # Initialize GradScaler

    for epoch in range(20): 
        cpt=0
        for source_slice, target_slice in dataloader:
            source_slice, target_slice = source_slice.to(device), target_slice.to(device)
            optimizer.zero_grad()

            timesteps = 1000
            t = torch.randint(0, timesteps, (source_slice.size(0),), device=source_slice.device)

            save_image(target_slice, 'target_'+str(epoch)+'_'+str(cpt)+'.png')
            save_image(source_slice, 'source_'+str(epoch)+'_'+str(cpt)+'.png')
            # Use autocast to run the forward pass in mixed precision
            with autocast():
                reconstructed_slice = model(source_slice, t)
                loss = criterion(reconstructed_slice, target_slice)
                save_image(reconstructed_slice,'reconstruct_'+str(epoch)+'_'+str(cpt)+'.png')

            # Use GradScaler to scale the loss and call backward
            scaler.scale(loss).backward()
            # Use GradScaler to step the optimizer
            scaler.step(optimizer)
            # Update the scale for next iteration
            scaler.update()
            cpt=cpt+1
            print(f"Epoch {epoch+1}, Loss: {loss.item()}")

        scheduler.step()  # Update the learning rate
        print(f"Learning rate adjusted to: {scheduler.get_last_lr()[0]}")

    torch.save(model.state_dict(), 'savedmodel_different_model_pat8.pt')
    save_image(source_slice, 'harmonized_slice_88.png')

In [26]:
def generate_harmonized_image(model_path, source_image_path, target_image_path):
    # Load the trained model
    model = define_model()
    model.load_state_dict(torch.load(model_path))
    model.eval()

    # Load the source and target images
    source_nii = nib.load(source_image_path)
    target_nii = nib.load(target_image_path)

    source_img = torch.tensor(source_nii.get_fdata(dtype=np.float32))
    target_img = torch.tensor(target_nii.get_fdata(dtype=np.float32))

    source_slice = source_img[:, :, source_img.shape[2] // 2].unsqueeze(0).unsqueeze(0)
    target_slice = target_img[:, :, target_img.shape[2] // 2].unsqueeze(0).unsqueeze(0)

    # Generate a random timestep
    timesteps = 1000
    t = torch.randint(0, timesteps, (source_slice.size(0),), device=source_slice.device)

    # Generate the harmonized image
    with torch.no_grad():
        harmonized_slice = model(source_slice, t)

    # Save the harmonized image
    save_image(harmonized_slice, '/home/youssef/harmo_4/harmonized_result/harmonized_slice8_dm_10ep.png')

In [27]:
def main():
    #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataloader = load_data()
    #model = define_model().to(device)
    model = define_model()
    diffusion = define_diffusion(model)
    train_model(dataloader, model, diffusion)


if __name__ == "__main__":
    main()

4 4
Epoch 1, Loss: 0.1092502772808075
Epoch 1, Loss: 0.13179025053977966
Epoch 1, Loss: 1.4635510444641113
Epoch 1, Loss: 1.4164683818817139
Learning rate adjusted to: 0.0001
Epoch 2, Loss: 1.4095628261566162
Epoch 2, Loss: 0.14970599114894867
Epoch 2, Loss: 0.0638352632522583
Epoch 2, Loss: 0.0031716013327240944
Learning rate adjusted to: 0.0001
Epoch 3, Loss: 0.03259649872779846
Epoch 3, Loss: 0.02056567184627056
Epoch 3, Loss: 0.002363627078011632
Epoch 3, Loss: 0.014938732609152794
Learning rate adjusted to: 0.0001
Epoch 4, Loss: 0.028853412717580795
Epoch 4, Loss: 0.014280201867222786
Epoch 4, Loss: 0.0026247708592563868
Epoch 4, Loss: 0.013024034909904003
Learning rate adjusted to: 0.0001
Epoch 5, Loss: 0.016502177342772484
Epoch 5, Loss: 0.006820809096097946
Epoch 5, Loss: 0.0016480039339512587
Epoch 5, Loss: 0.006739890668541193
Learning rate adjusted to: 0.0001
Epoch 6, Loss: 0.014854387380182743
Epoch 6, Loss: 0.008057093247771263
Epoch 6, Loss: 0.0015543561894446611
Epoch 6,

In [None]:
#source_image_path = "/home/youssef/harmo_4/test_data/Pat42_CHU_zscore_minmax_unbias.nii.gz"
#target_image_path = "/home/youssef/harmo_4/test_data/Pat42_COL_zscore_minmax_unbias.nii.gz"
source_image_path = "/home/youssef/harmo_4/training_data/Pat8_CHU_zscore_minmax_unbias.nii.gz"
target_image_path = "/home/youssef/harmo_4/training_data/Pat8_COL_zscore_minmax_unbias.nii.gz"
model_path = "/home/youssef/harmo_4/trained_model/savedmodel_different_model_pat8.pt"

generate_harmonized_image(model_path, source_image_path, target_image_path)
#generate_harmonized_image('savedmodel.pt', source_image_path, target_image_path)