In [None]:
# Standard Libraries
import os
import io
import monai
import random
import tempfile
from multiprocessing import Manager
from tqdm.notebook import tqdm
from monai.config import USE_COMPILED
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy
import warnings
import torch.nn.functional as F
import SimpleITK as sitk
from monai.networks.nets import UNet
import matplotlib.gridspec as gridspec
import torch
import torch.nn as nn
import torchio as tio
from torch.utils.data import DataLoader, Dataset, random_split
from torch.cuda.amp import GradScaler, autocast

# MONAI Libraries
from monai.apps import DecathlonDataset
from monai.config import print_config
from monai.data import DataLoader
# from monai.transforms import (
#     AddChanneld, 
#     CenterSpatialCropd, 
#     Compose, 
#     Lambdad, 
#     LoadImaged, 
#     Resized, 
#     ScaleIntensityd
# )
from monai.utils import set_determinism

# Custom Libraries
from generative.inferers import DiffusionInferer
from generative.networks.nets import DiffusionModelUNet
from generative.networks.schedulers import DDPMScheduler, DDIMScheduler
from dataloader import Train ,Eval 

import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (8192, rlimit[1]))

print("This is registration job.")
# import wandb
# wandb.init(
#     # set the wandb project where this run will be logged
#     project="Registration",
    
#     # track hyperparameters and run metadata
#     config={
#     "Data": "4 datasets (4000 random samples)",
#     "Loss": "MSE", 
#     "reg_penalty": "0.000001",
#     "lr": "0.0001",
#     }
# )

# Configuration
sitk.ProcessObject.SetGlobalDefaultThreader("Platform")
warnings.filterwarnings('ignore')


JUPYTER_ALLOW_INSECURE_WRITES=True


# Initialize Configuration
config = {
    'batch_size': 2,
    'imgDimResize': (160, 192, 160),
    'imgDimPad': (208, 256, 208),
    'spatialDims': '3D',
    'unisotropic_sampling': True,
    'perc_low': 1,
    'perc_high': 99,
    'rescaleFactor': 2,
    'base_path': '/scratch1/akrami/Latest_Data/Data',
    'lambda': 100,
}

# Seed and Device Configuration
manualSeed = 999
random.seed(manualSeed)
torch.manual_seed(manualSeed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# CUDA and CUDNN Configuration
# Uncomment the following line to specify CUDA_VISIBLE_DEVICES
# os.environ['CUDA_VISIBLE_DEVICES'] = '2,3,5,6'
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# SimpleITK Configuration
# Set the default number of threads and global behavior for SimpleITK
sitk.ProcessObject.SetGlobalDefaultThreader("Platform")
    
data_train = Train(pd.read_csv('./LesionData/train.csv'),config) 
data_val = Train(pd.read_csv('./LesionData/val.csv'),config)                



# #data_train = Train(pd.read_csv('/project/ajoshi_27/akrami/monai3D/GenerativeModels/data/split/IXI_train_fold0.csv', converters={'img_path': pd.eval}), config)
train_loader = DataLoader(data_train, batch_size=config.get('batch_size', 2),shuffle=True)

#data_val = Train(pd.read_csv('/project/ajoshi_27/akrami/monai3D/GenerativeModels/data/split/IXI_val_fold0.csv', converters={'img_path': pd.eval}), config)
val_loader = DataLoader(data_val, batch_size=config.get('batch_size', 2),shuffle=True)

In [None]:
model = DiffusionModelUNet(
    spatial_dims=2,
    in_channels=2,
    out_channels=1,
    num_channels=(64, 64, 64),
    attention_levels=(False, False, True),
    num_res_blocks=1,
    num_head_channels=64,
    with_conditioning=False,
)
model.to(device)

scheduler = DDPMScheduler(num_train_timesteps=1000)
optimizer = torch.optim.Adam(params=model.parameters(), lr=2.5e-5)
inferer = DiffusionInferer(scheduler)

scaler = GradScaler()

loss = nn.MSELoss()

In [None]:
max_epochs = 500
for epoch in range(max_epochs):
    model.train()
    epoch_loss = 0
    for i, data in enumerate(tqdm(train_loader)):
        
        input_data = data['vol']['data']
        input_data = input_data.to(device)
        peak_expanded = (data['peak'].unsqueeze(1).unsqueeze(2).unsqueeze(3).unsqueeze(4)).long().to(device)
        input_data = (input_data / peak_expanded)
        groundtruth = input_data.detach()
        masks = data['mask']['data']
        masks = masks.to(device)
        masked_input_data = (input_data*torch.where((masks == 0) , 1, 0)).detach()

        optimizer.zero_grad(set_to_none=True)
        timesteps = torch.randint(0, 1000, (len(input_data),)).to(device)

        with autocast(enabled=True):
            # Generate random noise
            noise = torch.randn_like(groundtruth).to(device)
            noisy_groundtruth = scheduler.add_noise(
                original_samples=groundtruth, noise=noise, timesteps=timesteps
            )  # we only add noise to the segmentation mask
            combined = torch.cat(
                (masked_input_data, noisy_groundtruth), dim=1
            )  # we concatenate the brain MR image with the noisy segmenatation mask, to condition the generation process
            prediction = model(x=combined, timesteps=timesteps)
            # Get model prediction
            loss = F.mse_loss(prediction.float(), noise.float())
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        epoch_loss += loss.item()
          
    print('  * train  ' +
        f'Loss: {epoch_loss/len(train_loader):.7f}, ')

    model.eval()
    val_epoch_loss = 0
    for i, data in enumerate(tqdm(val_loader)):

        input_data = data['vol']['data']
        input_data = input_data.to(device)
        groundtruth = input_data
        peak_expanded = (data['peak'].unsqueeze(1).unsqueeze(2).unsqueeze(3).unsqueeze(4)).long().to(device)
        input_data = (input_data / peak_expanded)
        masks = data['mask']['data']
        masks = masks.to(device)
        masked_input_data = input_data*torch.where((masks == 0) , 1, 0)

        with torch.no_grad():
                with autocast(enabled=True):
                    noise = torch.randn_like(groundtruth).to(device)
                    noisy_groundtruth = scheduler.add_noise(original_samples=groundtruth, noise=noise, timesteps=timesteps)
                    combined = torch.cat((masked_input_data, noisy_groundtruth), dim=1)
                    prediction = model(x=combined, timesteps=timesteps)
                    val_loss = F.mse_loss(prediction.float(), noise.float())
            val_epoch_loss += val_loss.item()
        #print("Epoch", epoch, "Validation loss", val_epoch_loss / (step + 1))
        print('  * val  ' +
          f'Loss: {val_epoch_loss/len(val_loader):.7f}, ')
    
    if (epoch%1==0):
        torch.save(model.state_dict(), f"./model{epoch}.pt")

In [None]:
noise = torch.randn_like(input_data).to(device)
current_img = noise  # for the segmentation mask, we start from random noise.
combined = torch.cat(
    (masked_input_data, noise), dim=1
)  # We concatenate the input brain MR image to add anatomical information.

scheduler.set_timesteps(num_inference_steps=1000)
progress_bar = tqdm(scheduler.timesteps)
chain = torch.zeros(current_img.shape)
for t in progress_bar:  # go through the noising process
    with autocast(enabled=False):
        with torch.no_grad():
            model_output = model(combined, timesteps=torch.Tensor((t,)).to(current_img.device))
            current_img, _ = scheduler.step(
                model_output, t, current_img
            )  # this is the prediction x_t at the time step t
            combined = torch.cat(
                (masked_input_data, current_img), dim=1
            )  # in every step during