In [None]:
# Imports
import pandas as pd
from pathlib import Path
from sklearn.model_selection import train_test_split
from transformers import (
    SegformerForSemanticSegmentation, 
    TrainingArguments,
    Trainer
)
from dsa_helpers.ml.transforms.segformer_transforms import (
    train_transforms, val_transforms
)
from dsa_helpers.ml.datasets.utils import create_segformer_segmentation_dataset

import torch
from torch import nn
import numpy as np

idx2label = {
    1: "Gray Matter",
    2: "White Matter",
    3: "Superficial",
    4: "Leptomeninges"
}

LINE_COLORS = {
    1: "rgb(0,128,0)",
    2: "rgb(0,0,255)",
    3: "rgb(255,255,0)",
    4: "rgb(0,0,0)"
}

FILL_COLORS = {
    1: "rgba(0,128,0,0.5)",
    2: "rgba(0,0,255,0.5)",
    3: "rgba(255,255,0,0.5)",
    4: "rgba(0,0,0,0.5)"
}

In [None]:

# Functions
def compute_metrics(batch_output: tuple[np.ndarray, np.ndarray]) -> dict:
    """Compute metrics for a batch of predictions from SegFormer model
    of multi-class labels.
    
    Args:
        batch_output (tuple[numpy.ndarray, numpy.ndarray]): Tuple of logits and 
            labels.
        
    Returns:
        dict: Dictionary of metrics.
    
    """
    with torch.no_grad():
        logits, labels = batch_output
        
        # Convert logits to tensor.
        logits_tensor = torch.from_numpy(logits).cpu()
        labels_tensor = torch.from_numpy(labels).cpu()
        
        # From logits get the number of classes in the dataset.
        num_classes = logits.shape[1]
        
        # Scale the logits back to the shape of the labels.
        logits_tensor = nn.functional.interpolate(
            logits_tensor,
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        )
        
        # Turn the logits to predictions.
        pred = logits_tensor.argmax(dim=1)
            
        # Calculate the IoU for each class.
        ious = []
        
        metrics = {}
        
        for cls in range(num_classes):
            pred_mask = (pred == cls)
            labels_mask = (labels_tensor == cls)
            
            intersection = torch.logical_and(pred_mask, labels_mask).sum().item()
            union = torch.logical_or(pred_mask, labels_mask).sum().item()
            
            if union == 0:
                iou = float('nan')  # or 1.0
            else:
                iou = intersection / union
                
            metrics[f'label {cls} IoU'] = iou
                
            ious.append(iou)
            
        mean_iou = np.nanmean(ious)  # Use NumPy for NaN handling
        metrics['mean_iou'] = mean_iou
            
        return metrics


In [None]:
# Constants
DATASET_DIR = "/rayCode/bert2024/candanoDev/uk_wsi"
RANDOM_STATE = 42
VAL_FRAC = 0.2  # fraction of WSIs to use for validation
LABEL_2_IDX = {
    "Background": 0,
    "Gray Matter": 1,
    'White Matter': 2,
    'Superficial': 3,
    "Leptomeninges": 4
}
SAVE_DIR = "/rayCode/bert2024/candanoDev/uk_wsi/experiments"
LEARNING_RATE = 1e-4
EPOCHS = 10
BATCH_SIZE = 16
EVAL_ACCUMULATION_STEPS = 100


In [None]:
dataset_dir = Path(DATASET_DIR)

# Metadata for each tile image.
tile_df = pd.read_csv(dataset_dir / "tile_10_metadata.csv")
val_df = pd.read_csv(dataset_dir / "tile_test10_metadata.csv")
print(f"{len(tile_df)} tiles found.")
tile_df.head()
# Get list of unique source WSI image names.
train_wsi_names = tile_df['wsi_name'].unique()
val_wsi_names = val_df['wsi_name'].unique()

# Split the unique WSI names into training and validation sets.
#train_wsi_names, val_wsi_names = train_test_split(wsi_names, test_size=VAL_FRAC, random_state=RANDOM_STATE)

print(f"Training set has {len(train_wsi_names)} WSI images.")
print(f"Validation set has {len(val_wsi_names)} WSI images.")
# Split the tiles into training and validation sets.
train_tiles_df = tile_df[tile_df["wsi_name"].isin(train_wsi_names)]
val_tiles_df = tile_df[val_df["wsi_name"].isin(val_wsi_names)]

## Remove after testing
train_tiles_df = train_tiles_df.head(900)
val_tiles_df = val_tiles_df.head(100)

print(f"Training set has {len(train_tiles_df)} tiles.")
print(f"Validation set has {len(val_tiles_df)} tiles.")
# Create SegFormer dataset objects.
train_dataset = create_segformer_segmentation_dataset(  
    train_tiles_df, transforms=train_transforms
)
val_dataset = create_segformer_segmentation_dataset(
    val_tiles_df, transforms=val_transforms
)

In [None]:
# Get ID to label from label to id.
id_2_label = {v: k for k, v in LABEL_2_IDX.items()}

# Load a model to start training from.
model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/mit-b0", id2label=id_2_label, label2id=LABEL_2_IDX
)


In [None]:
model_dir = Path(SAVE_DIR) / "model"

training_args = TrainingArguments(
    model_dir,
    learning_rate=LEARNING_RATE,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    eval_accumulation_steps=EVAL_ACCUMULATION_STEPS,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    load_best_model_at_end=True,
    save_total_limit=1,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [None]:
from shapely.geometry import Polygon
from shapely.affinity import translate
import cv2 as cv


# Functions
def mask_to_geojson(
    mask: str | np.ndarray, x_offset: int = 0, y_offset: int = 0, 
    background_label: int = 0, min_area: int = 0
) -> list[Polygon, int]:
    """
    Extract contours from a label mask and convert them into shapely 
    polygons.
    
    Args:
        mask (str | np.ndarray): Path to the mask image or the mask.
        x_offset (int): Offset to add to x coordinates of polygons.
        y_offset (int): Offset to add to y coordinates of polygons.
        background_label (int): Label value of the background class, 
            which is ignored.
    
    Returns:
        list[Polygon, int]: List of polygons and their corresponding 
            labels.
            
    """
    if isinstance(mask, str):
        mask = imread(mask, grayscale=True)
    
    # Find unique labels (excluding background 0)
    labels = [label for label in np.unique(mask) if label != background_label]
    
    polygons = []  # Track all polygons.
        
    # Loop through unique label index.
    for label in labels:
        # Filter to mask for this label.
        label_mask = (mask == label).astype(np.uint8)
        
        # Find contours.
        contours, hierarchy = cv.findContours(
            label_mask, cv.RETR_TREE, cv.CHAIN_APPROX_SIMPLE
        )
        
        # Process the contours.
        polygons_dict = {}

        for idx, (contour, h) in enumerate(zip(contours, hierarchy[0])):
            if len(contour) > 3:
                if idx not in polygons_dict:
                    polygons_dict[idx] = {"holes": []}
                
                if h[3] == -1:
                    polygons_dict[idx]["polygon"] = contour.reshape(-1, 2)
                else:
                    polygons_dict[h[3]]["holes"].append(contour.reshape(-1, 2))
        
        # Now that we know the polygon and the holes, create a Polygon object for each.
        for data in polygons_dict.values():   
            if 'polygon' in data: 
                polygon = Polygon(data["polygon"], holes=data["holes"])
                
                # Shift the polygon by the offset.
                polygon = translate(polygon, xoff=x_offset, yoff=y_offset)
                            
                # Skip small polygons.
                if polygon.area >= min_area:
                    polygons.append([polygon, label])
        
    return polygons

In [None]:
# Constants
BATCH_SIZE = 1
# Find the checkpoint directory in the model directory.
checkpoint_dir = None

# Load the model.
model = SegformerForSemanticSegmentation.from_pretrained("experiments/model/checkpoint-174")
# Read the datasets from csv.
train_tiles_df = pd.read_csv("tile_test10_metadata.csv")
val_tiles_df = pd.read_csv("tile_test10_metadata.csv")

# Create SegFormer dataset objects.
train_dataset = create_segformer_segmentation_dataset(  
    train_tiles_df, transforms=train_transforms
)
val_dataset = create_segformer_segmentation_dataset(
    val_tiles_df, transforms=val_transforms
)

In [None]:
list(val_tiles_df['wsi_name'].unique())

In [None]:
from tqdm.notebook import tqdm
from geopandas import GeoDataFrame

from dsa_helpers.girder_utils import login

gc = login("http://bdsa.pathology.emory.edu:8080/api/v1")
itms_list = [i for i in gc.listItem('673b566f900c0c0559bf156f')]


# Predict labels on the validation dataset.
for wsi_name in list(val_tiles_df['wsi_name'].unique()):
    wsi_tiles_df = val_tiles_df[val_tiles_df['wsi_name'] == wsi_name]
    # Create dataset for this WSI tiles.
    wsi_dataset = create_segformer_segmentation_dataset(
        wsi_tiles_df, transforms=val_transforms
    )    
    # Predict the labels.
    predictions = trainer.predict(wsi_dataset).predictions
    
    pred_tensor = torch.from_numpy(predictions)

    # From logits get the number of classes in the dataset.
    num_classes = pred_tensor.shape[1]

    # Scale the logits back to the shape of the labels.
    pred_tensor = nn.functional.interpolate(
        pred_tensor,
        size=(512, 512),
        mode="bilinear",
        align_corners=False,
    )

    # Turn the logits to predictions.
    pred_tensor = pred_tensor.argmax(dim=1)
    pred_tensor.shape

    # Create location to save predictions.
    polygons_with_labels = []
    tqdm.pandas()
    sf = 5 / 40  # hard-coded scaling factor to go from scan magnification to tile magnification

    for i in range(len(wsi_tiles_df)):
        row = wsi_tiles_df.iloc[i]
        pred = pred_tensor[i].numpy()
        
        # If the tile has no predictions, skip it.
        if np.count_nonzero(pred) == 0:
            continue
        
        # Process the mask for polygons.
        polygons_with_labels.extend(mask_to_geojson(
            pred, x_offset=int(row["x"] ), y_offset=int(row["y"])
        ))
    
    gdf = GeoDataFrame(polygons_with_labels, columns=["geometry", "label"])

    # Apply a buffer to make edges touch.
    gdf["geometry"] = gdf["geometry"].buffer(1)

    gdf.plot(edgecolor="black", column="label")
    # Dissolve the dataframe by the label.
    gdf_dissolved = gdf.dissolve(by="label", as_index=False)

    # Apply the buffer to the dissolved dataframe.
    gdf_dissolved["geometry"] = gdf_dissolved["geometry"].buffer(-1)
    gdf_dissolved.plot(edgecolor="black", column="label")

    for dict_item in itms_list:
        if dict_item['name']== wsi_name:
            print(dict_item['_id'])
            save_id = dict_item['_id']
            print(dict_item['name'])

    # Format the multipolygons into DSA annotation format.
    elements = []
    tolerance = 0.5

    for idx, row in gdf_dissolved.iterrows():
        list_of_polygons = list(row['geometry'].geoms)
        label = row['label']
        
        for poly in list_of_polygons:
            exterior_poly = list(poly.exterior.coords)
            interior_polys = [list(interior.coords) for interior in poly.interiors]
            
            if len(exterior_poly) > 1000:
                poly = poly.simplify(tolerance, preserve_topology=True)
                exterior_poly = list(poly.exterior.coords)
            
            points = [
                [int(xy[0]) * 8, int(xy[1]) * 8, 0] for xy in exterior_poly # *2 for 20 mag
            ]
            
            holes = []
            
            for interior_poly in interior_polys:
                hole = [
                    [int(xy[0]) * 8, int(xy[1]) * 8, 0] for xy in interior_poly # *2 for 20 mag
                ]
                holes.append(hole)

            element = {
                    "points": points,
                    "fillColor": FILL_COLORS[label],
                    "lineColor": LINE_COLORS[label],
                    "type": "polyline",
                    "lineWidth": 2,
                    "closed": True,
                    "label": {"value": idx2label[label]},
                    "group": "Gray White Segmentation"
            }
            
            if len(holes):
                element["holes"] = holes
                
            elements.append(element)
    gc = login("http://bdsa.pathology.emory.edu:8080/api/v1")
    _ = gc.post(
        f'/annotation?itemId={save_id}', 
        json={
            'name': "temp", 
            'description': '', 
            'elements': elements
        }
    )