In [1]:
# Python
import os
import torch
from denoising_diffusion_pytorch import Unet, GaussianDiffusion
#from nibabel import load as load_nii
import nibabel as nib
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

In [2]:
class NiftiiDataset(Dataset):
    def __init__(self, source_paths, target_paths):
        self.source_slices = []
        self.target_slices = []

        for source_path, target_path in zip(source_paths, target_paths):
            source_nii = nib.load(source_path)
            target_nii = nib.load(target_path)

            source_img = torch.tensor(source_nii.get_fdata(dtype=np.float32))
            target_img = torch.tensor(target_nii.get_fdata(dtype=np.float32))

            source_slice = source_img[:, :, source_img.shape[2] // 2].unsqueeze(0)
            target_slice = target_img[:, :, target_img.shape[2] // 2].unsqueeze(0)

            self.source_slices.append(source_slice)
            self.target_slices.append(target_slice)

    def __len__(self):
        return len(self.source_slices)

    def __getitem__(self, idx):
        return self.source_slices[idx], self.target_slices[idx]

def load_data():
    source_image_paths = sorted(glob.glob("/home/youssef/harmo_4/ALL_training_data/Pat*_CHU_zscore_minmax_unbias.nii.gz"))
    target_image_paths = sorted(glob.glob("/home/youssef/harmo_4/ALL_training_data/Pat*_COL_zscore_minmax_unbias.nii.gz"))

    dataset = NiftiiDataset(source_image_paths, target_image_paths)
    dataloader = DataLoader(dataset, batch_size=4)

    return dataloader

In [13]:
def define_model():
    model = Unet(
        #dim = 64, #for better result put it to 128
        #dim_mults = (1, 2, 4, 8), #for better result Add an additional layer (1, 2, 4, 8, 16)
        dim = 128,
        dim_mults = (1, 2, 4, 8, 16),
        channels=1
    )
    return model

def define_diffusion(model):
    diffusion = GaussianDiffusion(
        model=model,
        image_size=256,
        timesteps=1000
    )
    return diffusion

In [14]:
def generate_harmonized_image(model_path, source_slice, target_slice):
    # Load the trained model
    model = define_model()
    model.load_state_dict(torch.load(model_path))
    model.eval()

    # Generate a random timestep
    timesteps = 1000
    t = torch.randint(0, timesteps, (source_slice.size(0),), device=source_slice.device)

    # Generate the harmonized image
    with torch.no_grad():
        harmonized_slice = model(source_slice, t)

    # Save the harmonized image
    save_image(harmonized_slice, 'harmonized_slice8.png')
    
def generate_harmonized_image(model_path, source_image_path, target_image_path):
    # Load the trained model
    model = define_model()
    model.load_state_dict(torch.load(model_path))
    model.eval()

    # Load the source and target images
    source_nii = nib.load(source_image_path)
    target_nii = nib.load(target_image_path)

    source_img = torch.tensor(source_nii.get_fdata(dtype=np.float32))
    target_img = torch.tensor(target_nii.get_fdata(dtype=np.float32))

    source_slice = source_img[:, :, source_img.shape[2] // 2].unsqueeze(0).unsqueeze(0)
    target_slice = target_img[:, :, target_img.shape[2] // 2].unsqueeze(0).unsqueeze(0)

    # Generate a random timestep
    timesteps = 1000
    t = torch.randint(0, timesteps, (source_slice.size(0),), device=source_slice.device)

    # Generate the harmonized image
    with torch.no_grad():
        harmonized_slice = model(source_slice, t)

    # Save the harmonized image
    save_image(harmonized_slice, '/home/youssef/harmo_4/harmonized_result/harmonized_slice42_9.png')


source_image_path = "/home/youssef/harmo_4/test_data/Pat42_CHU_zscore_minmax_unbias.nii.gz"
target_image_path = "/home/youssef/harmo_4/test_data/Pat42_COL_zscore_minmax_unbias.nii.gz"
model_path = "/home/youssef/harmo_4/trained_model/savedmodel_02.pt"

generate_harmonized_image(model_path, source_image_path, target_image_path)
#generate_harmonized_image('savedmodel.pt', source_image_path, target_image_path)

training.py

In [15]:
def train_model(dataloader, model, diffusion):
    criterion = torch.nn.L1Loss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(10): 
        for source_slice, target_slice in dataloader:
            optimizer.zero_grad()

            # Generate a random timestep for each batch
            timesteps = 1000
            t = torch.randint(0, timesteps, (source_slice.size(0),), device=source_slice.device)

            # The model's forward call now includes the timestep
            reconstructed_slice = model(source_slice, t)

            # Calculate loss, backpropagate, and update model weights as before
            loss = criterion(reconstructed_slice, target_slice)
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}, Loss: {loss.item()}")

    # Save the trained model
    torch.save(model.state_dict(), 'savedmodel.pt')


def train_model(dataloader, model, diffusion):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    scaler = GradScaler()  # Initialize GradScaler

    for epoch in range(100): 
        for source_slice, target_slice in dataloader:
            source_slice, target_slice = source_slice.to(device), target_slice.to(device)
            optimizer.zero_grad()

            timesteps = 1000
            t = torch.randint(0, timesteps, (source_slice.size(0),), device=source_slice.device)

            # Use autocast to run the forward pass in mixed precision
            with autocast():
                reconstructed_slice = model(source_slice, t)
                loss = criterion(reconstructed_slice, target_slice)

            # Use GradScaler to scale the loss and call backward
            scaler.scale(loss).backward()
            # Use GradScaler to step the optimizer
            scaler.step(optimizer)
            # Update the scale for next iteration
            scaler.update()

            print(f"Epoch {epoch+1}, Loss: {loss.item()}")


test.py

In [16]:
def load_model(model_path):
    model = torch.load(model_path)
    model.eval()
    return model

def load_images(source_path, target_path):
    transform = transforms.Compose([transforms.ToTensor()])
    source = transform(Image.open(source_path))
    target = transform(Image.open(target_path))
    return source.unsqueeze(0), target.unsqueeze(0)

def harmonize_images(model, source, target):
    with torch.no_grad():
        harmonized = model(source, target)
    return harmonized.squeeze(0)

def save_image(image, save_path):
    image = transforms.ToPILImage()(image)
    image.save(save_path)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', required=True)
    parser.add_argument('--source', required=True)
    parser.add_argument('--target', required=True)
    parser.add_argument('--save', required=True)
    args = parser.parse_args()

    model = load_model(args.model)
    source, target = load_images(args.source, args.target)
    harmonized = harmonize_images(model, source, target)
    save_image(harmonized, args.save)

original.py

In [17]:
import torch
import nibabel as nib
import numpy as np
from torch.utils.data import Dataset, DataLoader
from denoising_diffusion_pytorch import Unet, GaussianDiffusion
from torchvision.utils import save_image

class NiftiiDataset(Dataset):
    def __init__(self, source_path, target_path):
        source_nii = nib.load(source_path)
        target_nii = nib.load(target_path)
        self.source_img = torch.tensor(source_nii.get_fdata(dtype=np.float32))
        self.target_img = torch.tensor(target_nii.get_fdata(dtype=np.float32))
        self.source_slice = self.source_img[:, :, self.source_img.shape[2] // 2].unsqueeze(0)
        self.target_slice = self.target_img[:, :, self.target_img.shape[2] // 2].unsqueeze(0)

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        return self.source_slice, self.target_slice

def load_data(source_image_path, target_image_path):
    dataset = NiftiiDataset(source_image_path, target_image_path)
    dataloader = DataLoader(dataset, batch_size=1)
    return dataloader

def define_model():
    model = Unet(
        dim = 64,
        dim_mults = (1, 2, 4, 8),
        channels=1
    )
    return model

def define_diffusion(model):
    diffusion = GaussianDiffusion(
        model=model,
        image_size=256,
        timesteps=1000
    )
    return diffusion

def train_model(dataloader, model, diffusion):
    criterion = torch.nn.L1Loss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(10): 
        for source_slice, target_slice in dataloader:
            optimizer.zero_grad()

            # Generate a random timestep for each batch
            timesteps = 1000
            t = torch.randint(0, timesteps, (source_slice.size(0),), device=source_slice.device)

            # The model's forward call now includes the timestep
            reconstructed_slice = model(source_slice, t)

            # Calculate loss, backpropagate, and update model weights as before
            loss = criterion(reconstructed_slice, target_slice)
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}, Loss: {loss.item()}")

    # Save the trained model
    torch.save(model.state_dict(), 'savedmodel.pt')


def generate_harmonized_image(model_path, source_slice, target_slice):
    # Load the trained model
    model = define_model()
    model.load_state_dict(torch.load(model_path))
    model.eval()

    # Generate a random timestep
    timesteps = 1000
    t = torch.randint(0, timesteps, (source_slice.size(0),), device=source_slice.device)

    # Generate the harmonized image
    with torch.no_grad():
        harmonized_slice = model(source_slice, t)

    # Save the harmonized image
    save_image(harmonized_slice, 'harmonized_slice8.png')


def main():
    source_image_path = "/home/youssef/harmonization_project/data/train_2/Pat8_CHU_zscore_minmax_unbias.nii.gz"
    target_image_path = "/home/youssef/harmonization_project/data/train_2/Pat8_COL_zscore_minmax_unbias.nii.gz"

    dataloader = load_data(source_image_path, target_image_path)
    model = define_model()
    diffusion = define_diffusion(model)
    train_model(dataloader, model, diffusion)

    for source_slice, target_slice in dataloader:
        generate_harmonized_image('savedmodel.pt', source_slice, target_slice)
    

if __name__ == "__main__":
    main()

Epoch 1, Loss: 0.7744356393814087
Epoch 2, Loss: 0.22553113102912903
Epoch 3, Loss: 0.23664915561676025
Epoch 4, Loss: 0.10385526716709137
Epoch 5, Loss: 0.18706238269805908
Epoch 6, Loss: 0.19578926265239716
Epoch 7, Loss: 0.12775063514709473
Epoch 8, Loss: 0.09209910035133362
Epoch 9, Loss: 0.06490813195705414
Epoch 10, Loss: 0.043991755694150925


In [None]:
source_image_path = "/home/youssef/harmo_4/test_data/Pat42_CHU_zscore_minmax_unbias.nii.gz"
target_image_path = "/home/youssef/harmo_4/test_data/Pat42_COL_zscore_minmax_unbias.nii.gz"
model_path = "/home/youssef/harmo_4/trained_model/savedmodel_02.pt"

generate_harmonized_image(model_path, source_image_path, target_image_path)
#generate_harmonized_image('savedmodel.pt', source_image_path, target_image_path)
