# Experimentation: Segmenation
---

In [1]:
# Import libraries
import sys
sys.path.append("..")

from monai.config import print_config
from monai.networks.nets import SwinUNETR
from src.preprocessing import get_transforms, get_datasets, get_dataloaders

import torch
import wandb
import matplotlib.pyplot as plt
import numpy as np
from torch import nn

# Set the device

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

#wandb.login()

cuda:0


In [None]:
# Get transforms
transform = get_transforms()

# Get datasets
train_ds, val_ds = get_datasets(root_dir="../data", collection="HCC-TACE-Seg", transform=transform, download=False, download_len=1, val_frac=0.2)

# Get dataloaders
train_loader, val_loader = get_dataloaders(train_ds, val_ds, batch_size=1, num_workers=4)

train_dl =train_ds.datalist
val_dl = val_ds.datalist


# Check length of datasets and dataloaders
print(train_ds.get_indices(), val_ds.get_indices())
print(len(train_loader), len(val_loader))

In [None]:
# Sample a batch of data from the dataloader
batch = next(iter(train_loader))

# Separate the image and segmentation from the batch
image, seg = batch["image"], batch["seg"]

print(image.shape, seg.shape)

#mip, _ = torch.max(seg, dim=-1)
#mip = mip.squeeze()
#print(mip.size())


In [2]:
import model_source.LRM

# Create an instance of the model
model = model_source.LRM.TACEnet()

model.eval()

TACEnet(
  (lrm): Conv3DTo2D(
    (conv1): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (conv2): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (conv3): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv5): Conv2d(256, 1, kernel_size=(2, 2), stride=(2, 2))
  )
  (cn): ConvNet(
    (conv1): Conv2d(2, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv3): Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)

In [3]:
VesselVolume = torch.rand(1,96,512,512)
DRR = torch.rand(1,1,256,256)

In [4]:
output = model(VesselVolume,DRR)

In [5]:
output.size()

torch.Size([1, 1, 256, 256])

In [None]:
torch.concatenate((torch.rand(1,1,256,256), torch.rand(1,1,256,256)), dim=1).size()

In [None]:
#plt.imshow(mip[2], cmap='gray')

In [None]:
# Create a Swin-UNet model
#model = SwinUNETR(img_size=(96, 96, 32), in_channels=1, out_channels=1, use_v2 = True, spatial_dims=3, normalize=False)


In [None]:
class SwinUNETRSigmoid(nn.Module):
    def __init__(self, img_size, in_channels, out_channels, use_v2=True, spatial_dims=3, normalize=False):
        super(SwinUNETRSigmoid, self).__init__()
        self.model = SwinUNETR(img_size=img_size, in_channels=in_channels, out_channels=out_channels, use_v2=use_v2, spatial_dims=spatial_dims, normalize=normalize)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.model(x)
        x = self.sigmoid(x)
        return x

# Create the model instance
model = SwinUNETRSigmoid(img_size=(96, 96, 32), in_channels=1, out_channels=1, use_v2=True, spatial_dims=3, normalize=False)

In [None]:

wandb_ = True


# Set the model to use the GPU
model = model.to(device)

# Set the hyperparameters for training
max_epochs = 10
learning_rate = 1e-3
weight_decay = 1e-5
val_interval = 1

if wandb_:
    run = wandb.init(
        # Set the project where this run will be logged
        project="HCC TACE", name="test",
        # Track hyperparameters and run metadata
        config={
            "learning_rate": {learning_rate},
            "epochs": {max_epochs},
            "Weight decay": {weight_decay},
            "Batch_size": {1},
        },
    )

# Set the optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), learning_rate, weight_decay=weight_decay)
loss_function = torch.nn.BCEWithLogitsLoss()

# Create the training loop

# Set the model to training mode
model.train()

# Iterate over the training data for the specified number of epochs
for epoch in range(max_epochs):
    for batch in train_loader:
        # Extract the image and segmentation from the batch
        image, seg = batch["image"].to(device), batch["seg"].to(device)

        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass
        pred = model(image).squeeze()
        
        print(pred.size())
        print(seg.size())
        
        # Compute the loss
        loss = loss_function(pred, seg)
        
        # Backward pass
        loss.backward()
        
        # Update the weights
        optimizer.step()
        
        if wandb_:
            run.log({"Train loss": loss.item()})
        
    # Print the loss for each epoch
    print(f"Epoch [{epoch+1}/{max_epochs}], Loss: {loss.item():.4f}")
    
    # Set the model to evaluation mode
    model.eval()
    
    # Compute the validation loss
    with torch.no_grad():
        for batch in val_loader:
            image, seg = batch["image"].to(device), batch["seg"].to(device)
            pred = model(image).squeeze()
            
            loss = loss_function(pred, seg)
        print(f"Validation Loss: {loss.item():.4f}")

        if wandb_:
            run.log({"Validation loss": loss.item()})

    

In [None]:
run.finish()


In [None]:
# Save the model
torch.save(model.state_dict(), "../models/arteries-2.pth")

In [None]:

# TODO: Saving a CT volume tensor to image format (DICOM / NIFTI)
#loss_function = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(10000))


# Load the model
model.load_state_dict(torch.load("../models/arteries-2.pth"))
model.to('cuda')

# Set the model to evaluation mode
model.eval()

# Sample a batch of data from the test dataloader to make predictions
batch = next(iter(train_loader))

# Extract the image and segmentation from the batch
image, seg = batch["image"].to(device), batch["seg"].to(device).squeeze()

# Make predictions on the image
output = model(image).squeeze()



#MIP
#output,_ =torch.max(output, dim = -1)
seg,_ = torch.max(seg, dim = -1)

# Get the predicted segmentation class for each pixel
pred_seg,_ = torch.max(output, dim = -1)


# Visualize the image, ground truth segmentation, and predicted segmentation
import matplotlib.pyplot as plt
import numpy as np

slice_idx = 16



# Convert the image, ground truth segmentation, and predicted segmentation to NumPy arrays
image = image.cpu().detach().numpy()
seg = seg.cpu().detach().numpy()
pred_seg = pred_seg.cpu().detach().numpy()

# Get the first image, ground truth segmentation, and predicted segmentation from the batch
image = image[0, 0, :, :, slice_idx]



# Plot the image, ground truth segmentation, and predicted segmentation
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(image, cmap="gray")
axes[0].set_title("Image")
axes[1].imshow(seg[0], cmap="gray")
axes[1].set_title("Ground Truth Segmentation")
axes[2].imshow(pred_seg[0], cmap="gray")
axes[2].set_title("Predicted Segmentation")
plt.show()



In [None]:

from matplotlib.colors import LinearSegmentedColormap

batch = next(iter(train_loader))

# Extract the image and segmentation from the batch
image, seg = batch["image"], batch["seg"].squeeze()
mip,_ = torch.max(seg, dim = -1)

image = image[0, 0, :, :, 48]



fig, ax = plt.subplots()
colors = [(0, 0, 0, 0), (1, 0, 0, 1)]  # Start with transparent black, end with opaque red
cmap_name = 'transparent_red'
cm = LinearSegmentedColormap.from_list(cmap_name, colors)
ax.imshow(image, cmap="gray")
ax.imshow(mip[0], cmap=cm)

In [None]:
plt.boxplot(pred_seg[0].flatten())

In [None]:
plt.boxplot(seg[0].flatten())