In [None]:
import os
import torch
import json
from torch import nn
from monai.networks.nets import SwinUNETR
import numpy as np
import random
import time
from glob import glob
import monai
from monai import transforms
from monai.transforms import MapTransform
from monai import data
from monai.metrics import DiceMetric
import matplotlib.pyplot as plt
from monai.transforms import Compose, AsDiscrete, Activations

def get_loader(batch_size):
    train_img_dir = "/data2/user123/c2/c2_1_0/data/data_segment/train"
    val_img_dir = "/data2/user123/c2/c2_1_0/data/data_segment/val"
    train_ann_dir = "/data2/user123/c2/c2_1_0/data/data_segment/train_ann"
    val_ann_dir = "/data2/user123/c2/c2_1_0/data/data_segment/val_ann"
    
    train_img_files = sorted(glob(os.path.join(train_img_dir, "*")))
    train_ann_files = sorted(glob(os.path.join(train_ann_dir, "*"))) 
    train_files = [{"image": img, "label": ann} for img, ann in zip(train_img_files, train_ann_files)]

    val_img_files = sorted(glob(os.path.join(val_img_dir, "*")))
    val_ann_files = sorted(glob(os.path.join(val_ann_dir, "*")))
    validation_files = [{"image": img, "label": ann} for img, ann in zip(val_img_files, val_ann_files)]

    print(len(train_files), len(validation_files))
    train_transform = transforms.Compose(
        [
            transforms.LoadImaged(keys=["image", "label"]),
            transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
            transforms.ToTensord(keys=["image", "label"]),
        ]
    )
    val_transform = transforms.Compose(
        [
            transforms.LoadImaged(keys=["image", "label"]),
            transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
            transforms.ToTensord(keys=["image", "label"]),
        ]
    )

    train_ds = data.Dataset(data=train_files, transform=train_transform)
    train_loader = data.DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=8,
        pin_memory=True,
    )
    val_ds = data.Dataset(data=validation_files, transform=val_transform)
    val_loader = data.DataLoader(
        val_ds,
        batch_size=1,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
    )

    return train_loader, val_loader


def get_default_device():
    if torch.cuda.is_available():
        print("Got CUDA!")
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
  

seed_val = 2
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

device = get_default_device()
torch.multiprocessing.set_sharing_strategy('file_system')

'''' Initialise dataloaders '''
path_save = "/data2/user123/c2/c2_1_0/code/save"
batch_size = 64
train_loader, val_loader = get_loader(batch_size)
print("Dataloaders Initialised") 

''' Initialise the model '''
model = SwinUNETR(
img_size=224,
in_channels=3,
out_channels=1,
feature_size=48,
drop_rate=0.0,
attn_drop_rate=0.0,
dropout_path_rate=0.0,
use_checkpoint=True,
spatial_dims=2).to(device)

loss_function = monai.losses.DiceLoss(include_background=False,sigmoid=True)
#loss_function1 = monai.losses.DiceCELoss(to_onehot_y=False, sigmoid=True, jaccard=True)
step_size = len(train_loader) * 8
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")


In [None]:
max_epochs = 100
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
epoch_val_loss_values = []
metric_values = []
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.4)])
post_label = Compose([AsDiscrete(threshold=0.4)]) 


for epoch in range(max_epochs):
    start_time = time.time()
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )
        optimizer.zero_grad()
        inputs = inputs.permute(0, 3, 1, 2)
        labels = labels.unsqueeze(1)
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()   
        epoch_loss += loss.item()
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    end_time = time.time()
    epoch_time = end_time - start_time  

    print(f"Time taken for epoch {epoch + 1}: {epoch_time:.2f} seconds")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            metric_sum =  0.0
            metric_count =  0
            
            epoch_val_loss = 0
            step = 0
            for val_data in val_loader:
                step += 1
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                val_inputs = val_inputs.permute(0, 3, 1, 2)
                val_labels = val_labels.unsqueeze(1)
                val_outputs = model(val_inputs)
                loss = loss_function(val_outputs, val_labels)
                epoch_val_loss += loss.item()
                val_outputs = post_trans(val_outputs)
                val_labels = post_label(val_labels)
                # compute overall mean dice
                dice_metric(y_pred=val_outputs, y=val_labels)

            epoch_val_loss /= step
            epoch_val_loss_values.append(epoch_val_loss)
            print(f"epoch {epoch + 1} average val loss: {epoch_val_loss:.4f}")

            metric = dice_metric.aggregate().item()
            dice_metric.reset()
            metric_values.append(metric)
            

            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), os.path.join(path_save, "model_segment_swinunetr.pth"))
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f"\nbest mean dice: {best_metric:.4f}"
                f" at epoch: {best_metric_epoch}"
            )


print(f"train completed, best_metric: {best_metric:.4f} " f"at epoch: {best_metric_epoch}")

plt.figure("train", (12, 6))
plt.subplot(1, 3, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("epoch")
plt.plot(x, y, color="red")
plt.subplot(1, 3, 2)
plt.title("Epoch Average Val Loss")
x1 = [2*i + 1 for i in range(len(epoch_val_loss_values))]
y1 = epoch_val_loss_values
plt.xlabel("epoch")
plt.plot(x1, y1, color="blue")
plt.subplot(1, 3, 3)
plt.title("Val Mean Dice")
x = [val_interval * (i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("epoch")
plt.plot(x, y, color="green")
plt.show()


In [None]:
import os
import torch
import pandas as pd
from monai.networks.nets import SwinUNETR
from monai.transforms import Compose, Activations, AsDiscrete
from monai.metrics import DiceMetric, compute_iou
from monai.data import DataLoader
from monai.transforms import LoadImaged, NormalizeIntensityd, ToTensord
from monai.utils import set_determinism
from glob import glob
import monai

# Set the random seed for reproducibility
set_determinism(seed=42)

# Define the paths
model_path = "/data2/user123/c2/c2_1_0/code/save/model_segment_swinunetr.pth"
test_img_dir = "/data2/user123/c2/c2_1_0/data_org_test/Test Dataset 2/Images"
test_ann_dir = "/data2/user123/c2/c2_1_0/data_org_test/Test Dataset 2/Annotations"
output_dir = "/data2/user123/c2/c2_1_0/code/save/anns"

# Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SwinUNETR(
    img_size=224,
    in_channels=3,
    out_channels=1,
    feature_size=12,
    drop_rate=0.5,
    attn_drop_rate=0.0,
    dropout_path_rate=0.25,
    use_checkpoint=True,
    spatial_dims=2
).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()

# Define the transforms
transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ToTensord(keys=["image", "label"]),
])

# Create the test dataset and data loader
test_files = sorted(glob(os.path.join(test_img_dir, "*")))
test_labels = sorted(glob(os.path.join(test_ann_dir, "*")))
test_data = [{"image": img, "label": ann} for img, ann in zip(test_files, test_labels)]
test_dataset = monai.data.Dataset(data=test_data, transform=transforms)
test_loader = DataLoader(test_dataset, batch_size=1, num_workers=8, pin_memory=True)

# Define the metrics
dice_metric = DiceMetric(include_background=False, reduction="mean")

image_names = []
dice_scores = []
iou_scores = []


# Iterate over the test data and calculate the metrics
for i, batch_data in enumerate(test_loader):
    image_name = os.path.basename(test_files[i])
    inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
    inputs = inputs.permute(0, 3, 1, 2)
    labels = labels.unsqueeze(1)
    outputs = model(inputs)
    outputs = Activations(sigmoid=True)(outputs)
    outputs = AsDiscrete(threshold=0.4)(outputs)
    labels = AsDiscrete(threshold=0.4)(labels)
    dice = dice_metric(y_pred=outputs, y=labels)
    iou = compute_iou(y_pred=outputs, y=labels)
    image_names.append(image_name)
    dice_scores.append(dice.item())
    iou_scores.append(iou.item())   

    # Save the segmentation mask
    mask = outputs.squeeze().detach().cpu().numpy()
    mask_path = os.path.join(output_dir, f"{image_name}")
    

# Create a DataFrame to store the results
results_df = pd.DataFrame({
    "Image": image_names,
    "IoU": iou_scores,
    "Dice": dice_scores
    
})



# Calculate the average metrics
average_dice = results_df["Dice"].mean()
average_iou = results_df["IoU"].mean()

# Save the results to an Excel sheet
results_path = os.path.join(output_dir, "results.xlsx")
results_df.to_excel(results_path, index=False)

print("Dice and IoU calculation completed.")
print(f"Average Dice: {average_dice:.4f}")
print(f"Average IoU: {average_iou:.4f}")
print(f"Results saved to: {results_path}")
