In [None]:
import h5py
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import h5py
import numpy as np
from scipy.stats import spearmanr
from sklearn.model_selection import train_test_split

import h5py
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import wandb
from scipy.stats import spearmanr
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
from collections import defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.stats import spearmanr
import cv2


## EDA And preprocess data

In [None]:
import h5py
import numpy as np
import matplotlib.pyplot as plt
import math

def plot_celltype_abundance(slide_name, image, spots, plots_per_line=4, show_image=False):
    """
    Plot the slide image with spot overlays for cell-type abundances (C1 to C35).

    Parameters:
        slide_name (str): Name of the slide.
        image (np.array): 2D or 3D array representing the slide image.
        spots (np.array): Structured NumPy array containing spot information with fields:
                          'x', 'y', 'C1', 'C2', ..., 'C35'.
        plots_per_line (int): Number of subplots per row (default is 4).

    The function creates a figure with subplots arranged in a grid and plots:
      - The slide image as a background.
      - A scatter overlay at spot (x, y) positions, colored by the abundance of each cell type.
      - A common colorbar that indicates the mapping from color to abundance.
    """
    # Extract the x and y coordinates from the spots array
    x = spots["x"]
    y = spots["y"]
    num_celltypes = 35  # There are 35 cell types (C1 to C35)
    num_rows = math.ceil(num_celltypes / plots_per_line)
    
    # Create a figure with a grid of subplots
    fig, axes = plt.subplots(num_rows, plots_per_line, figsize=(plots_per_line * 3, num_rows * 3))
    axes = axes.flatten()  # Flatten to simplify indexing
    
    # Loop over each cell type field from C1 to C35
    for i in range(num_celltypes):
        var_name = f"C{i+1}"  # Create field name e.g. "C1", "C2", ...
        c_values = spots[var_name]  # Get the abundance values for this cell type
        
        ax = axes[i]
        if show_image is True:
            ax.imshow(image, aspect="auto")
        else:
            # Manually set the axis limits to match the image dimensions
            height, width = image.shape[:2]
            ax.set_xlim(0, width)
            ax.set_ylim(height, 0)  # invert y-axis to mimic image display

        sc = ax.scatter(x, y, c=c_values, cmap="viridis", s=2, alpha=1)
        #sc = ax.scatter(x, y, c=c_values, cmap="plasma", s=2, alpha=0.7)

        ax.set_title(var_name, fontsize=8)
        ax.axis("off")
    
    # If there are any extra subplots (in case grid has one more cell), hide them
    for j in range(num_celltypes, len(axes)):
        axes[j].axis("off")
    
    # Set an overall title for the figure
    fig.suptitle(f"Slide {slide_name}", fontsize=14)
    
    # Add a common colorbar (using the last scatter object)
    #fig.colorbar(sc, ax=axes.tolist(), label="Abundance")
    #fig.subplots_adjust(right=0.85)
    #cbar_ax = fig.add_axes([0.88, 0.15, 0.03, 0.7])  # [left, bottom, width, height]
    #fig.colorbar(sc, cax=cbar_ax, label="Abundance")
    
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()


def plot_total_abundance(slide_name, image, spots, show_colorbar=True, colorbar_outside=True, show_image=True, shift_x = 0, shift_y = 0 ):
    """
    Plot the slide image with spot overlays where each spot's color represents 
    the total abundance of cell types C1 to C35.

    Parameters:
        slide_name (str): Name of the slide.
        image (np.array): The slide image.
        spots (np.array): Structured NumPy array with fields 'x', 'y', and 'C1' ... 'C35'.
        show_colorbar (bool): Whether to display a colorbar (default True).
        colorbar_outside (bool): If True, place the colorbar to the right of the plot.
    """
    # Extract x and y coordinates from the spots array
    x = spots["x"] +  shift_x
    y = spots["y"] +  shift_y
    
    # Compute total abundance by summing C1 through C35 for each spot
    total_abundance = np.zeros_like(x, dtype=float)
    if slide_name == "S_7":
        total_abundance = spots["Test_Set"]
    else:
        for i in range(1, 36):  # Fields C1 to C35
            total_abundance += spots[f"C{i}"]
    
    # Create a figure
    fig, ax = plt.subplots(figsize=(6, 6))
    if show_image is True:
        ax.imshow(image, aspect="auto")
    else:
        # Manually set the axis limits to match the image dimensions
        height, width = image.shape[:2]
        ax.set_xlim(0, width)
        ax.set_ylim(height, 0)  # invert y-axis to mimic image display

    sc = ax.scatter(x, y, c=total_abundance, cmap="viridis", s=2, alpha=0.7)
    ax.set_title(f"Total Abundance for Slide {slide_name}")
    ax.axis("off")
    
    # Add a colorbar if desired
    if show_colorbar:
        if colorbar_outside:
            # Adjust the right margin to make room for the colorbar
            fig.subplots_adjust(right=0.85)
            cbar_ax = fig.add_axes([0.88, 0.15, 0.03, 0.7])
            fig.colorbar(sc, cax=cbar_ax, label="Total Abundance")
        else:
            fig.colorbar(sc, ax=ax, label="Total Abundance")
    
    plt.tight_layout()
    plt.show()

In [None]:
# S1: x - 60, y - 60
# S2: x - 70, y - 70
# S3: x - 20, y - 20
# S4: x - 10, y - 10
# S5: x - 10, y - 10
# S6:  x - 10, y - 10

In [None]:
slice_name = f"S_6"
with h5py.File("/kaggle/input/el-hackathon-2025/elucidata_ai_challenge_data.h5", "r") as h5file:
    image = np.array(h5file["images/Train"][slice_name])
    print(image.shape)
    spots = np.array(h5file["spots/Train"][slice_name])
    plot_total_abundance(slice_name, image, spots, show_colorbar=True, colorbar_outside=True,  shift_x = -10, shift_y = -10)
    plot_total_abundance(slice_name, image, spots, show_colorbar=True, colorbar_outside=True, show_image=False)

In [None]:
config_shilf = {
    "S_1": 60,
    "S_2": 70,
    "S_3": 20,
    "S_4": 10,
    "S_5": 10,
    "S_6": 10,
    "S_7": 0,

}
class CellDataset(data.Dataset):
    def __init__(self, mode, images=None, coords=None, labels=None, transform=None):
        self.images = images if images is not None else []
        self.coords = coords if coords is not None else []
        self.labels = labels if labels is not None else []
        self.transform = transform  # Add transform support

        if mode in ["Train", "Test"] and images is None:
            with h5py.File("/kaggle/input/el-hackathon-2025/elucidata_ai_challenge_data.h5", "r") as h5file:
                images_group = h5file[f"images/{mode}"]
                spots_group = h5file[f"spots/{mode}"]

                for slide_name in images_group.keys():
                    image = np.array(images_group[slide_name])
                    spots = np.array(spots_group[slide_name])
                    spots = pd.DataFrame(spots)
                    subimage_size = 110
                    half_size = subimage_size // 2

                    for i in range(spots.shape[0]):
                        x1, y1 = int(spots["x"][i]), int(spots["y"][i])
                        x = int(x1 - config_shilf[slide_name])
                        y = int(y1 - config_shilf[slide_name])
                        cell_values = spots.iloc[i][2:].values.tolist()

                        x_min = max(0, x - half_size)
                        x_max = min(image.shape[1], x + half_size)
                        y_min = max(0, y - half_size)
                        y_max = min(image.shape[0], y + half_size)

                        subimage = image[y_min:y_max, x_min:x_max]
                        subimage = (subimage * 255).astype(np.uint8)  # Normalize and convert to uint8
                        # cv2.imwrite(f'/content/sample_data/data/{slide_name}_{x1}_{y1}.jpg', subimage)

                        if subimage.shape[:2] == (subimage_size, subimage_size):
                            self.images.append(subimage)
                            self.coords.append((x, y))
                            self.labels.append(cell_values)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = torch.tensor(self.images[idx], dtype=torch.float32).permute(2, 0, 1) / 255.0
        coords = torch.tensor(self.coords[idx], dtype=torch.float32)
        labels = torch.tensor(self.labels[idx], dtype=torch.float32)

        if self.transform:
            image = self.transform(image)

        return image, coords, labels

In [None]:
# 0.3
class CellProportionPredictor(nn.Module):
    def __init__(self, num_cell_types):
        super().__init__()
        self.cnn = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True)
        self.cnn.fc = nn.Identity()  # Remove final layer

        # Coordinate processing
        self.coord_fc = nn.Sequential(
            nn.Linear(2, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128)
        )

        self.head = nn.Sequential(
            nn.Linear(2048 + 128, 512),  
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, num_cell_types),
            nn.ReLU()
        )

    def forward(self, image, coords):
        features = self.cnn(image)
        coord_features = self.coord_fc(coords)
        combined = torch.cat([features, coord_features], dim=1)
        return self.head(combined)


class KLDivLoss(nn.Module):
    def forward(self, pred, target):
        pred = torch.clamp(pred, 1e-7, 1.0)
        target = torch.clamp(target, 1e-7, 1.0)
        return (target * (torch.log(target) - torch.log(pred))).sum(dim=1).mean()

# Metrics
def metric_monitor(pred, target):
    metrics = {
        'Cosine': nn.CosineSimilarity()(pred, target).mean(),
        'Spearman': spearmanr(pred.cpu().detach(),
                              target.cpu().detach()).correlation
    }
    return metrics


class CombinedLoss(nn.Module):
    def __init__(self, alpha=0.5):
        """
        Combined loss using KL Divergence and Spearman correlation.
        :param alpha: Weight for KL loss, (1-alpha) is for Spearman loss.
        """
        super().__init__()
        self.alpha = alpha
        self.kl_loss = KLDivLoss()  # Use your defined KLDivLoss

    def spearman_loss(self, pred, target):
        """
        Approximate Spearman rank correlation as a loss.
        """
        pred_rank = torch.argsort(torch.argsort(pred, dim=1), dim=1).float()
        target_rank = torch.argsort(torch.argsort(target, dim=1), dim=1).float()

        pred_rank = pred_rank / (pred.shape[1] - 1)
        target_rank = target_rank / (target.shape[1] - 1)

        spearman_corr = F.cosine_similarity(pred_rank, target_rank, dim=1)
        return 1 - spearman_corr.mean() 

    def forward(self, pred, target):
        """
        Compute the combined loss.
        """
        kl = self.kl_loss(pred, target)
        spearman = self.spearman_loss(pred, target)
        return self.alpha * kl + (1 - self.alpha) * spearman

from scipy.stats import spearmanr
import numpy as np

def compute_spearman(y_pred, y_true):
    correlations = []
    y_pred = y_pred.cpu().numpy()
    y_true = y_true.cpu().numpy()

    for i in range(y_pred.shape[0]):  
        corr, _ = spearmanr(y_pred[i], y_true[i])
        if not np.isnan(corr):  
            correlations.append(corr)

    return np.mean(correlations) if correlations else 0.0  


def train_model():
    best_val_loss = float("inf")
    best_model_path = "best_model.pth"
    checkpoint_path = "checkpoint.pth"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    full_train_set = CellDataset("Train")

    full_val_set = CellDataset("Test")

    train_loader = DataLoader(
        full_train_set, batch_size=64, shuffle=True, num_workers=4,
        pin_memory=True, persistent_workers=(4 > 0))

    val_loader = DataLoader(
        full_val_set, batch_size=64, shuffle=False, num_workers=4,
        pin_memory=True)

    model = CellProportionPredictor(num_cell_types=35).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
    # criterion = KLDivLoss()
    criterion = CombinedLoss()

    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(1000):
        model.train()
        total_train_loss = 0.0

        for images, coords, labels in train_loader:
            images, coords, labels = images.to(device), coords.to(device), labels.to(device)

            optimizer.zero_grad()

            with torch.cuda.amp.autocast():
                outputs = model(images, coords)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)
        
        model.eval()
        total_val_loss = 0.0
        metrics = defaultdict(float)
        total_spearman = 0.0
        num_batches = 0

        with torch.no_grad():
            for images, coords, labels in train_loader:
                images, coords, labels = images.to(device), coords.to(device), labels.to(device)

                outputs = model(images, coords)
                loss = criterion(outputs, labels)
                total_val_loss += loss.item()

                # Compute batch-wise Spearman correlation
                batch_spearman = compute_spearman(outputs, labels)
                total_spearman += batch_spearman

        avg_spearman = total_spearman / len(train_loader)
        scheduler.step(1-avg_spearman)

        print(f"Epoch {epoch+1}: Train Loss={avg_train_loss:.4f}, train spearman = {avg_spearman:.4f} ")
        # wandb.log({
        #     "epoch": epoch + 1,
        #     "train_loss": avg_train_loss,
        #     "val_loss": avg_val_loss,
        #     "Spearman" : metrics["Spearman"] / len(test_loader),
        #     "train_spearman" : avg_spearman
        # })



        # **Save checkpoint (for continuing training)**
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'scaler_state_dict': scaler.state_dict(),
            'best_val_loss': best_val_loss
        }

        torch.save(checkpoint, checkpoint_path)

   
        predictions = []

        with torch.no_grad():
                for images, coords, labels in val_loader:
                    images, coords, labels = images.to(device), coords.to(device), labels.to(device)

                    outputs = model(images, coords)
                    predictions.extend(outputs.cpu().detach().numpy())
                    # Save predictions
                columns = [f"C{k}" for k in range(1, 36)]
                df_predictions = pd.DataFrame(predictions, columns=columns)
                df_predictions['ID'] = df_predictions.index
                df_predictions.to_csv("predictions.csv", index=False)
                # wandb.save("predictions.csv")

    print("Training complete.")


train_model()

### ViT-Based Model

In [None]:
import torch
import torch.nn as nn
from torchvision.models import vit_b_16


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SpotViT(nn.Module):
    def __init__(self, num_classes=35):
        super(SpotViT, self).__init__()
        
        # Load pre-trained ViT
        self.vit = vit_b_16(pretrained=True)

        # Modify classifier head
        in_features = self.vit.heads.head.in_features
        self.vit.heads = nn.Identity()  # Remove the original classifier

        # MLP head for cell-type prediction
        self.fc = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        features = self.vit(x)  # Extract features from ViT
        out = self.fc(features)  # Predict 35 cell-type values
        return out


In [None]:
import torch.optim as optim
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data augmentation
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
])

# Load full dataset
full_train_set = CellDataset("Train", transform=train_transform)
full_val_set = CellDataset("Test", transform=train_transform)

# Split dataset into training and validation
train_indices, test_indices = train_test_split(
    range(len(full_train_set)), test_size=0.05, random_state=42
)

train_set = Subset(full_train_set, train_indices)
test_set = Subset(full_train_set, test_indices)

# Create data loaders
train_loader = DataLoader(
    train_set, batch_size=64, shuffle=True, num_workers=4,
    pin_memory=True, persistent_workers=(4 > 0)
)
test_loader = DataLoader(
    test_set, batch_size=64, shuffle=False, num_workers=4,
    pin_memory=True
)
val_loader = DataLoader(
    full_val_set, batch_size=64, shuffle=False, num_workers=4,
    pin_memory=True
)

# Initialize Model
model = SpotViT().to(device)
criterion = nn.MSELoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

# Training loop
num_epochs = 1000
best_val_loss = float("inf")
best_model_path = "best_model.pth"
checkpoint_path = "checkpoint.pth"

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0  # Reset loss for each epoch

    for images, coords, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    avg_train_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}: Train Loss={avg_train_loss:.4f}")

    # **Validation Phase**
    model.eval()
    val_loss = 0.0
    predictions = []
    with torch.no_grad():
        for images, coords, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            predictions.extend(outputs.cpu().detach().numpy())

    avg_val_loss = val_loss / len(val_loader)
    print(f"Epoch {epoch+1}: Validation Loss={avg_val_loss:.4f}")

    # Save checkpoint
    checkpoint = {
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }
    torch.save(checkpoint, checkpoint_path)

    # Save best model based on validation loss
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(checkpoint, best_model_path)
        print(f"New best model saved with val loss: {best_val_loss:.4f}")

        # Save predictions
        columns = [f"C{k}" for k in range(1, 36)]
        df_predictions = pd.DataFrame(predictions, columns=columns)
        df_predictions['ID'] = df_predictions.index
        df_predictions.to_csv("predictions.csv", index=False)

print("Training complete.")


### Graph

In [None]:
!pip install torch_geometric

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch_geometric
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, GATConv, GraphSAGE
import numpy as np
import pandas as pd
from sklearn.neighbors import kneighbors_graph


class CellDataset(data.Dataset):
    def __init__(self, mode, images=None, coords=None, labels=None, slide_ids=None, transform=None):
        self.images = images if images is not None else []
        self.coords = coords if coords is not None else []
        self.labels = labels if labels is not None else []
        self.slide_ids = slide_ids if slide_ids is not None else []  # Add slide ID
        self.transform = transform  

        if mode in ["Train", "Test"] and images is None:
            with h5py.File("/kaggle/input/el-hackathon-2025/elucidata_ai_challenge_data.h5", "r") as h5file:
                images_group = h5file[f"images/{mode}"]
                spots_group = h5file[f"spots/{mode}"]

                for slide_id, slide_name in enumerate(images_group.keys()):  # Unique slide_id
                    image = np.array(images_group[slide_name])
                    spots = np.array(spots_group[slide_name])
                    spots = pd.DataFrame(spots)

                    for i in range(spots.shape[0]):
                        x, y = int(spots["x"][i]), int(spots["y"][i])
                        cell_values = spots.iloc[i][2:].values.tolist()
                        subimage_size = 224
                        half_size = subimage_size // 2
                        x1, y1 = int(spots["x"][i]), int(spots["y"][i])
                        x = int(x1 - config_shilf[slide_name])
                        y = int(y1 - config_shilf[slide_name])
                        cell_values = spots.iloc[i][2:].values.tolist()

                        x_min = max(0, x - half_size)
                        x_max = min(image.shape[1], x + half_size)
                        y_min = max(0, y - half_size)
                        y_max = min(image.shape[0], y + half_size)

                        subimage = image[y_min:y_max, x_min:x_max]
                        subimage = (subimage * 255).astype(np.uint8)

                        self.images.append(subimage)
                        self.coords.append((x, y))
                        self.labels.append(cell_values)
                        self.slide_ids.append(slide_id)  

    def __getitem__(self, idx):
        image = torch.tensor(self.images[idx], dtype=torch.float32).permute(2, 0, 1) / 255.0
        coords = torch.tensor(self.coords[idx], dtype=torch.float32)
        labels = torch.tensor(self.labels[idx], dtype=torch.float32)
        slide_id = torch.tensor(self.slide_ids[idx], dtype=torch.int64)

        return image, coords, labels, slide_id


In [None]:
import torch
import timm
import numpy as np
from torch_geometric.data import Data
from sklearn.neighbors import kneighbors_graph
import torchvision.transforms as transforms
from PIL import Image
from tqdm import tqdm


vit = timm.create_model("vit_base_patch16_224", pretrained=True)
vit.head = torch.nn.Linear(768, 512)  # Reduce feature size to 512
vit.eval()

transform = transforms.Compose([
    transforms.Resize((224, 224)),  
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [None]:
def extract_features(image):
    image = (image * 255).astype(np.uint8) 
    image = Image.fromarray(image).convert("RGB")  
    image = transform(image).unsqueeze(0)  
    with torch.no_grad():
        features = vit(image)  
    return features.squeeze()

In [None]:
data = CellDataset(mode='Train')

images = list(data.images)
coords = np.array(data.coords)
labels = np.array(data.labels, dtype=object) 
slide_ids = np.array(data.slide_ids)


edge_index_list = []

unique_wsi_ids = np.unique(slide_ids)
for wsi_id in unique_wsi_ids:
    mask = (slide_ids == wsi_id)
    coords_wsi = coords[mask]
    node_indices = np.where(mask)[0]

    k = 10
    adj_matrix = kneighbors_graph(coords_wsi, k, mode="connectivity", include_self=False)

    edge_list = np.array(adj_matrix.nonzero())
    edge_list[0] = node_indices[edge_list[0]]
    edge_list[1] = node_indices[edge_list[1]]

    edge_index_list.append(torch.tensor(edge_list, dtype=torch.long))

# Merge all edges
edge_index = torch.cat(edge_index_list, dim=1)

# Convert images to feature vectors (flatten each image)
# x = torch.stack([torch.tensor(img, dtype=torch.float).flatten() for img in images])

features = []
for img in tqdm(images, desc="Extracting features", total=len(images)):
    features.append(extract_features(img))

x = torch.stack(features)

labels = [np.array(label, dtype=np.float32) for label in labels]  
y = torch.tensor(np.vstack(labels), dtype=torch.float32)

graph_data = Data(x=x, edge_index=edge_index, y=y)

print(f"Graph created with {graph_data.num_nodes} nodes and {graph_data.num_edges} edges.")


In [None]:
datatest = CellDataset(mode='Test')
# Extract data
images_test = list(datatest.images) 
coords = np.array(datatest.coords)
labels = np.array(datatest.labels, dtype=object)  
slide_ids = np.array(datatest.slide_ids)

# Build Graph
edge_index_list = []

unique_wsi_ids = np.unique(slide_ids)
for wsi_id in unique_wsi_ids:
    mask = (slide_ids == wsi_id)
    coords_wsi = coords[mask]
    node_indices = np.where(mask)[0]

    # Compute k-NN Graph (k=5)
    k = 10
    adj_matrix = kneighbors_graph(coords_wsi, k, mode="connectivity", include_self=False)

    edge_list = np.array(adj_matrix.nonzero())
    edge_list[0] = node_indices[edge_list[0]]
    edge_list[1] = node_indices[edge_list[1]]

    edge_index_list.append(torch.tensor(edge_list, dtype=torch.long))

# Merge all edges
edge_index = torch.tensor(np.array(adj_matrix.nonzero()), dtype=torch.long).contiguous()

# Convert images to feature vectors (flatten each image)
# x = torch.stack([torch.tensor(img, dtype=torch.float).flatten() for img in images])

features_test = []
for img in tqdm(images_test, desc="Extracting features", total=len(images_test)):
    features_test.append(extract_features(img))

x = torch.stack(features_test)

# Convert labels
labels = [np.array(label, dtype=np.float32) for label in labels] 
y = torch.tensor(np.vstack(labels), dtype=torch.float32)

# Create Graph Data
graph_data_test = Data(x=x, edge_index=edge_index, y=y)

print(f"Graph created with {graph_data_test.num_nodes} nodes and {graph_data_test.num_edges} edges.")


In [None]:
# class GNNModel(torch.nn.Module):
#     def __init__(self, in_dim, hidden_dim, out_dim):
#         super(GNNModel, self).__init__()
#         self.conv1 = GCNConv(in_dim, hidden_dim)
#         self.conv2 = GCNConv(hidden_dim, hidden_dim)
#         self.fc = nn.Linear(hidden_dim, out_dim)

#     def forward(self, x, edge_index):
#         x = self.conv1(x, edge_index)
#         x = self.conv2(x, edge_index)
#         x = self.fc(x)
#         return x


class GNNModel(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super(GNNModel, self).__init__()
        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, out_dim)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.fc(x)
        return F.relu(x)

In [None]:
def spearman_score(y_true, y_pred):
    """Compute Spearman correlation per output dimension and return the mean."""
    y_true = y_true.cpu().detach().numpy()
    y_pred = y_pred.cpu().detach().numpy()

    scores = []
    for i in range(y_true.shape[1]):  # Loop over each output dimension (35)
        coef, _ = spearmanr(y_true[:, i], y_pred[:, i])
        scores.append(coef if not np.isnan(coef) else 0.0)  # Handle NaNs

    return np.mean(scores) 

In [None]:
# Initialize Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GNNModel(in_dim=x.shape[1], hidden_dim=1024, out_dim=35).to(device)

# Define Loss Function (MSE Loss for regression task)
criterion = nn.MSELoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

# Move Data to Device
graph_data = graph_data.to(device)
graph_data_test = graph_data_test.to(device)


# Training Loop
num_epochs = 100000
best_loss = float("inf")
best_model_path = "best_model.pth"
checkpoint_path = "checkpoint.pth"

import os

if os.path.exists('/kaggle/working/best_mode.pth'):
    print("Loading checkpoint...")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']  # Resume from last epoch
    print(f"Resumed training from epoch {start_epoch}")
else:
    print("No checkpoint found. Starting from scratch.")
    start_epoch = 0

for epoch in range(start_epoch, num_epochs):  # Start from last saved epoch
    total_spearman = 0
    model.train()
    optimizer.zero_grad()
    
    outputs = model(graph_data.x, graph_data.edge_index)
    loss = criterion(outputs, graph_data.y)
    loss.backward()
    optimizer.step()
# for epoch in range(num_epochs):
#     total_spearman = 0
#     model.train()
#     optimizer.zero_grad()
#     outputs = model(graph_data.x, graph_data.edge_index)
#     loss = criterion(outputs, graph_data.y)
#     loss.backward()
#     optimizer.step()
    spearman_corr = spearman_score(graph_data.y, outputs)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}, Spearman: {spearman_corr}")

    
    # Save checkpoint
    checkpoint = {
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }
    if loss.item() < best_loss:
        best_loss = loss.item()
        torch.save(checkpoint, best_model_path)
        print(f"New best model saved with loss: {best_loss:.4f}")
        torch.save(checkpoint, checkpoint_path)
        model.eval()
        with torch.no_grad():
            predictions = model(graph_data_test.x, graph_data_test.edge_index).cpu().numpy()
            df_predictions = pd.DataFrame(predictions, columns=[f"C{k}" for k in range(1, 36)])
            df_predictions['ID'] = df_predictions.index
            df_predictions.to_csv("predictions.csv", index=False)
print("Training complete.")

