In [2]:
import os
import torch
import monai
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd, ScaleIntensityRanged, CropForegroundd,
    RandCropByPosNegLabeld, RandFlipd, RandRotate90d, RandShiftIntensityd, EnsureTyped, EnsureType, DivisiblePadd
)
from monai.data import DataLoader, CacheDataset
from monai.networks.nets import SwinUNETR
from monai.losses import DiceLoss
from monai.utils import set_determinism
from monai.data import decollate_batch
from monai.transforms import DivisiblePad
from monai.data.image_reader import NibabelReader
import pty
from sklearn.metrics import mean_squared_error
import nibabel as nib

In [3]:
pty.fork = lambda: (0, 0)

In [4]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['PYTHONWARNINGS'] = 'ignore::RuntimeWarning'


# Set deterministic training for reproducibility
set_determinism(seed=0)

# modality_keys = ["flair", "t1", "t1ce", "t2"]
# 'dataset/MICCAI_BraTS2020_TrainingData'
# Define directories
train_path = 'dataset/MICCAI_BraTS2020_TrainingData/'
val_path = 'dataset/MICCAI_BraTS2020_ValidationData/'
# modality_keys = ["flair"]


In [5]:
# Function to create a list of data dictionaries
def create_data_list(data_dir):
    data_list = []
    patients = os.listdir(data_dir)
    for patient in patients:
        patient_dir = os.path.join(data_dir, patient)
        if os.path.isdir(patient_dir):
            data_dict = {
                "flair": os.path.join(patient_dir, f"{patient}_flair.nii"),
                "t1": os.path.join(patient_dir, f"{patient}_t1.nii"),
                "t1ce": os.path.join(patient_dir, f"{patient}_t1ce.nii"),
                "t2": os.path.join(patient_dir, f"{patient}_t2.nii")
            }
            data_list.append(data_dict)
    return data_list


In [6]:
train_data_list = create_data_list(train_path)
val_data_list = create_data_list(val_path)


In [7]:

# Define transforms
train_transforms = Compose(
    [
        LoadImaged(keys=["flair", "t1", "t1ce", "t2"], reader=NibabelReader()),
        EnsureChannelFirstd(keys=["flair", "t1", "t1ce", "t2"]),
        Spacingd(
            keys=["flair", "t1", "t1ce", "t2"],
            pixdim=(2.0, 2.0, 2.0),  # Adjust pixdim to a slightly larger value
            mode=("bilinear"),
        ),
        Orientationd(keys=["flair", "t1", "t1ce", "t2"], axcodes="RAS"),
        ScaleIntensityRanged(
            keys=["flair", "t1", "t1ce", "t2"], a_min=-175, a_max=250,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["flair", "t1", "t1ce", "t2"], source_key="flair", allow_smaller=True),  # Explicitly set allow_smaller
        DivisiblePadd(keys=["flair", "t1", "t1ce", "t2"], k=32),  # Padding to make dimensions divisible by 32
        RandCropByPosNegLabeld(
            keys=["flair", "t1", "t1ce", "t2"],
            label_key="flair",
            spatial_size=(64, 64, 64),  # Adjust the spatial size as needed
            pos=1,
            neg=1,
            num_samples=1,
            image_key="flair",
            image_threshold=0,
        ),
        RandFlipd(keys=["flair", "t1", "t1ce", "t2"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["flair", "t1", "t1ce", "t2"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["flair", "t1", "t1ce", "t2"], prob=0.5, spatial_axis=2),
        RandRotate90d(keys=["flair", "t1", "t1ce", "t2"], prob=0.5, max_k=3),
        RandShiftIntensityd(keys=["flair", "t1", "t1ce", "t2"], offsets=0.10, prob=0.5),
        EnsureTyped(keys=["flair", "t1", "t1ce", "t2"]),
    ]
)


OverflowError: Python integer 4294967296 out of bounds for uint32

In [None]:
# Create datasets and dataloaders
train_ds = CacheDataset(
    data=train_data_list,
    transform=train_transforms,
#     cache_rate=0.5,
#     num_workers=4,
)
train_loader = DataLoader(train_ds, batch_size= 1, shuffle=True, num_workers=1)


Loading dataset: 100%|██████████| 369/369 [05:43<00:00,  1.08it/s]


In [None]:
val_transforms = Compose(
    [
        LoadImaged(keys=["flair", "t1", "t1ce", "t2"], reader=NibabelReader()),
        EnsureChannelFirstd(keys=["flair", "t1", "t1ce", "t2"]),
        Spacingd(
            keys=["flair", "t1", "t1ce", "t2"],
            pixdim=(2.0, 2.0, 2.0),  # Adjust pixdim to a slightly larger value
            mode=("bilinear"),
        ),
        Orientationd(keys=["flair", "t1", "t1ce", "t2"], axcodes="RAS"),
        ScaleIntensityRanged(
            keys=["flair", "t1", "t1ce", "t2"], a_min=-175, a_max=250,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["flair", "t1", "t1ce", "t2"], source_key="flair", allow_smaller=True),  # Explicitly set allow_smaller
        DivisiblePadd(keys=["flair", "t1", "t1ce", "t2"], k=32),  # Padding to make dimensions divisible by 32
        EnsureTyped(keys=["flair", "t1", "t1ce", "t2"]),
    ]
)


In [None]:

val_ds = CacheDataset(
    data=val_data_list,
    transform=val_transforms,
#     cache_rate=0.5,
#     num_workers=4,
)
val_loader = DataLoader(val_ds, batch_size= 1, shuffle=False, num_workers=1)


Loading dataset: 100%|██████████| 125/125 [02:55<00:00,  1.41s/it]


In [None]:
# Define model, loss, optimizer
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = SwinUNETR(
    img_size=(64, 64, 64),  # Adjust image size accordingly
    in_channels=4,
    out_channels=4,
    feature_size=48,
    use_checkpoint=True,
).to(device)


  return torch._C._cuda_getDeviceCount() > 0


In [None]:
loss_function = torch.nn.MSELoss()  # Mean Squared Error Loss for reconstruction
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)


In [None]:
# Training loop
max_epochs = 10
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []


for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs = torch.cat([batch_data["flair"], batch_data["t1"], batch_data["t1ce"], batch_data["t2"]], dim=1).to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, inputs)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        # print(f"{step}/{len(train_ds) // train_loader.batch_size}, train_loss: {loss.item():.4f}")
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

print("Training completed.")


----------
epoch 1/10




epoch 1 average loss: 0.0231
----------
epoch 2/10




epoch 2 average loss: 0.0058
----------
epoch 3/10




epoch 3 average loss: 0.0040
----------
epoch 4/10




epoch 4 average loss: 0.0030
----------
epoch 5/10




epoch 5 average loss: 0.0026
----------
epoch 6/10




epoch 6 average loss: 0.0024
----------
epoch 7/10




epoch 7 average loss: 0.0020
----------
epoch 8/10




epoch 8 average loss: 0.0014
----------
epoch 9/10




epoch 9 average loss: 0.0012
----------
epoch 10/10




epoch 10 average loss: 0.0012
Training completed.


In [None]:
# Save the model state dictionary to a file
model_save_path = "model_saved/swin_unetr_reconstruction.pth"
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

Model saved to model_saved/swin_unetr_reconstruction.pth


In [None]:
# Load the model from the saved state dictionary
loaded_model = SwinUNETR(
    img_size=(128, 128, 128),
    in_channels=4,  # 4 modalities as input
    out_channels=4,  # 4 modalities as output for reconstruction
    feature_size=48,
    use_checkpoint=True,
).to(device)
loaded_model.load_state_dict(torch.load(model_save_path))
print("Model loaded from saved state dictionary")

Model loaded from saved state dictionary


In [None]:
# Define a simple evaluation loop and calculate accuracy (MSE in this case)
mse_values = []
with torch.no_grad():
    for val_data in val_loader:
        val_images = torch.cat([val_data["flair"], val_data["t1"], val_data["t1ce"], val_data["t2"]], dim=1).to(device)
        val_outputs = loaded_model(val_images)
        
        # Calculate MSE for each sample and store the value
        mse_value = mean_squared_error(val_images.cpu().numpy().flatten(), val_outputs.cpu().numpy().flatten())
        mse_values.append(mse_value)

# Calculate and print the average MSE
average_mse = sum(mse_values) / len(mse_values)
print(f"Validation Mean Squared Error: {average_mse}")

Validation Mean Squared Error: 0.000937357866205275


In [None]:
# Save the model state dictionary to a file
model_save_path = "model_saved/4_modality_swin_unetr_reconstruction.pth"
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

Model saved to model_saved/4_modality_swin_unetr_reconstruction.pth
