# 2d Segmentation of Sagittal Lumbar Spine MRI

- Training data Spider dataset (https://doi.org/10.5281/zenodo.10159290)
- Very simple model using segmentation_models_pytorch
- Used ChatGPT and Gemini to help with coding
- Trained model attached
- Images resized to 256x256

In [1]:
!pip install segmentation_models_pytorch -q

In [9]:
import numpy as np 
import pandas as pd 
import os
from pathlib import Path
from PIL import Image

from sklearn.model_selection import KFold
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

from segmentation_models_pytorch import Unet

In [10]:
#transforms
newsize = (256, 256)
#dataset
fold = 1
#dataloader
batch_size = 64
num_workers = 4
#model
num_classes = 20
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")#run
epochs = 100
learning_rate = 1e-3

TRAIN = True #or False for inference only

### Model

In [11]:
model = Unet(
  encoder_name="resnet34",  # Choose encoder (e.g. resnet18, efficientnet-b0)
  classes=num_classes,  # Number of output classes
  in_channels=3  # Number of input channels (e.g. 3 for RGB)
)

### Create folds

In [5]:
output_dir = "/kaggle/input/spider-mri-spine-t2-png/data"
im_dir = os.path.join(output_dir, "images")
mask_dir = os.path.join(output_dir, "masks")

# get list of data
items = list(Path(im_dir).glob("*.png"))
image_names = [o.name for o in items]
images = list(set([o.split('_')[0] for o in image_names]))

fold_df = pd.DataFrame({"image_name": images})
# Seed for reproducibility
np.random.seed(42)

# Split the DataFrame into 5 folds
kf = KFold(n_splits=5, shuffle=True, random_state=42)
for i, (_, v_ind) in enumerate(kf.split(fold_df)):
    fold_df.loc[v_ind, 'fold'] = i+1

# Create df with image_names and their respective folds
def get_fold(fn, df):
    image_name = fn.name.split("_")[0] 
    return df.loc[df.image_name==image_name, 'fold'].values[0]

folds = [get_fold(o, fold_df) for o in items]
df = pd.DataFrame({"image": image_names, "fold": folds})

df.head()

ValueError: Cannot have number of splits n_splits=5 greater than the number of samples: n_samples=0.

### Dataset class

In [None]:
class SEGDataset(Dataset):
    def __init__(self, df, mode, transforms=None):
        self.df = df.reset_index()
        self.mode = mode
        self.transforms = transforms

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        row = self.df.iloc[index]

        image_path = os.path.join(im_dir, row.image)
        mask_path = os.path.join(mask_dir, row.image)

        # Open image
        image = Image.open(image_path)
        if image.mode != 'RGB':  # Ensure image is RGB
            image = image.convert('RGB')
        image = np.asarray(image)
        if (image > 1).any():  # Normalize if pixel values are between 0-255
            image = image / 255.0

        # Open mask
        mask = Image.open(mask_path)
        mask = np.asarray(mask)
        assert mask.max() < num_classes, f"Mask value {mask.max()} exceeds number of classes {num_classes}"

        # Apply transformations
        if self.transforms is not None:
            transformed = self.transforms(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"]
        
        # Create one layer for each label
        mask = torch.as_tensor(mask).long()
        mask = torch.nn.functional.one_hot(mask, num_classes=num_classes).permute(2,0,1).float()
        #mask = torch.nn.functional.one_hot(mask, num_classes=num_classes).permute(0,3,1,2).squeeze(0).float()

        # Convert image to tensor
        image = torch.as_tensor(image).float()

        return image, mask          

### Transforms

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

transforms_train = A.Compose([
    A.Resize(newsize[0], newsize[1]),
    A.HorizontalFlip(),
    A.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
    ToTensorV2()
])

transforms_valid = A.Compose([
    A.Resize(newsize[0], newsize[1]),
    A.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
    ToTensorV2()
])

### Loss

In [None]:
class CombinedLoss(nn.Module):
    def __init__(self, weight_ce=1.0, weight_iou=1.0):
        super(CombinedLoss, self).__init__()
        self.weight_ce = weight_ce
        self.weight_iou = weight_iou
        self.cross_entropy_loss = nn.CrossEntropyLoss()

    def forward(self, inputs, targets):
        # Cross-Entropy Loss
        ce_loss = self.cross_entropy_loss(inputs, targets)

        # IoU Loss
        # Apply softmax to the inputs to get probabilities
        probs = F.softmax(inputs, dim=1)

        intersection = torch.sum(probs * targets, dim=(2, 3))
        union = torch.sum(probs + targets, dim=(2, 3)) - intersection
        iou = (intersection + 1e-6) / (union + 1e-6)
        iou_loss = 1 - iou.mean()

        # Combine losses
        loss = self.weight_ce * ce_loss + self.weight_iou * iou_loss
        return loss

### Create datasets and dataloaders

In [None]:
train_ = df[df['fold'] != fold].reset_index(drop=True)
valid_ = df[df['fold'] == fold].reset_index(drop=True)

dataset_train = SEGDataset(train_, 'train',  transforms_train)
dataset_valid = SEGDataset(valid_, 'valid',  transforms_valid)

train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = torch.utils.data.DataLoader(dataset_valid, batch_size=batch_size, shuffle=False, num_workers=num_workers)

### Run function

In [None]:
from torch import optim
from torch.nn import BCEWithLogitsLoss
from torch.optim.lr_scheduler import ReduceLROnPlateau


def run(train_loader, val_loader, model, learning_rate, criterion, epochs, device):
  """
  Trains a U-net model for multi-label segmentation.

  Args:
      train_loader: DataLoader for training data.
      val_loader: DataLoader for validation data.
      model: U-net model instance.
      learning_rate: Learning rate for optimizer.
      epochs: Number of epochs to train.
      device: Device to use for training (CPU or GPU).
  """
  # Define loss function and optimizer
  optimizer = optim.Adam(model.parameters(), lr=learning_rate)

  # Define a learning rate scheduler
  scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)


  # Training loop
  for epoch in range(epochs):
    model.train()
    train_loss = 0.0
    for images, masks in train_loader:
      images, masks = images.to(device), masks.to(device)

      # Forward pass and calculate loss
      outputs = model(images)
      loss = criterion(outputs, masks)

      # Backward pass and update weights
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      train_loss += loss.item()

    train_loss /= len(train_loader)

    # Validation step (optional)
    model.eval()
    with torch.no_grad():
      val_loss = 0.0
      for images, masks in val_loader:
        images, masks = images.to(device), masks.to(device)
        outputs = model(images)
        val_loss += criterion(outputs, masks).item()

    val_loss /= len(val_loader)
    
    # Step the scheduler
    scheduler.step(val_loss)

    # Print training and validation loss
    print(f"Epoch: {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

### Train

In [6]:
criterion = CombinedLoss()
model.to(device)

if TRAIN:
    run(train_loader, val_loader, model, learning_rate, criterion, epochs, device)
else:
    model.load_state_dict(torch.load("/kaggle/input/simple_unet_2d_lspine/pytorch/one/1/simple_unet.pth"))
                      

NameError: name 'CombinedLoss' is not defined

### Inference

In [7]:
import matplotlib.pyplot as plt

def inference(model, dataloader, device, num_samples=16):
    model.eval()
    images_batch = []
    preds_batch = []
    
    with torch.no_grad():
        for images, _ in dataloader:
            images = images.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)
            
            images_batch.append(images.cpu())
            preds_batch.append(preds.cpu())
            
            if len(images_batch) * images.size(0) >= num_samples:
                break

    images_batch = torch.cat(images_batch)[:num_samples]
    preds_batch = torch.cat(preds_batch)[:num_samples]
    
    return images_batch, preds_batch


# Define a color map with fixed colors for each label
def get_label_colors(num_classes):
    colors = plt.cm.tab20(np.linspace(0, 1, num_classes))
    return colors

def visualize_predictions(images, masks, num_classes=20, num_samples=16):
    num_samples = min(num_samples, len(images))
    plt.figure(figsize=(20, 20))
    
    label_colors = get_label_colors(num_classes)
    
    for i in range(num_samples):
        plt.subplot(4, 8, i * 2 + 1)
        im = images[i].numpy()
        im = np.transpose(im, (1, 2, 0))
        #denormalize
        im = ((im * [0.229, 0.224, 0.225]) + [0.485, 0.456, 0.406]) * 255
        plt.imshow(im)
        plt.title("Input Image")
        plt.axis('off')
        
        plt.subplot(4, 8, i * 2 + 2)
        mask = masks[i].numpy()

        color_mask = np.zeros((mask.shape[0], mask.shape[1], 3))
        for label in range(num_classes):
            color_mask[mask == label] = label_colors[label][:3] * 255
        
        plt.imshow(color_mask.astype(np.uint8))
        plt.title("Predicted Mask")
        plt.axis('off')

    plt.show()

In [8]:
model.eval()
model.to(device)

images, masks = inference(model, val_loader, device, num_samples=160000)
visualize_predictions(images, masks, num_samples=160000)

NameError: name 'val_loader' is not defined

### Save model

In [40]:
torch.save(model.state_dict(), './simple_unet.pth')