In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.vision_transformer import VisionTransformer
from torchvision.models.resnet import Bottleneck, ResNet

  from .autonotebook import tqdm as notebook_tqdm


In [1]:
import os
import json
import torch
import pandas as pd
from torch.utils.data import Dataset

class MultiModalSurvivalDataset(Dataset):
    def __init__(self, base_dir="cases", embedding_modalities=None, other_data_modalities=None):
        self.metadata = []
        self.missing_counts = {modality: 0 for modality in embedding_modalities} if embedding_modalities else {}
        self.missing_other_data_counts = {modality: 0 for modality in other_data_modalities} if other_data_modalities else {}
        self.missing_survival_data_count = 0
        self.failed_cases = {}

        self.slide_stats_df = pd.read_csv(f"{base_dir}/slide_statistics.csv")
        case_dirs = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d != "GENERAL_METADATA"]

        for case_id in case_dirs:
            case_dir = os.path.join(base_dir, case_id)
            metadata_path = os.path.join(case_dir, "aggregated_data", f'{case_id}_data.json')

            if not os.path.exists(metadata_path):
                print(f"No metadata found for case {case_id}, skipping...")
                continue

            with open(metadata_path, 'r') as f:
                case_metadata = json.load(f)

            if not all(case_metadata.get(key, {}).get("has_data", False) for key in ["biospecimen", "clinical", "methylation"]):
                self.failed_cases[case_id] = self.failed_cases.get(case_id, []) + ["Missing required data"]
                continue

            for sample in case_metadata["biospecimen"]["biospecimen_data"]:
                if sample["sample_type"] == "Primary Tumor":
                    for slide in sample["slides"]:
                        slide_barcode = slide["slide_barcode"]
                        
                        if slide["resolution_level"] <= 2:
                            continue
                        
                        survival_time = case_metadata["clinical"]["clinical_patient_data"].get("days_to_death")
                        event = 1
                        
                        if survival_time is None:
                            survival_time = case_metadata["clinical"]["clinical_patient_data"].get("days_to_last_followup")
                            event = 0

                        if survival_time is None:
                            self.missing_survival_data_count += 1
                            continue
                        
                        sample_data = {"survival_time": float(survival_time), "event": event}

                        # Embedding Data
                        for modality in embedding_modalities:
                            embedding_key = f"embedding_{modality}"
                            if embedding_key in slide and slide[embedding_key] is not None:
                                sample_data[modality] = slide[embedding_key]
                            else:
                                self.missing_counts[modality] += 1
                                
                        # Other Data Types (e.g., Methylation, Clinical)
                        if "methylation" in other_data_modalities:
                            methylation_data = self.extract_methylation_embedding(case_metadata)
                            if methylation_data is not None:
                                sample_data["methylation"] = methylation_data
                            else:
                                self.missing_other_data_counts["methylation"] += 1

                        if "clinical" in other_data_modalities:
                            clinical_data = self.extract_clinical_features(case_metadata)
                            if clinical_data is not None:
                                sample_data["clinical"] = clinical_data
                            else:
                                self.missing_other_data_counts["clinical"] += 1

                        if len(sample_data) == len(embedding_modalities) + len(other_data_modalities) + 2:
                            self.metadata.append(sample_data)
                        else:
                            self.failed_cases[case_id] = self.failed_cases.get(case_id, []) + ["Incomplete data"]
        if self.failed_cases:
            print("Failed cases and reasons:")
            for case_id, errors in self.failed_cases.items():
                print(f"Case {case_id} failed due to: {', '.join(errors)}")
            else:
                print("All cases processed successfully.")

    def extract_methylation_embedding(self, case_metadata):
        """Extract methylation embedding if available."""
        methylation_data = case_metadata.get("methylation", {}).get("methylation_embedding", [])
        
        if methylation_data:
            return methylation_data
        else:
            return None

    def extract_clinical_features(self, case_metadata):
        """Extract and encode clinical features."""
        clinical_data = case_metadata.get("clinical", {}).get("clinical_patient_data", {})

        features = {
            "age": self.transform_age(clinical_data.get("days_to_birth", 0)),
            "gender": self.encode_gender(clinical_data.get("gender", "unknown")),
            "tumor_stage": self.encode_tumor_stage(clinical_data.get("pathologic_stage", "unknown")),
            "tumor_site": self.encode_tumor_site(clinical_data.get("tumor_tissue_site", "unknown")),
            "ajcc_stage": self.encode_ajcc_stage(clinical_data.get("ajcc_stage_version", "unknown")),
            "pathologic_T": self.encode_pathologic_stage(clinical_data.get("pathologic_T", "unknown")),
            "pathologic_N": self.encode_pathologic_stage(clinical_data.get("pathologic_N", "unknown")),
            "pathologic_M": self.encode_pathologic_stage(clinical_data.get("pathologic_M", "unknown"))
        }

        return torch.tensor(list(features.values()), dtype=torch.float32)

    def transform_age(self, days_to_birth):
        """Convert days_to_birth to age. Return 0 if invalid."""
        try:
            age = abs(int(days_to_birth) / 365)
            return age
        except:
            return 0

    def encode_gender(self, gender):
        """Encode gender as 0 for Male, 1 for Female, or -1 for unknown."""
        gender_map = {"MALE": 0, "FEMALE": 1}
        return gender_map.get(gender.upper(), -1)

    def encode_tumor_stage(self, tumor_stage):
        """Encode tumor stage as categorical values."""
        tumor_stage_map = {
            "Stage I": 0, "Stage IIA": 1, "Stage IIB": 2, "Stage IIIA": 3,
            "Stage IIIB": 4, "Stage IIIC": 5, "Stage IV": 6, "unknown": -1
        }
        return tumor_stage_map.get(tumor_stage, -1)

    def encode_tumor_site(self, tumor_site):
        """Encode tumor tissue site as categorical values."""
        tumor_site_map = {
            "Colon": 0, "Lung": 1, "Breast": 2, "Prostate": 3, "unknown": -1
        }
        return tumor_site_map.get(tumor_site, -1)

    def encode_ajcc_stage(self, ajcc_stage):
        """Encode AJCC stage as categorical values."""
        ajcc_stage_map = {
            "7th": 0, "8th": 1, "unknown": -1
        }
        return ajcc_stage_map.get(ajcc_stage, -1)

    def encode_pathologic_stage(self, pathologic_stage):
        """Encode pathologic stage (T, N, M) as categorical values."""
        pathologic_map = {
            "T1": 0, "T2": 1, "T3": 2, "T4": 3,
            "N1": 0, "N2": 1, "N3": 2,
            "M0": 0, "M1": 1, "unknown": -1
        }
        return pathologic_map.get(pathologic_stage, -1)

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, index):
        item = self.metadata[index]
        embeddings = {modality: torch.tensor(item[modality], dtype=torch.float32) for modality in item if modality not in ["survival_time", "event"]}
        return embeddings, torch.tensor(item["survival_time"], dtype=torch.float32), torch.tensor(item["event"], dtype=torch.float32)


In [None]:
embedding_modalities = {"vit_DINO": True, "resnet50": True}
other_data_modalities = {"clinical": True}

dataset = MultiModalSurvivalDataset(base_dir="cases", 
                                   embedding_modalities=embedding_modalities, 
                                   other_data_modalities=other_data_modalities)

print(f"Total samples in dataset: {len(dataset)}\n")

# Print the first 3 samples
for i in range(min(3, len(dataset))):
    embeddings, survival_time, event = dataset[i]
    print(f"Sample {i+1}:")
    for modality, tensor in embeddings.items():
        print(f"  {modality} embedding shape: {tensor.shape}")
    print(f"  Survival Time: {survival_time.item()}, Event: {event.item()}\n")

Failed cases and reasons:
Case 86a9e44f-728e-48e4-9e1b-682b023f7c0a failed due to: Missing required data
Case c616f95f-4ea1-4019-8c17-041878085c88 failed due to: Incomplete data
Case 41658cf8-0de1-4a97-bb60-0f0107c84e00 failed due to: Missing required data
Case a9e4b6db-8b4b-4345-a4a1-d63175eb0174 failed due to: Missing required data
Case 515a062e-f68e-4100-a923-bd2ce0bbe071 failed due to: Missing required data
Case 286cbae3-8324-43b7-9b20-0f40dda1588e failed due to: Missing required data
Case 65bb7520-f055-43a8-b735-1152fa2c9e04 failed due to: Missing required data
Case 733d8b6a-ca9d-4a69-8c9c-1f88733e8b68 failed due to: Missing required data
All cases processed successfully.
Total samples in dataset: 847

Sample 1:
  vit_DINO embedding shape: torch.Size([384])
  resnet50 embedding shape: torch.Size([2048])
  clinical embedding shape: torch.Size([8])
  Survival Time: 43.0, Event: 0.0

Sample 2:
  vit_DINO embedding shape: torch.Size([384])
  resnet50 embedding shape: torch.Size([2048]

  embeddings = {modality: torch.tensor(item[modality], dtype=torch.float32) for modality in item if modality not in ["survival_time", "event"]}


In [44]:
class EarlyFusionSurvivalModel(nn.Module):
    def __init__(self, input_dims, extra_modalities):
        super(EarlyFusionSurvivalModel, self).__init__()
        self.extra_modalities = extra_modalities

        total_dim = sum(input_dims[key] for key in extra_modalities if extra_modalities[key])

        self.fc = nn.Sequential(
            nn.Linear(total_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, inputs):
        embeddings = [inputs[key] for key in self.extra_modalities if self.extra_modalities[key]]
        x = torch.cat(embeddings, dim=1)
        return self.fc(x)

In [20]:
import numpy as np
from sksurv.ensemble import GradientBoostingSurvivalAnalysis
from sksurv.util import Surv
from sklearn.model_selection import train_test_split
from sksurv.metrics import concordance_index_censored

dataset = MultiModalSurvivalDataset(base_dir="cases", 
                                   embedding_modalities=embedding_modalities, 
                                   other_data_modalities=other_data_modalities)

features = []
survival_times = []
events = []

for i in range(len(dataset)):
    embeddings, survival_time, event = dataset[i]

    flattened_embeddings = np.concatenate([embeddings[modality].numpy().flatten() for modality in embeddings])
    
    features.append(flattened_embeddings)
    survival_times.append(survival_time.item())
    events.append(event.item())

features = np.array(features)
survival_times = np.array(survival_times)
events = np.array(events)

survival_data = Surv.from_arrays(event=events, time=survival_times)
X_train, X_test, y_train, y_test = train_test_split(features, survival_data, test_size=0.2, random_state=42)

model = GradientBoostingSurvivalAnalysis(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)
model.fit(X_train, y_train)

# Predict risk scores on the test set
risk_scores_train = model.predict(X_train)
risk_scores_test = model.predict(X_test)

# Calculate concordance index (C-index) for both train and test sets
c_index_train = concordance_index_censored(y_train['event'], y_train['time'], risk_scores_train)
c_index_test = concordance_index_censored(y_test['event'], y_test['time'], risk_scores_test)

# Print the C-index for train and test
print(f"Train Concordance Index: {c_index_train[0]}")
print(f"Test Concordance Index: {c_index_test[0]}")

Failed cases and reasons:
Case 86a9e44f-728e-48e4-9e1b-682b023f7c0a failed due to: Missing required data
Case c616f95f-4ea1-4019-8c17-041878085c88 failed due to: Incomplete data
Case 41658cf8-0de1-4a97-bb60-0f0107c84e00 failed due to: Missing required data
Case a9e4b6db-8b4b-4345-a4a1-d63175eb0174 failed due to: Missing required data
Case 515a062e-f68e-4100-a923-bd2ce0bbe071 failed due to: Missing required data
Case 286cbae3-8324-43b7-9b20-0f40dda1588e failed due to: Missing required data
Case 65bb7520-f055-43a8-b735-1152fa2c9e04 failed due to: Missing required data
Case 733d8b6a-ca9d-4a69-8c9c-1f88733e8b68 failed due to: Missing required data
All cases processed successfully.


  embeddings = {modality: torch.tensor(item[modality], dtype=torch.float32) for modality in item if modality not in ["survival_time", "event"]}


Train Concordance Index: 0.9699532014465008
Test Concordance Index: 0.6175609756097561
