In [2]:
import torch
import torch.nn.functional as F
from torch import nn
from torch_geometric.nn import GraphConv, global_mean_pool
import torch.distributed as dist
from datetime import timedelta
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import Subset, SubsetRandomSampler
from torch.amp import GradScaler, autocast
from sklearn.neighbors import kneighbors_graph
from torch_geometric.nn import GraphConv, global_mean_pool
from torch_geometric.data import Data, Dataset
from torch_geometric.data.data import BaseData
from torch_geometric.loader import DataLoader
from torch_geometric.utils import add_self_loops
from torch.nn.parallel import DistributedDataParallel as DDP
from sklearn.metrics import precision_score, accuracy_score, confusion_matrix, classification_report,roc_auc_score, roc_curve, auc
from sklearn.utils.class_weight import compute_class_weight
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import StratifiedKFold
from GNN_model_utils import *
import optuna
import copy
import torch.multiprocessing as mp

In [2]:
class GatedAttentionWithInstanceClassifier(nn.Module):
    """
    A variant that computes instance-level logits + an attention weighting,
    then aggregates instance logits to produce a slide-level prediction.
    """
    def __init__(self, input_dim, hidden_dim, num_classes):
        super().__init__()
        # For gating-based attention
        self.w_h = nn.Linear(input_dim, hidden_dim, bias=True)
        self.w_g = nn.Linear(input_dim, hidden_dim, bias=True)
        self.w_a = nn.Linear(hidden_dim, 1, bias=False)
        
        # Instance-level classifier
        self.instance_classifier = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)  # Per-tile logits
        )

    def forward(self, x, batch=None):
        """
        Args:
            x: [N, in_dim] tile embeddings
            batch: [N], which slide each tile belongs to
        Returns:
            slide_logits: [num_slides_in_batch, num_classes]
            instance_logits: [N, num_classes]
            att_weights: [N], attention weight per tile
        """
        if batch is None:
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        
        # 1) Per-tile (instance) logits
        instance_logits = self.instance_classifier(x)  # shape [N, num_classes]
        
        # 2) Compute gated attention scores
        h = torch.tanh(self.w_h(x))      # [N, hidden_dim]
        g = torch.sigmoid(self.w_g(x))   # [N, hidden_dim]
        alpha_logits = self.w_a(h * g).squeeze(-1)  # shape [N]
        
        # 3) Aggregate instance logits -> slide logits
        unique_graph_ids = batch.unique()
        slide_logits_list = []
        att_weights_list = torch.zeros_like(alpha_logits)

        for gid in unique_graph_ids:
            mask = (batch == gid)
            alpha_g = alpha_logits[mask]         # shape [n_tiles_in_slide]
            inst_logits_g = instance_logits[mask]  # [n_tiles_in_slide, num_classes]
            
            # Normalize attention weights within this slide
            alpha_g_softmax = F.softmax(alpha_g, dim=0)
            alpha_g_softmax = alpha_g_softmax.to(att_weights_list.dtype)  
            att_weights_list[mask] = alpha_g_softmax

            # Weighted sum of instance logits => slide-level logit
            slide_logit = torch.sum(alpha_g_softmax.unsqueeze(1) * inst_logits_g, dim=0)
            slide_logits_list.append(slide_logit.unsqueeze(0))

        slide_logits = torch.cat(slide_logits_list, dim=0)  # [num_slides_in_batch, num_classes]
        
        return slide_logits, instance_logits, att_weights_list

In [3]:
# The GNN Model
class SlideGNN_AdvancedMIL(nn.Module):
    def __init__(self, input_dim=1024, hidden_dim=256, num_classes=2,num_layers=3,dropout=0.3):
        super().__init__()

        # Initialize a ModuleList to hold your GraphConv layers
        self.convs = nn.ModuleList()

        # Add the first layer (input layer)
        self.convs.append(GraphConv(input_dim, hidden_dim))

        # Add intermediate layers
        for _ in range(num_layers - 2):
            self.convs.append(GraphConv(hidden_dim, hidden_dim))

        # Add the last layer
        self.convs.append(GraphConv(hidden_dim, hidden_dim))

        # Store dropout
        self.dropout = dropout

        # Batch normalization layers
        self.bns = torch.nn.ModuleList([nn.BatchNorm1d(hidden_dim) for _ in range(num_layers)])

        # Attention-based pooling for slide-level embedding
        self.mil_pool = GatedAttentionWithInstanceClassifier(
            input_dim=hidden_dim,
            hidden_dim=hidden_dim,
            num_classes=num_classes
        )

        # Final classification from the pooled slide-level embedding
        self.classifier = nn.Linear(hidden_dim, num_classes)


    def forward(self, x, edge_index, batch):
        """
        x: Node features [num_nodes, in_channels]
        edge_index: Graph connectivity [2, num_edges] 
        batch: Node-to-graph assignment [num_nodes]
        """

        # Pass through each graph convolutional layer
        for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        # MIL aggregator
        slide_logits, tile_logits, att_weights = self.mil_pool(x, batch)

        return slide_logits, tile_logits, att_weights


In [4]:
def train(model, train_loader, optimizer, criterion, device):
    model.train()
    scaler = GradScaler()
    train_loss = 0.0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        with autocast(device_type='cuda', dtype=torch.float16):
            # Forward pass (in mixed-precision)
            slide_logits, _, _= model(data.x, data.edge_index, data.batch)
            
            # Suppose we only have slide-level labels
            loss = criterion(slide_logits, data.y)

        # Backward pass with scaled gradients
        scaler.scale(loss).backward()
        
        # Step with the optimizer using the scaler
        scaler.step(optimizer)
        scaler.update()
        train_loss += loss.item()
    return train_loss / len(train_loader)
   

def test(model,test_loader,device):
    model.eval()
    slide_preds, slide_labels, slide_score, tile_preds, tile_score, tile_batch = [],[], [], [], [], []
    with torch.no_grad():
        for data in test_loader:
            data = data.to(device)
            slide_logits, tile_logits, att_weights = model(data.x, data.edge_index, data.batch)
            
            # ---------- Slide-level ----------
            # Slide-level predictions
            s_preds = slide_logits.argmax(dim=1)
            s_score = F.softmax(slide_logits, dim=1)
            slide_preds.extend(s_preds.cpu().numpy())
            slide_score.extend(s_score.cpu().numpy())
            slide_labels.extend(data.y.cpu().numpy())

            # ---------- Tile-level ----------
            # Convert tile_logits to probabilities
            t_preds = tile_logits.argmax(dim=1)
            t_score = F.softmax(tile_logits, dim=1)

            #t_preds = att_weights.argmax(dim=1)
            tile_preds.append(t_preds.cpu())
            tile_score.append(t_score.cpu())
            tile_batch.append(data.batch.cpu())

    # Concatenate tile-level outputs across all batches
    tile_preds = torch.cat(tile_preds, dim=0).numpy()
    tile_score = torch.cat(tile_score, dim=0).numpy()  # shape [TotalTiles, num_classes]
    tile_batch = torch.cat(tile_batch, dim=0).numpy()    # shape [TotalTiles]

    # Slide-level accuracy
    slide_accuracy = accuracy_score(slide_labels, slide_preds)
    return slide_accuracy, slide_score, slide_preds, slide_labels, tile_score, tile_preds, tile_batch

In [5]:
def setup_distributed_environment():
    if 'SLURM_NODELIST' in os.environ:
        master_node = os.popen(f"scontrol show hostnames {os.environ['SLURM_NODELIST']}").read().splitlines()[0]
        os.environ['MASTER_ADDR'] = master_node
    else:
        os.environ['MASTER_ADDR'] = '127.0.0.1'

    # Dynamically set MASTER_PORT if not set
    os.environ.setdefault("MASTER_PORT", str(29500 + int(os.environ.get("SLURM_PROCID", 0))))

    # SLURM-specific rank and world size
    #world_size = int(os.environ.get('WORLD_SIZE', os.environ.get('SLURM_NTASKS')))
    #rank =int(os.environ.get('RANK', os.environ.get('SLURM_PROCID')))
    #local_rank = int(os.environ.get('SLURM_LOCALID', rank % torch.cuda.device_count()))
    #local_rank = int(os.environ.get('LOCAL_RANK', os.environ.get('SLURM_LOCALID')))
    #return world_size,rank,local_rank

In [6]:
def run(world_size: int, rank: int, graph_dataset,sample_names,label_encoder):
    # Will query the runtime environment for `MASTER_ADDR` and `MASTER_PORT`.
    # Make sure, those are set!
    os.environ["MASTER_ADDR"] = os.popen(f"scontrol show hostnames {os.environ['SLURM_NODELIST']}").read().splitlines()[0]
    os.environ.setdefault("MASTER_PORT", str(29500 + int(os.environ.get("SLURM_PROCID", 0))))
       
    #print(f"Initializing process group: Rank={rank}, World Size={world_size}, Local Rank={local_rank}")
    dist.init_process_group('nccl', world_size=world_size, rank=rank)
    print(f"Process group initialized: Rank={rank}")


    # Move to device for faster feature fetch.
    #data = data.to(local_rank, 'x', 'y')

    folds = stratified_kfold_split(graph_dataset,  n_splits=5)
   
    for fold_idx, (train_idx, val_idx) in enumerate(folds):
        # Create a directory for this fold
        fold_dir = os.path.join(outdir, f"result_{fold_idx + 1}")
        os.makedirs(fold_dir, exist_ok=True)  # Create folder if it doesn't exist
        if rank == 0:
            print(f"  Training samples: {len(train_idx)}")
            print(f"  Validation samples: {len(val_idx)}")

        # Split dataset
        train_dataset = Subset(graph_dataset, train_idx)
        test_dataset = Subset(graph_dataset, val_idx)

        print(f"Rank={rank}, train_data:{len(train_dataset)}, test_data:{len(test_dataset)}")

        # Use DistributedSampler
        train_sampler = torch.utils.data.DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
        test_sampler = torch.utils.data.DistributedSampler(test_dataset, num_replicas=world_size, rank=rank, shuffle=False)

        train_loader = DataLoader(train_dataset, batch_size=64, sampler=train_sampler)

        if rank == 0:
            test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


        # Validate data loaders
        if rank == 0:
            print("Checking train loader...")
            check_batches(train_loader)
            if rank == 0:
                print("Checking validation loader...")
                check_batches(test_loader)

        # Initialize model
        model = SlideGNN_AdvancedMIL(input_dim=1024, hidden_dim=512, num_classes=4,num_layers=3, dropout=0.3).to(rank)
        model = DDP(model, device_ids=[rank])

        # Calculate class weights
        class_weights = calculate_class_weights_from_graph(train_dataset, device=rank)

        # Use CrossEntropyLoss with class weights
        criterion = nn.CrossEntropyLoss(weight=class_weights).to(rank)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

        # Train and evaluate
        for epoch in range(1, 101):
            train_loss = train(model, train_loader, optimizer, criterion, device=rank)
            dist.barrier()

            if rank == 0 and epoch % 5 == 0:
                slide_accuracy, slide_score, slide_preds, slide_labels, tile_score, tile_preds, tile_batch = test(model,test_loader, device=rank)
            dist.barrier()

            if rank == 0:
                print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, Accuracy:{slide_accuracy:.4f}')

                # Save model for this fold
                model_path = os.path.join(fold_dir, "GNN_ViT_model.pth")
                torch.save(model.state_dict(), model_path)
                print(f"Model for Fold {fold_idx + 1} saved at {model_path}")

                # Save slide results for this fold in a text file
                save_results_to_text(sample_names, val_idx, slide_labels, slide_preds, slide_score, label_encoder, rank, outdir = fold_dir)
            
                # Save tile predictions to an .npz file
                save_tile_results(tile_preds, tile_score, tile_batch, val_idx, sample_names, rank, output_dir = fold_dir)

        dist.barrier()

    dist.destroy_process_group()

In [None]:
world_size = torch.cuda.device_count()
print(f"world count:{wold_size}")

In [7]:
outdir="GNN_modle_Malignant_5FCross_multiGPU_old"
os.makedirs(outdir, exist_ok=True)

In [8]:
# old data
slide_features=np.load("../features/NCI_CBTN_GP34_features.npy", allow_pickle=True)
slide_classes=pd.read_csv("../features/NCI_CBTN_GP34_Malignant&Myeloid_Cellstate_classes.csv",index_col=0)


In [7]:
## load NCI CBTN Group 3 and 4 features and classes
slide_features=np.load("/data/NCI_LP/pediatric_Data/Medulloblastoma//NCI_CBTN_GP34_features.npy", allow_pickle=True)
slide_classes=pd.read_csv("/data/NCI_LP/pediatric_Data/Medulloblastoma/NCI_CBTN_MB34_320Samples_metadata_reorder.csv",index_col=0)

In [9]:
# Seperate sample names and image tile level features
sample_names, features = extract_sample_names_and_features(slide_features)

# Scale the features
scaled_features = scale_tile_features(features)

# Encode the target classes
slide_classes=slide_classes["Malignant.Cluster"]
label_encoder = LabelEncoder()
y_labels = label_encoder.fit_transform(slide_classes)

# creating graph data from tile level featutes
graph_dataset = TileGraphDataset(scaled_features, y_labels, k=5)

In [10]:
folds = stratified_kfold_split(graph_dataset,  n_splits=5)

Number of labels: 321:
  Training samples: 256
  Validation samples: 65
------------------------------
  Training samples: 257
  Validation samples: 64
------------------------------
  Training samples: 257
  Validation samples: 64
------------------------------
  Training samples: 257
  Validation samples: 64
------------------------------
  Training samples: 257
  Validation samples: 64
------------------------------


In [11]:
device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
for fold_idx, (train_idx, val_idx) in enumerate(folds):
    # Create a directory for this fold
    fold_dir = os.path.join(outdir, f"result_{fold_idx + 1}")
    os.makedirs(fold_dir, exist_ok=True)  # Create folder if it doesn't exist
    
    print(f"  Training samples: {len(train_idx)}")
    print(f"  Validation samples: {len(val_idx)}")

    # Split dataset
    train_dataset = Subset(graph_dataset, train_idx)
    test_dataset = Subset(graph_dataset, val_idx)

    print(f"train_data:{len(train_dataset)}, test_data:{len(test_dataset)}")

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


    # Validate data loaders
    print("Checking train loader...")
    check_batches(train_loader)
    print("Checking validation loader...")
    check_batches(test_loader)

    # Initialize model
    model = SlideGNN_AdvancedMIL(input_dim=1024, hidden_dim=1024, num_classes=4,num_layers=3, dropout=0.3).to(device)
    #model = DDP(model, device_ids=device)

    # Calculate class weights
    class_weights = calculate_class_weights_from_graph(train_dataset, device)

    # Use CrossEntropyLoss with class weights
    criterion = nn.CrossEntropyLoss(weight=class_weights).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    # Train and evaluate
    best_slide_acc = 0.0  # Best validation accuracy so far
    patience_counter = 0
    patience = 10
    # Train and evaluate
    for epoch in range(1, 101):
        train_loss = train(model, train_loader, optimizer, criterion, device)

        slide_accuracy, slide_score, slide_preds, slide_labels, tile_score, tile_preds, tile_batch = test(model,test_loader, device)

        print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, Accuracy:{slide_accuracy:.4f}')
        # -------------- EARLY STOPPING LOGIC --------------
        if slide_accuracy > best_slide_acc:
            best_slide_acc = slide_accuracy
            patience_counter = 0
        else:
            patience_counter += 1

        # Check for early stopping condition
        if patience_counter >= patience:
            print(f"Early stopping triggered at epoch {epoch}.")
            break
        
    # Save model for this fold
    model_path = os.path.join(fold_dir, "GNN_ViT_model.pth")
    torch.save(model.state_dict(), model_path)
    print(f"Model for Fold {fold_idx + 1} saved at {model_path}")

    # Save slide results for this fold in a text file
    save_results_to_text(sample_names, val_idx, slide_labels, slide_preds, slide_score, label_encoder, rank=device, outdir = fold_dir)
            
    # Save tile predictions to an .npz file
    save_tile_results(tile_preds, tile_score, tile_batch, val_idx, sample_names, rank=device, output_dir = fold_dir)

  Training samples: 256
  Validation samples: 65
train_data:256, test_data:65
Checking train loader...
Step 1:
Number of graphs in the current batch: 32
Data in Batch:
DataBatch(x=[57518, 1024], edge_index=[2, 287590], y=[32], batch=[57518], ptr=[33])

Step 2:
Number of graphs in the current batch: 32
Data in Batch:
DataBatch(x=[82182, 1024], edge_index=[2, 410910], y=[32], batch=[82182], ptr=[33])

Step 3:
Number of graphs in the current batch: 32
Data in Batch:
DataBatch(x=[75031, 1024], edge_index=[2, 375155], y=[32], batch=[75031], ptr=[33])

Step 4:
Number of graphs in the current batch: 32
Data in Batch:
DataBatch(x=[80717, 1024], edge_index=[2, 403585], y=[32], batch=[80717], ptr=[33])

Step 5:
Number of graphs in the current batch: 32
Data in Batch:
DataBatch(x=[68080, 1024], edge_index=[2, 340400], y=[32], batch=[68080], ptr=[33])

Step 6:
Number of graphs in the current batch: 32
Data in Batch:
DataBatch(x=[82279, 1024], edge_index=[2, 411395], y=[32], batch=[82279], ptr=[33]

In [12]:
# Collect and combine tile level results
combine_tile_level_results(outdir)

# Collect slide results data
true_labels, predicted_labels, predicted_scores = aggregate_slide_level_results(outdir)

# Annalyze slide results data
analyze_results(true_labels, predicted_labels, predicted_scores, label_encoder,outdir=outdir)

Loading GNN_modle_Malignant_5FCross_multiGPU_old/result_2/tile_results_rankcuda.npz...
Loading GNN_modle_Malignant_5FCross_multiGPU_old/result_4/tile_results_rankcuda.npz...
Loading GNN_modle_Malignant_5FCross_multiGPU_old/result_1/tile_results_rankcuda.npz...
Loading GNN_modle_Malignant_5FCross_multiGPU_old/result_3/tile_results_rankcuda.npz...
Loading GNN_modle_Malignant_5FCross_multiGPU_old/result_5/tile_results_rankcuda.npz...
Combined tile results saved to GNN_modle_Malignant_5FCross_multiGPU_old/combined_tile_results.npz
Processing fold directory: GNN_modle_Malignant_5FCross_multiGPU_old/result_1
  Loading file: GNN_modle_Malignant_5FCross_multiGPU_old/result_1/slide_results_rankcuda.txt
Processing fold directory: GNN_modle_Malignant_5FCross_multiGPU_old/result_2
  Loading file: GNN_modle_Malignant_5FCross_multiGPU_old/result_2/slide_results_rankcuda.txt
Processing fold directory: GNN_modle_Malignant_5FCross_multiGPU_old/result_3
  Loading file: GNN_modle_Malignant_5FCross_multiG

In [35]:
from tile_categarical_heamp import*
outdir ="GNN_modle_Malignant_5FCross_multiGPU"
slide_name = "BM51_ST-22-34_A-1_-_2022-01-25_10.54.58.ndpi"
slide_file_name = os.path.join("/data/NCI_LP/pediatric_Data/Medulloblastoma/NCI_MB_GP34/slides",slide_name)
all_tile_result = np.load(os.path.join(outdir,"combined_tile_results.npz"),allow_pickle=True)
tile_result = all_tile_result[slide_name]
tile_dict = tile_result.item()  # Extract the dictionary from the array
tile_labels = tile_dict['tile_preds']
print(tile_labels)
tile_coord_path = "/data/NCI_LP/pediatric_Data/Medulloblastoma/NCI_MB_GP34/outputs/NCI_MB_GP34_example/NCI_MB_GP34_coordinates/"
slide_coord_name = slide_name+"_tile_coordinates.csv"
tile_coords = pd.read_csv(os.path.join(tile_coord_path,slide_coord_name))
tile_coords = tile_coords.to_numpy()


[0 0 0 ... 0 0 0]


In [63]:
import os
import numpy as np
import openslide
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.patches import Patch

def create_categorical_heatmap(
    tile_coords,
    tile_labels,
    slide_file_name,
    class_idx=None,
    tile_size=512,
    downsample=16,
    categories=None,
    category_colors=None,
    overwrite_policy='background_only'
):
    """
    Create a categorical heatmap for a single slide from tile labels.

    Args:
        tile_coords (np.ndarray): shape [N, 2], each row=(row, col) in original resolution.
        tile_labels (np.ndarray): shape [N], integer labels for each tile.
        slide_file_name (str): Path to the WSI.
        class_idx (int, optional): If specified, only visualize this label.
        tile_size (int): Original tile width/height in pixels.
        downsample (int): Factor to downsample the slide for heatmap.
        categories (dict, optional): Mapping from label -> category name.
        category_colors (dict, optional): Mapping from label -> color (string).
        overwrite_policy (str): 'background_only' means only fill empty (0),
            'overwrite' means always overwrite.

    Returns:
        heatmap (np.ndarray): shape (H_ds, W_ds) with integer labels.
    """

    # 1. Check input lengths
    if len(tile_coords) != len(tile_labels):
        raise ValueError("tile_coords and tile_labels must be the same length.")

    # Provide default categories if needed
    if categories is None:
        categories = {
            0: 'Background',
            1: 'Neural.Crest',
            2: 'Neuronal',
            3: 'Photoreceptor',
            4: 'Proliferative'
        }
    # Provide default colors if needed
    if category_colors is None:
        category_colors = {
            0: 'white',     # background
            1: 'green',     
            2: 'yellow',
            3: 'blue',
            4: 'pink'
        }

    # Print some debug info
    print(f"Number of tiles: {len(tile_coords)}")
    unique_tile_labels = np.unique(tile_labels)
    print("Unique labels in tile_labels:", unique_tile_labels)

    # 2. Open slide with OpenSlide
    try:
        slide = openslide.OpenSlide(slide_file_name)
    except Exception as e:
        raise IOError(f"Could not open slide file {slide_file_name}: {e}")

    properties = slide.properties
    mag_key = openslide.PROPERTY_NAME_OBJECTIVE_POWER
    if mag_key in properties:
        mag_max = float(properties[mag_key])
        print(f"Slide magnification: {mag_max}x")
    else:
        mag_assumed = 20.0
        mag_max = mag_assumed
        print(f"[WARNING] Magnification not found. Assuming {mag_assumed}x")

    # 3. Determine slide dimensions at the highest level (level=0)
    slide_width_0, slide_height_0 = slide.level_dimensions[0]

    # Downsample for the heatmap space
    slide_width = slide_width_0 // downsample
    slide_height = slide_height_0 // downsample
    print(f"Slide dimensions at downsample=1/{downsample}: {slide_width}x{slide_height}")

    # 4. Create heatmap array
    heatmap = np.zeros((slide_height, slide_width), dtype=np.int32)

    # Downsample tile size
    tile_size_ds = tile_size // downsample

    # 5. Populate the heatmap
    #   If your coords are row=(y), col=(x), we do row_ds = row//downsample, col_ds = col//downsample
    #   If class_idx is given, we skip tiles that are not that label.
    for idx, (row, col) in enumerate(tile_coords):
        label = tile_labels[idx]

        if class_idx is not None and label != class_idx:
            continue

        row_ds = row // downsample
        col_ds = col // downsample

        row_start = max(0, row_ds)
        row_end   = min(slide_height, row_ds + tile_size_ds)
        col_start = max(0, col_ds)
        col_end   = min(slide_width,  col_ds + tile_size_ds)

        # Print debug info for the first few tiles or if label != 0
        if idx < 10 or label != 0:
            print(f"[Tile {idx}] row={row}, col={col}, label={label}")
            print(f"  region=({row_start}:{row_end}, {col_start}:{col_end})")

        # Depending on the overwrite policy, we either fill only background or overwrite always
        if overwrite_policy == 'background_only':
            # fill label only where heatmap == 0
            existing = heatmap[row_start:row_end, col_start:col_end]
            mask = (existing == 0)
            existing[mask] = label
        else:
            # Always overwrite any existing label
            heatmap[row_start:row_end, col_start:col_end] = label

    # 6. Debug: check final unique labels in the heatmap
    unique_vals_heatmap = np.unique(heatmap)
    print("Unique values in heatmap:", unique_vals_heatmap)

    # 7. Build the color map
    max_label = max(categories.keys())
    color_list = []
    for lbl in range(max_label+1):
        color = category_colors.get(lbl, 'gray')
        color_list.append(color)
    cmap = mcolors.ListedColormap(color_list)
    bounds = list(range(max_label+2))
    norm = mcolors.BoundaryNorm(bounds, cmap.N)

    # 8. Plot
    plt.figure(figsize=(10, 10))
    plt.imshow(heatmap, cmap=cmap, norm=norm, alpha=1.0)  # alpha=1.0 => fully opaque
    plt.axis('off')
    plt.title('Categorical Heatmap Overlay', fontsize=16)

    # Create legend
    legend_elements = []
    for lbl, cat_name in categories.items():
        if category_colors.get(lbl, 'white') != 'none':
            legend_elements.append(
                Patch(facecolor=category_colors[lbl], edgecolor='black', label=cat_name)
            )
    plt.legend(handles=legend_elements, bbox_to_anchor=(1.05,1), loc='upper left')

    # 9. Save figure
    heatmap_save_path = os.path.splitext(slide_file_name)[0] + '_heatmap.png'
    plt.savefig(heatmap_save_path, bbox_inches='tight', dpi=300)
    plt.close()

    print(f"Categorical heatmap saved to {heatmap_save_path}")

    return heatmap


In [1]:
heatmap = create_categorical_heatmap(tile_coords, tile_labels, slide_file_name,
                           class_idx=None, tile_size=512, downsample=16,
                           categories=categories, category_colors=category_colors)

NameError: name 'create_categorical_heatmap' is not defined

In [9]:
import os
import numpy as np
import openslide
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Patch

def overlay_tiles_on_wsi(
    slide_path: str,
    tile_coords: np.ndarray,   # shape (N,2) in level 0 coords
    tile_labels: np.ndarray,   # shape (N,) numeric
    tile_size: int = 512,
    level_for_display: int = 2,
    label_names=None,          # dict label->string
    label_colors=None,         # dict label->(r,g,b,a)
    output_dir: str = "./output",
    output_prefix: str = "overlay"
):
    """
    Overlays semi-transparent colored rectangles for tile labels on a downsampled WSI image.
    """

    # Open the slide
    slide = openslide.OpenSlide(slide_path)
    if level_for_display >= slide.level_count:
        raise ValueError(f"Invalid level {level_for_display}, max is {slide.level_count-1}")

    # Dimensions at this level
    display_width, display_height = slide.level_dimensions[level_for_display]
    downsample_factor = slide.level_downsamples[level_for_display]

    print(f"Using level {level_for_display} => dimension: {display_width}x{display_height}, "
          f"downsample={downsample_factor}")

    # Read the region at this level
    display_region = slide.read_region((0,0), level_for_display, (display_width, display_height))
    display_image = display_region.convert('RGB')
    display_image_np = np.array(display_image)

    # Default label->name mapping if not provided
    if label_names is None:
        label_names = {
            0: "Neural-crest",
            1: "Neuronal",
            2: "Photoreceptor",
            3: "Proliferative",
        }

    # Default label->RGBA color mapping if not provided
    if label_colors is None:
        label_colors = {
            0: (1.0, 0.0, 0.0, 0.3),
            1: (0.0, 1.0, 0.0, 0.3),
            2: (0.0, 0.0, 1.0, 0.3),
            3: (1.0, 1.0, 0.0, 0.3),
        }

    # Create matplotlib figure
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(display_image_np)
    ax.axis('off')

    # Draw rectangles for each tile
    for i, (row_level0, col_level0) in enumerate(tile_coords):
        lbl = tile_labels[i]
        color = label_colors.get(lbl, (1.0, 1.0, 1.0, 0.3))  # fallback color if not found

        # Downsample coords
        row_ds = row_level0 / downsample_factor
        col_ds = col_level0 / downsample_factor
        tile_size_ds = tile_size / downsample_factor

        rect = patches.Rectangle(
            (col_ds, row_ds),
            tile_size_ds,
            tile_size_ds,
            linewidth=1,
            edgecolor=color[:3],  # ignoring alpha for edge
            facecolor=color
        )
        ax.add_patch(rect)

        # Optional debug for first few
        if i < 5:
            name = label_names.get(lbl, f"Label {lbl}")
            print(f"Tile {i} => label={lbl} ({name}), region=({col_ds},{row_ds}), size={tile_size_ds}")

    # Build a legend from label_names
    legend_patches = []
    for lbl, name in label_names.items():
        c = label_colors.get(lbl, (1.0, 1.0, 1.0, 0.3))
        legend_patches.append(Patch(facecolor=c, label=name))
    ax.legend(handles=legend_patches, bbox_to_anchor=(1.05,1), loc='upper left')

    # Save
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, f"{output_prefix}_level{level_for_display}.png")
    plt.savefig(output_path, bbox_inches='tight', dpi=300)
    plt.close()
    print(f"Saved overlay to {output_path}")


In [10]:
# Provide custom names if you prefer
label_names = {
    0: "Neural-crest",
    1: "Neuronal",
    2: "Photoreceptor",
    3: "proliferative"
        # If you have a label 1 or 4, define them too
}
# Provide custom colors if desired
label_colors = {
    0: 'red',
    1: 'green',
    2: 'blue',
    3: 'magenta'
}

overlay_tiles_on_wsi(
    slide_path=slide_file_name,
    tile_coords=tile_coords,
    tile_labels=tile_labels,
    tile_size=512,
    level_for_display=2,    # a downsampled level
    label_names=None,
    label_colors=None,
    output_dir=outdir,
    output_prefix="BM51_Image1"
)

    

Using level 2 => dimension: 22080x23232, downsample=4.0


ValueError: too many values to unpack (expected 2)

In [53]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.patches import Patch
import os

def create_categorical_heatmap(tile_coords, tile_labels, tile_size=256, downsample=2,
                               categories=None, category_colors=None,
                               output_path="heatmap.png"):
    """
    Create and save a categorical heatmap from tile coordinates and labels,
    without overlaying on an H&E slide.

    Args:
        tile_coords (np.ndarray): shape [N, 2], each row = (row, col) or (y, x)
                                  for top-left corner of tile in original resolution.
        tile_labels (np.ndarray): shape [N], integer labels for each tile.
        tile_size (int): width/height of each tile in pixels.
        downsample (int): factor to downsample the heatmap. E.g. 16 => 1/16 scale.
        categories (dict): label -> category name for the legend.
        category_colors (dict): label -> color (any matplotlib color format).
        output_path (str): file path to save the heatmap image.

    Returns:
        heatmap (np.ndarray): A 2D integer array with shape [H, W] containing labels.
    """

    # Sanity check
    if len(tile_coords) != len(tile_labels):
        raise ValueError("tile_coords and tile_labels must have the same length.")

    # Provide default categories if not given
    if categories is None:
        categories = {
            0: "Background",
            1: "Neural-crest",
            2: "Neuronal",
            3: "Photoreceptor",
            4: "Proliferative"
        }

    # Provide default colors if not given
    if category_colors is None:
        category_colors = {
            0: "white",
            1: "red",
            2: "green",
            3: "blue",
            4: "magenta"
        }

    # 1) Determine the shape of the heatmap
    #    We find the max row, col among tile_coords plus tile_size,
    #    then optionally downsample.
    
    max_row = 0
    max_col = 0
    for (row, col) in tile_coords:
        # The bottom-right corner of this tile is row+tile_size, col+tile_size
        if row + tile_size > max_row:
            max_row = row + tile_size
        if col + tile_size > max_col:
            max_col = col + tile_size

    # Downsample the final array
    heatmap_height = (max_row // downsample)
    heatmap_width  = (max_col // downsample)

    print(f"Heatmap dimensions: {heatmap_height} x {heatmap_width}")

    # 2) Create the empty heatmap (background label = -1)
    heatmap = np.zeros((heatmap_height, heatmap_width), dtype=np.int32)

    # 3) Populate the heatmap
    tile_size_ds = tile_size // downsample
    for i, (row, col) in enumerate(tile_coords):
        label = tile_labels[i]

        row_ds = row // downsample
        col_ds = col // downsample

        row_start = row_ds
        row_end   = row_ds * tile_size_ds
        col_start = col_ds
        col_end   = col_ds * tile_size_ds

        # Bound check
        if row_end > heatmap_height:
            row_end = heatmap_height
        if col_end > heatmap_width:
            col_end = heatmap_width

        heatmap[row_start:row_end, col_start:col_end] = label

    # 4) Build a ListedColormap for the labels
    max_label = max(categories.keys())
    # Create a color list for each label 0..max_label
    color_list = []
    for l in range(max_label+1):
        color_list.append(category_colors.get(l, "yellow"))  # fallback

    cmap = mcolors.ListedColormap(color_list)
    bounds = list(range(max_label+2))  # e.g. 0..max_label+1
    norm = mcolors.BoundaryNorm(bounds, cmap.N)

    # 5) Plot the heatmap
    fig, ax = plt.subplots(figsize=(10, 10))
    im = ax.imshow(heatmap, cmap=cmap, norm=norm, origin='upper')  # origin='upper' for row->y
    ax.axis("off")
    ax.set_title("Categorical Tile Heatmap", fontsize=16)

    # 6) Create a legend
    from matplotlib.patches import Patch
    legend_elements = []
    for lbl, cat_name in categories.items():
        color_val = category_colors.get(lbl, "yellow")
        legend_elements.append(Patch(facecolor=color_val, label=cat_name))
    ax.legend(handles=legend_elements, bbox_to_anchor=(1.05,1), loc='upper left')

    # 7) Save the figure
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    plt.savefig(output_path, bbox_inches='tight', dpi=300)
    plt.close()

    print(f"Saved heatmap to {output_path}")
    return heatmap

In [55]:
# Create the heatmap
heatmap_array = create_categorical_heatmap(
    tile_coords=tile_coords,
    tile_labels=tile_labels,
    tile_size=256,
    categories=categories,
    category_colors=category_colors,
    output_path=os.path.join(outdir,"BM51_categorical_heatmap.png")
)


Heatmap dimensions: 171 x 169
Saved heatmap to GNN_modle_Malignant_5FCross_multiGPU/BM51_categorical_heatmap.png


In [79]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.patches import Patch
import os

def create_categorical_heatmap(tile_coords, tile_labels, tile_size=512, downsample=2,
                               categories=None, category_colors=None,
                               output_path="heatmap.png"):
    """
    Create and save a categorical heatmap from tile coordinates and labels,
    without overlaying on an H&E slide.

    Args:
        tile_coords (np.ndarray): shape [N, 2], each row = (row, col) or (y, x)
                                  for top-left corner of tile in original resolution.
        tile_labels (np.ndarray): shape [N], integer labels for each tile.
        tile_size (int): width/height of each tile in pixels.
        downsample (int): factor to downsample the heatmap. E.g. 16 => 1/16 scale.
        categories (dict): label -> category name for the legend.
        category_colors (dict): label -> color (any matplotlib color format).
        output_path (str): file path to save the heatmap image.

    Returns:
        heatmap (np.ndarray): A 2D integer array with shape [H, W] containing labels.
    """

    # Sanity check
    if len(tile_coords) != len(tile_labels):
        raise ValueError("tile_coords and tile_labels must have the same length.")

    # Provide default categories if not given
    if categories is None:
        categories = {
            0: "Background",
            1: "Neural-crest",
            2: "Neuronal",
            3: "Photoreceptor",
            4: "Proliferative"
        }

    # Provide default colors if not given
    if category_colors is None:
        category_colors = {
            0: "white",
            1: "red",
            2: "green",
            3: "blue",
            4: "magenta"
        }
    # 1) Open the WSI slide
    try:
        slide = openslide.OpenSlide(slide_file_name)
    except Exception as e:
        raise IOError(f"Could not open slide file {slide_file_name}: {e}")
    ## magnification max
    if openslide.PROPERTY_NAME_OBJECTIVE_POWER in slide.properties:
        mag_max = slide.properties[openslide.PROPERTY_NAME_OBJECTIVE_POWER]
        print("mag_max:", mag_max)
        mag_original = mag_max
    else:
        print("[WARNING] mag not found, assuming: {mag_assumed}")
        mag_max = 40
        mag_original = 0

    ## downsample_level
    downsampling = int(int(mag_max)/20)
    print(f"downsampling: {downsampling}")
    
    # 2) Get dimensions at level 0 and chosen level
    px0, py0 = slide.level_dimensions[0]  # Level 0 (original resolution)
    tile_size0 = int(512*downsampling)
    n_rows,n_cols = int(py0/tile_size0), int(px0/tile_size0)
    print(f"tiles from original image n_rows: {n_rows}, n_cols: {n_cols}, total tiles: {n_rows*n_cols}")
    
    # 1) Determine the shape of the heatmap
    #    We find the max row, col among tile_coords plus tile_size,
    #    then optionally downsample.
    
    print(f"tile size:{len(tile_coords)}")
    # Ensure correct heatmap size
    #max_row = len(np.unique(tile_coords[:, 0]))
    #max_col = len(np.unique(tile_coords[:, 1]))

   

    #print(f"tiles from tile_coord file max_rows: {max_row}, max_cols: {max_col}, total tiles: {max_row*max_col}")
    
    # Downsample the final array
    mask_tile_size = int(np.ceil(tile_size/16))
    heatmap_height = n_rows*tile_size0
    heatmap_width = n_cols*tile_size0

    print(f"Heatmap dimensions: {heatmap_height} x {heatmap_width}")

    # 2) Create the empty heatmap (background label = -1)
    heatmap = np.zeros((heatmap_height, heatmap_width), dtype=np.int32)

    # 3) Populate the heatmap
    tile_size_ds = tile_size // downsample
    for i, (row, col) in enumerate(tile_coords):
        label = tile_labels[i]

        row_start= row // downsample
        col_start = col // downsample

        # Correct tile placement within heatmap bounds
        row_end = min(row_start + (tile_size // downsample), heatmap_height)
        col_end = min(col_start + (tile_size // downsample), heatmap_width)
        # Bound check
        if row_end > heatmap_height:
            row_end = heatmap_height
        if col_end > heatmap_width:
            col_end = heatmap_width

        heatmap[row_start:row_end, col_start:col_end] = label

    # 4) Build a ListedColormap for the labels
    max_label = max(categories.keys())
    # Create a color list for each label 0..max_label
    color_list = []
    for l in range(max_label+1):
        color_list.append(category_colors.get(l, "yellow"))  # fallback

    cmap = mcolors.ListedColormap(color_list)
    bounds = list(range(max_label+2))  # e.g. 0..max_label+1
    norm = mcolors.BoundaryNorm(bounds, cmap.N)

    # 5) Plot the heatmap
    fig, ax = plt.subplots(figsize=(10, 10))
    im = ax.imshow(heatmap, cmap=cmap, norm=norm, origin='upper')  # origin='upper' for row->y
    ax.axis("off")
    ax.set_title("Categorical Tile Heatmap", fontsize=16)

    # 6) Create a legend
    from matplotlib.patches import Patch
    legend_elements = []
    for lbl, cat_name in categories.items():
        color_val = category_colors.get(lbl, "yellow")
        legend_elements.append(Patch(facecolor=color_val, label=cat_name))
    ax.legend(handles=legend_elements, bbox_to_anchor=(1.05,1), loc='upper left')

    # 7) Save the figure
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    plt.savefig(output_path, bbox_inches='tight', dpi=300)
    plt.close()

    print(f"Saved heatmap to {output_path}")
    return heatmap

In [80]:
# Create the heatmap
heatmap_array = create_categorical_heatmap(
    tile_coords=tile_coords,
    tile_labels=tile_labels,
    tile_size=256,
    categories=categories,
    category_colors=category_colors,
    output_path=os.path.join(outdir,"BM51_categorical_heatmap.png")
)


mag_max: 40
downsampling: 2
tiles from original image n_rows: 90, n_cols: 86, total tiles: 7740
tile size:3006
Heatmap dimensions: 92160 x 88064
Saved heatmap to GNN_modle_Malignant_5FCross_multiGPU/BM51_categorical_heatmap.png


In [57]:
categories = {
    0: "Background",
    1: "Neural-crest",
    2: "Neuronal",
    3: "Photoreceptor",
    4: "Proliferative"
}
category_colors = {
    0: "white",
    1: "red",
    2: "green",
    3: "blue",
    4: "magenta"
}

# Create the heatmap
heatmap_array = create_categorical_heatmap(
    slide_file_name,
    tile_coords=tile_coords,
    tile_labels=tile_labels,
    level=0,
    tile_size=512,
    categories=categories,
    category_colors=category_colors,
    output_path=os.path.join(outdir,"BM51_categorical_heatmap.png")
)


TypeError: create_categorical_heatmap() got multiple values for argument 'tile_coords'

In [21]:
import openslide
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.patches import Patch
import os

def create_categorical_heatmap(
    slide_file_name,
    tile_coords,
    tile_labels,
    level=2,
    downsample_factor=4,  # Adjust based on visualization needs
    categories=None,
    category_colors=None,
    output_path="heatmap.png"
):
    """
    Create a categorical heatmap without overlapping tiles.

    Args:
        slide_file_name (str): Path to the WSI.
        tile_coords (np.ndarray): shape [N,4], (row, col, row_end, col_end) at level 0.
        tile_labels (np.ndarray): shape [N], integer labels for each tile.
        level (int): WSI level to visualize.
        downsample_factor (int): Factor for reducing resolution.
        categories (dict): Label -> category name for the legend.
        category_colors (dict): Label -> color mapping.
        output_path (str): Path to save heatmap.

    Returns:
        heatmap (np.ndarray): 2D array at the chosen level resolution.
    """
    # 1) Open the WSI slide
    try:
        slide = openslide.OpenSlide(slide_file_name)
    except Exception as e:
        raise IOError(f"Could not open slide file {slide_file_name}: {e}")
    ## magnification max
    if openslide.PROPERTY_NAME_OBJECTIVE_POWER in slide.properties:
        mag_max = slide.properties[openslide.PROPERTY_NAME_OBJECTIVE_POWER]
        print("mag_max:", mag_max)
        mag_original = mag_max
    else:
        print("[WARNING] mag not found, assuming: {mag_assumed}")
        mag_max = 40
        mag_original = 0

    ## downsample_level
    downsampling = int(int(mag_max)/20)
    print(f"downsampling: {downsampling}")
    
    # 2) Get dimensions at level 0 and chosen level
    px0, py0 = slide.level_dimensions[0]  # Level 0 (original resolution)
    tile_size0 = int(512*downsampling)
    n_rows,n_cols = int(py0/tile_size0), int(px0/tile_size0)

    print(f"n_rows: {n_rows}, n_cols: {n_cols}, total tiles: {n_rows*n_cols}")

    w_level, h_level = slide.level_dimensions[level]  # Chosen level
   
    # 3) Compute scaling factor from level 0 to target level
    scale_x = w_level / px0
    scale_y = h_level / py0

    print(f"Slide level={level} dimensions: {w_level} x {h_level}")
    print(f"Scaling factors - X: {scale_x}, Y: {scale_y}")

    # 4) Create heatmap array (downsample for efficiency)
    w_down, h_down = n_rows // downsample_factor, n_cols // downsample_factor
    heatmap = np.full((h_down, w_down), fill_value=-1, dtype=np.int32)  # Background is -1

    print(f"Created heatmap of shape: {heatmap.shape} (downsampled)")

    # 5) Downscale tile coordinates and place labels
    for i in range(len(tile_labels)):
        row, col, row_end, col_end = tile_coords[i]
        label = tile_labels[i]

        # Scale coordinates to the target level
        row = int(row * scale_y) // downsample_factor
        col = int(col * scale_x) // downsample_factor
        row_end = int(row_end * scale_y) // downsample_factor
        col_end = int(col_end * scale_x) // downsample_factor

        # Ensure bounds stay within the heatmap size
        row_end = min(row_end, h_down)
        col_end = min(col_end, w_down)

        # Fill the heatmap with the corresponding label
        heatmap[row:row_end, col:col_end] = label
    
    slide.close()

    # 6) Define categories and colors
    if categories is None:
        categories = {0: "Neural-crest", 1: "Neuronal", 2: "Photoreceptor", 3: "Proliferative"}
    if category_colors is None:
        category_colors = {0: "red", 1: "green", 2: "blue", 3: "magenta"}

    all_labels = sorted(categories.keys())  # Only valid labels
    label_to_idx = {lbl: idx for idx, lbl in enumerate(all_labels)}

    # Create colormap for labeled areas
    cmap = mcolors.ListedColormap([category_colors[lbl] for lbl in all_labels])
    norm = mcolors.BoundaryNorm(list(range(len(all_labels) + 1)), cmap.N)

    # Mask background (-1 values)
    masked_heatmap = np.ma.masked_where(heatmap == -1, heatmap)

    # 7) Plot heatmap correctly aligned
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.set_facecolor("white")  # Set figure background to white
    ax.imshow(masked_heatmap, cmap=cmap, norm=norm, origin="upper")

    # 8) Add Legend
    legend_elems = [Patch(facecolor=category_colors[lbl], label=categories[lbl]) for lbl in all_labels]
    ax.legend(handles=legend_elems, bbox_to_anchor=(1.05, 1), loc="upper left")

    # Hide axes for clean display
    plt.axis("off")

    # 9) Save image
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    plt.savefig(output_path, bbox_inches="tight", dpi=300, transparent=True)  # Transparent PNG
    plt.close()
    print(f"Saved heatmap to {output_path}")

    return heatmap


In [31]:
import openslide
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.patches import Patch
import os
import pandas as pd

def create_categorical_heatmap(
    slide_file_name,
    tile_coords,
    tile_labels,
    level=2,
    categories=None,
    category_colors=None,
    output_path="heatmap.png"
):
    """
    Create a categorical heatmap with correctly scaled tiles.

    Args:
        slide_file_name (str): Path to the WSI.
        tile_coordinates_file (str): CSV file containing tile coordinates.
        tile_labels (np.ndarray): shape [N], integer labels for each tile.
        level (int): WSI level to visualize.
        categories (dict): Label -> category name for the legend.
        category_colors (dict): Label -> color mapping.
        output_path (str): Path to save heatmap.

    Returns:
        heatmap (np.ndarray): 2D array at the chosen level resolution.
    """
    # 1) Open the WSI slide
    try:
        slide = openslide.OpenSlide(slide_file_name)
    except Exception as e:
        raise IOError(f"Could not open slide file {slide_file_name}: {e}")

    # 2) Load tile coordinates
    #tile_df = pd.read_csv(tile_coordinates_file)
    
    # Ensure required columns exist
    '''if not {"col", "row", "col_ed", "row_ed"}.issubset(tile_df.columns):
        raise ValueError("CSV file must contain columns: col, row, col_ed, row_ed")'''

    if isinstance(tile_coords, np.ndarray):
        tile_coords = pd.DataFrame(tile_coords, columns=["row", "col", "row_ed", "col_ed"])

    # Extract relevant columns
    tile_coords = tile_coords[["row", "col", "row_ed", "col_ed"]].values

    # 3) Get dimensions at level 0 and chosen level
    w_level0, h_level0 = slide.level_dimensions[0]  # Full resolution
    w_level, h_level = slide.level_dimensions[level]  # Target level resolution

    # 4) Compute scaling factors
    scale_x = w_level / w_level0
    scale_y = h_level / h_level0

    print(f"Slide level={level} dimensions: {w_level} x {h_level}")
    print(f"Scaling factors - X: {scale_x}, Y: {scale_y}")

    # 5) Create heatmap array at the chosen level resolution
    heatmap = np.full((h_level, w_level), fill_value=-1, dtype=np.int32)  # Background as -1

    print(f"Created heatmap of shape: {heatmap.shape}")

    # 6) Properly scale tile coordinates and place labels
    for i in range(len(tile_labels)):
        row, col, row_end, col_end = tile_coords[i]
        label = tile_labels[i]

        # DEBUG: Print original tile coordinates
        print(f"Original Tile {i}: row={row}, col={col}, row_end={row_end}, col_end={col_end}")

        # Scale coordinates using the computed scale factors
        row = int(row * scale_y)
        col = int(col * scale_x)
        row_end = int(row_end * scale_y)
        col_end = int(col_end * scale_x)

        # Ensure bounds stay within the heatmap size
        row = max(0, min(row, h_level - 1))
        col = max(0, min(col, w_level - 1))
        row_end = max(0, min(row_end, h_level))
        col_end = max(0, min(col_end, w_level))

        print(f"Transformed Tile {i}: row={row}, col={col}, row_end={row_end}, col_end={col_end}")

        # Fill the heatmap with the corresponding label
        heatmap[row:row_end, col:col_end] = label
    
    slide.close()

    # 7) Define categories and colors
    if categories is None:
        categories = {0: "Neural-crest", 1: "Neuronal", 2: "Photoreceptor", 3: "Proliferative"}
    if category_colors is None:
        category_colors = {0: "red", 1: "green", 2: "blue", 3: "magenta"}

    all_labels = sorted(categories.keys())  # Only valid labels
    label_to_idx = {lbl: idx for idx, lbl in enumerate(all_labels)}

    # Create colormap for labeled areas
    cmap = mcolors.ListedColormap([category_colors[lbl] for lbl in all_labels])
    norm = mcolors.BoundaryNorm(list(range(len(all_labels) + 1)), cmap.N)

    # Mask background (-1 values)
    masked_heatmap = np.ma.masked_where(heatmap == -1, heatmap)

    # 8) Plot heatmap correctly aligned
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.set_facecolor("white")  # Set figure background to white
    ax.imshow(masked_heatmap, cmap=cmap, norm=norm, origin="upper")

    # 9) Add Legend
    legend_elems = [Patch(facecolor=category_colors[lbl], label=categories[lbl]) for lbl in all_labels]
    ax.legend(handles=legend_elems, bbox_to_anchor=(1.05, 1), loc="upper left")

    # Hide axes for clean display
    plt.axis("off")

    # 10) Save image
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    plt.savefig(output_path, bbox_inches="tight", dpi=300, transparent=True)  # Transparent PNG
    plt.close()
    print(f"Saved heatmap to {output_path}")

    return heatmap


In [32]:
categories = {
    0: "Neural-crest",
    1: "Neuronal",
    2: "Photoreceptor",
    3: "Proliferative"
}
category_colors = {
    0: "red",
    1: "green",
    2: "blue",
    3: "magenta"
}

# Create the heatmap
heatmap_array = create_categorical_heatmap(
    slide_file_name,
    tile_coords=tile_coords,
    tile_labels=tile_labels,
    level=2,
    categories=categories,
    category_colors=category_colors,
    output_path=os.path.join(outdir,"BM51_categorical_heatmap.png")
)


Slide level=2 dimensions: 22080 x 23232
Scaling factors - X: 0.25, Y: 0.25
Created heatmap of shape: (23232, 22080)
Original Tile 0: row=0, col=0, row_end=720, col_end=64
Transformed Tile 0: row=0, col=0, row_end=180, col_end=16
Original Tile 1: row=0, col=0, row_end=736, col_end=64
Transformed Tile 1: row=0, col=0, row_end=184, col_end=16
Original Tile 2: row=0, col=0, row_end=752, col_end=64
Transformed Tile 2: row=0, col=0, row_end=188, col_end=16
Original Tile 3: row=0, col=0, row_end=768, col_end=64
Transformed Tile 3: row=0, col=0, row_end=192, col_end=16
Original Tile 4: row=0, col=0, row_end=640, col_end=80
Transformed Tile 4: row=0, col=0, row_end=160, col_end=20
Original Tile 5: row=0, col=0, row_end=656, col_end=80
Transformed Tile 5: row=0, col=0, row_end=164, col_end=20
Original Tile 6: row=0, col=0, row_end=672, col_end=80
Transformed Tile 6: row=0, col=0, row_end=168, col_end=20
Original Tile 7: row=0, col=0, row_end=688, col_end=80
Transformed Tile 7: row=0, col=0, row_

In [5]:
tile_coords

array([[   45,     4, 46080,  4096],
       [   46,     4, 47104,  4096],
       [   47,     4, 48128,  4096],
       ...,
       [   50,    86, 51200, 88064],
       [   51,    86, 52224, 88064],
       [   52,    86, 53248, 88064]])