In [None]:
! pip install -q --force-reinstall requests
!python -c "import monai" || pip install -q "monai-weekly[gdown, nibabel, tqdm, ignite]"
!python -c "import aim" || pip install -q aim
!python -c "import matplotlib" || pip install -q matplotlib

%matplotlib inline


In [None]:
import os
import shutil
import tempfile

import matplotlib.pyplot as plt
from tqdm import tqdm

from monai.config import print_config
from monai.metrics import DiceMetric
from monai.networks.nets import UNETR

from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)


from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    Resize,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
    RandCropByPosNegLabeld,
)



import torch

print_config()

In [None]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
if directory is not None:
    os.makedirs(directory, exist_ok=True)
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

In [None]:
  Lambdad
)
from monai.data import CacheDataset, Dataset
import torch

patch_size = 16
spatial_size = (64, 64, 96)

adjusted_spatial_size = tuple(((dim + patch_size - 1) // patch_size) * patch_size for dim in spatial_size)

def debug_transform(data):
    try:
        for transform in val_transforms.transforms:
            data = transform(data)
            print(f"Applied {transform}: {data['image'].shape}")
        return data
    except Exception as e:
        print(f"Error applying transform {transform}: {e}")
        raise

def handle_none(x, name="unknown"):
    if x is None:
        print(f"Warning: Encountered None {name}")
        return torch.zeros((1, 64, 64,96))
    return x

def print_shape(x, name="unknown"):
    if isinstance(x, dict):
        print(f"{name} Shape: {x['image'].shape if 'image' in x else 'No image'}")
    else:
        print(f"{name} Shape: {x.shape if hasattr(x, 'shape') else 'No shape'}")
    return x

def validate_item(data):
    if data is None:
        print("Data is None")
        return False
    if isinstance(data, dict):
        for key, value in data.items():
            if value is None:
                print(f"{key} is None")
                return False
            if isinstance(value, torch.Tensor) and value.numel() == 0:
                print(f"{key} is empty tensor")
                return False
    return True

def safe_transform(item):
    try:
        if not validate_item(item):
            return None
        return train_transforms(item)
    except Exception as e:
        print(f"Error processing item: {item}. Error: {e}")
        return None

class SafeDataset(CacheDataset):  
    def __getitem__(self, index):
        data = self.data[index]
        transformed = safe_transform(data)
        if transformed is None:
            
            return self.__getitem__((index + 1) % len(self))
        return transformed

class ConvertToBinaryLabeld(MapTransform):
    def __call__(self, data):
        d = dict(data)
        d['label'][d['label'] != 2] = 0
        d['label'][d['label'] == 2] = 1
        return d

def get_dynamic_spatial_size(image_size, desired_size):
    return tuple(min(img_size, desired_size) for img_size, desired_size in zip(image_size, desired_size))

class DynamicRandCropByPosNegLabeld:
    def __init__(self, transform, spatial_size):
        self.transform = transform
        self.spatial_size = spatial_size

    def __call__(self, data):
        image_shape = data['image'].shape[1:]
        adaptive_spatial_size = get_dynamic_spatial_size(image_shape, self.spatial_size)
        self.transform.spatial_size = adaptive_spatial_size
        return self.transform(data)

cropper = RandCropByPosNegLabeld(
    keys=["image", "label"],
    label_key="label",
    spatial_size=(64, 64, 96),
    pos=1,
    neg=1,
    num_samples=4,
    image_key="image",
    image_threshold=0
)

dynamic_cropper = DynamicRandCropByPosNegLabeld(cropper, spatial_size)

train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        Lambdad(keys=["image", "label"], func=lambda x: print_shape(x, "LoadImaged")),
        Lambdad(keys=["image", "label"], func=lambda x: handle_none(x, "LoadImaged")),
        EnsureChannelFirstd(keys=["image", "label"]),
        Lambdad(keys=["image", "label"], func=lambda x: print_shape(x, "EnsureChannelFirstd")),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        Lambdad(keys=["image", "label"], func=lambda x: print_shape(x, "Spacingd")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-175,
            a_max=250,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        
        CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True),
        Lambdad(keys=["image", "label"], func=lambda x: print_shape(x, "CropForegroundd")),
        dynamic_cropper,
        Resized(keys=["image", "label"], spatial_size=(64, 64, 96)),
        Resized(keys=["label"], spatial_size=(64, 64, 96), mode='nearest'),
        EnsureTyped(keys=["image", "label"]),
        ConvertToBinaryLabeld(keys=["label"]),
        ToTensord(keys=["image", "label"]),
    ]
)

val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("trilinear", "nearest"),
        ),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
        
        CropForegroundd(keys=["image", "label"], source_key="label"),
        Resized(keys=["image", "label"], spatial_size=(64, 64, 96)),
        Resized(keys=["label"], spatial_size=(64, 64, 96), mode='nearest'),
        EnsureTyped(keys=["image", "label"]),
        
        ConvertToBinaryLabeld(keys=["label"]),
        ToTensord(keys=["image", "label"]),
    ]
)


In [None]:
import os
import torch
import numpy as np

import SimpleITK as sitk
from monai.transforms import Resize, ToTensor
from torch.utils.data import Dataset, DataLoader
from monai.data import CacheDataset, load_decathlon_datalist


# Define the data directory and the dataset JSON file
data_dir = "/home/systemx86/Pictures/Tumor_detection/Code_test/"
split_json = "dataset_0.json"
datasets = data_dir + split_json

# Load the training and validation data lists
datalist = load_decathlon_datalist(datasets, True, "training")
val_files = load_decathlon_datalist(datasets, True, "validation")

# Create CacheDatasets for training and validation, applying transformations here
train_ds = CacheDataset(
    data=datalist,  
    transform=train_transforms,
    cache_num=24,
    cache_rate=1.0,
    num_workers=6,
)

val_ds = CacheDataset(
    data=val_files,  
    transform=val_transforms, 
    cache_num=6,
    cache_rate=1.0,
    num_workers=6,
)

# Wrap your dataset with the new class
target_size = (64,64,96)  # Adjust spatial size according to memory constraints

#train_ds = ConsistentSizeDataset(train_ds, target_size)
#val_ds = ConsistentSizeDataset(val_ds, target_size)

# Update DataLoader with the new dataset
batch_size = 1  # Reduce batch size to lower memory usage

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=6, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=6, pin_memory=True)


In [None]:

case_num = 8
print(f"Case number: {case_num}")

img = val_ds[case_num]['image']
label = val_ds[case_num]['label']

img_meta = val_ds[case_num]['image'].meta
label_meta = val_ds[case_num]['label'].meta
img_name = os.path.split(img_meta['filename_or_obj'])[1] 

# Print shapes
img_shape = img.shape
label_shape = label.shape
print(f"Image shape: {img_shape}, Label shape: {label_shape}")


image_path = f'/media/systemx86/A23089D03089AC3B/Users/arunk/Music/NIT_DATA/Task07_Pancreas/Task07_Pancreas/imagesTr/{img_name}'
label_path = f'/media/systemx86/A23089D03089AC3B/Users/arunk/Music/NIT_DATA/Task07_Pancreas/Task07_Pancreas/labelsTr/{img_name}'

raw_image_img = nib.load(image_path)
raw_label_img = nib.load(label_path)

raw_image_data = raw_image_img.get_fdata()
raw_label_data = raw_label_img.get_fdata()


slices = [30,40]

fig, axes = plt.subplots(3, len(slices), figsize=(15, 15))

for i, slice_idx in enumerate(slices):

    axes[0, i].imshow(raw_image_data[:, :, slice_idx], cmap="gray") 
    axes[0, i].set_title(f"Raw Image - Slice {slice_idx}")
    axes[0, i].axis("off")

    axes[1, i].imshow(raw_label_data[:, :, slice_idx], cmap="gray")
    axes[1, i].set_title(f"Raw Label - Slice {slice_idx}")
    axes[1, i].axis("off")


    axes[2, i].imshow(label[0, :, :, slice_idx].squeeze().detach().cpu().numpy(), cmap="gray")
    axes[2, i].set_title(f"Transformed Label - Slice {slice_idx}")
    axes[2, i].axis("off")

plt.tight_layout()
plt.show()


In [None]:

case_num = 8
print(f"Case number: {case_num}")

img = val_ds[case_num]['image']
label = val_ds[case_num]['label']

img_meta = val_ds[case_num]['image'].meta
label_meta = val_ds[case_num]['label'].meta
img_name = os.path.split(img_meta['filename_or_obj'])[1] 

img_shape = img.shape
label_shape = label.shape
print(f"Image shape: {img_shape}, Label shape: {label_shape}")


label_path = f'/media/systemx86/A23089D03089AC3B/Users/arunk/Music/NIT_DATA/Task07_Pancreas/Task07_Pancreas/labelsTr/{img_name}'
raw_label_img = nib.load(label_path)
raw_label_data = raw_label_img.get_fdata()


slices = [30, 50, 70]


fig, axes = plt.subplots(3, len(slices), figsize=(15, 15)) 

for i, slice_idx in enumerate(slices):
   
    axes[0, i].imshow(raw_label_data[:, :, slice_idx], cmap="gray")
    axes[0, i].set_title(f"Raw Label - Slice {slice_idx}")
    axes[0, i].axis("off")

    axes[1, i].imshow(label[0, :, :, slice_idx].squeeze().detach().cpu().numpy(), cmap="gray")
    axes[1, i].set_title(f"Dataset Label - Slice {slice_idx}")
    axes[1, i].axis("off")

    img_slice = img[0, :, :, slice_idx].detach().cpu().numpy()  
    img_slice = (img_slice - img_slice.min()) / (img_slice.max() - img_slice.min())  # Normalize image 
    axes[2, i].imshow(img_slice, cmap="gray")
    axes[2, i].contour(label[0, :, :, slice_idx].squeeze().detach().cpu().numpy(), colors='red', linewidths=0.5)
    axes[2, i].set_title(f"Image + Label - Slice {slice_idx}")
    axes[2, i].axis("off")

plt.tight_layout()
plt.show()


In [None]:
import torch; print(torch.cuda.is_available())

In [None]:
import torch

torch.cuda.empty_cache()

In [None]:
from torch.cuda.amp import autocast, GradScaler  # Added

torch.cuda.empty_cache()

device = torch.device("cuda")

In [None]:

from monai.data import DataLoader, Dataset, decollate_batch
from monai.networks.nets import UNETR  # Import UNETR
from torch import nn
import torch.nn.functional as F

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
criterion = nn.BCEWithLogitsLoss()

dice_val_best = -float('inf')
best_dice=0.0
global_step_best = 0
patience = 20  
epochs_without_improvement = 0


def validation(epoch_iterator_val, model, dice_metric, post_label, post_pred):
    model.eval()
    with torch.no_grad():
        post_pred = AsDiscrete(to_onehot=2, argmax=True)

        for batch in epoch_iterator_val:
            val_inputs, val_labels = batch.val_ds["image"].to(device), batch.val_ds["label"].to(device)
            val_labels = val_ds.long()

         

            with autocast():
                val_outputs = sliding_window_inference(val_inputs, (64,64,96), 4, model)

            val_labels = torch.squeeze(val_labels, dim=1).long()
            val_outputs = F.interpolate(val_outputs, size=val_labels.shape[1:], mode='trilinear', align_corners=False)

            if torch.isnan(val_outputs).any() or torch.isinf(val_outputs).any():
                print("NaN or Inf values detected in val_outputs during validation")
                break

        
            #print("val_labels shape before post_label:", val_labels.shape) # Print shape before transformation
            val_labels = post_label(val_labels)

      
            val_labels_list = decollate_batch(val_labels)
            val_outputs_list = decollate_batch(val_outputs)

            val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list]
            print("Predictions shape:", val_outputs_list[0].shape)
            print("Predictions data type:", val_outputs_list[0].dtype)
            print("Predictions min:", val_outputs_list[0].min())
            print("Predictions max:", val_outputs_list[0].max())
            print("Labels shape:", val_labels_list[0].shape)
            print("Labels data type:", val_labels_list[0].dtype) 
            print("Labels min:", val_labels_list[0].min())
            print("Labels max:", val_labels_list[0].max())

            dice_metric(y_pred=val_output_convert, y=val_labels_list) 

        
        mean_dice_val = dice_metric.aggregate().item()
        dice_metric.reset()
        
        return mean_dice_val
    
def train(global_step, train_loader, val_loader, model, optimizer, 
          loss_function, scaler, max_iterations, eval_num, root_dir, dice_metric, 
          post_label, post_pred,epochs_without_improvement):
    
    
    
    model.train()
    for step, batch in enumerate(train_loader):
        try:
            images, labels = batch.train_ds['image'].to(device), batch.train_ds['label'].to(device)
            print(f"Step: {step}")
            print(f"Images type: {type(images)}, Labels type: {type(labels)}")
            
            
            print(f"First image shape: {images[0].shape}, First label shape: {labels[0].shape}")
            
       
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        except Exception as e:
            print(f"Error during training at step {step}: {e}")
            break

    model.eval()
    with torch.no_grad():
        for val_step, val_batch in enumerate(val_loader):
            val_images, val_labels = val_batch['image'], val_batch['label']
            val_outputs = model(val_images)
            val_loss = criterion(val_outputs, val_labels)
    epoch_loss = 0
    accumulation_steps = 4 
    optimizer.zero_grad()

    epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True)
    step = 0
    
    global global_step_best 
    
   
    global dice_val_best 
    global best_dice

    for batch in epoch_iterator:
        epoch_loss = 0.0
        x = batch["image"].to(device)
        y = batch["label"].to(device)
        y = torch.squeeze(y, dim=1).long() 
        
        
        y = F.one_hot(y, num_classes=2).permute(0, 4, 1, 2, 3).float()

        with autocast():
            output = model(x) 
            logits = output 
            loss = loss_function(logits, y)  
        
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"NaN or Inf values detected in loss at global step {global_step}")
            print("Debugging NaNs/Infs:")
            print(f"Inputs: {x}")
            print(f"Labels: {y}")
            print(f"Logits: {logits}")
            break

        

        scaler.scale(loss).backward()
        
      
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        if (step + 1) % accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        epoch_loss += loss.item()
        epoch_iterator.set_description("Training (%d / %d Steps) (loss=%2.5f)" % (global_step, max_iterations, loss))

        global_step += 1

        if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations:
            epoch_iterator_val = tqdm(val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True)
            dice_val = validation(epoch_iterator_val, model, val_loader, dice_metric, post_label, post_pred)
            epoch_loss /= (step + 1)
            epoch_loss_values.append(epoch_loss)
            metric_values.append(dice_val)

            print(f"Validation at step {global_step} - Dice Score: {dice_val}, Best Dice Score: {best_dice}")

            if dice_val > best_dice:
                best_dice = dice_val
                global_step_best = global_step
                torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_model.pth"))
                print("Model Was Saved! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(best_dice, dice_val))
                epochs_without_improvement = 0  # Reset counter if improvement
            else:
                print("Model Was Not Saved! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(best_dice, dice_val))
                epochs_without_improvement += 1  # Increment counter if no improvement

            if epochs_without_improvement >= patience:
                print(f"Early stopping at step {global_step} - no improvement for {patience} epochs.")
                break  
        step += 1

    return global_step, global_step_best, epochs_without_improvement  

            
            
        

torch.cuda.empty_cache()
gc.collect()





model = UNETR(
    in_channels=1,
    out_channels=2,  
    img_size=(64,64,96),
    feature_size=16,
    hidden_size=768,
    mlp_dim=3072,
    num_heads=12,
    proj_type="perceptron",
    norm_name="instance",
    res_block=True,
    dropout_rate=0.0,
).to(device)

 
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

loss_function = BCEWithLogitsLoss()
scaler = GradScaler()

max_iterations = 25000
eval_num = 500
post_label = AsDiscrete(to_onehot=2)  
post_pred = AsDiscrete(argmax=True, to_onehot=2)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
global_step = 0
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []


for _ in range(global_step, max_iterations):
    try:
        global_step, global_step_best, epochs_without_improvement = train(
            global_step, train_loader, val_loader, model, optimizer, loss_function, scaler, 
            max_iterations, eval_num, root_dir, dice_metric, post_label, post_pred, 
            epochs_without_improvement  
        )
    except Exception as e:
        print(f"Error during training at global step {global_step}: {e}")
        break

    if epochs_without_improvement >= patience:
        break 

