In [17]:
import json
import torch
from torch.utils.data import Dataset
import os
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
import torch.optim as optim

In [18]:
class WSIEmbeddingsDataset(Dataset):
    def __init__(self, base_dir="cases"):
        self.metadata = []

        case_dirs = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]

        if "GENERAL_METADATA" in case_dirs:
            case_dirs.remove("GENERAL_METADATA")

        for case_id in case_dirs:
            case_dir = os.path.join(base_dir, case_id)
            metadata_path = os.path.join(case_dir, "aggregated_data", f'{case_id}_data.json')

            if not os.path.exists(metadata_path):
                print(f"No metadata found for case {case_id}, skipping...")
                continue
            
            with open(metadata_path, 'r') as f:
                case_metadata = json.load(f)

            if case_metadata["biospecimen"]["has_data"] == False or \
                case_metadata["clinical"]["has_data"] == False or \
                case_metadata["methylation"]["has_data"] == False:
                continue

            for sample in case_metadata["biospecimen"]["biospecimen_data"]:
                if sample["sample_type"] != "Primary Tumor":
                    continue

                for slide in sample["slides"]:
                    censored = 0
                    slide_barcode = slide["slide_barcode"]
                    survival_time = case_metadata["clinical"]["clinical_patient_data"].get("days_to_death")
                    
                    if survival_time is None:
                        survival_time = case_metadata["clinical"]["clinical_patient_data"].get("days_to_last_followup")
                        censored = 1

                    if survival_time is None:
                        continue

                    self.metadata.append({
                        "json_file": metadata_path,
                        "case_id": case_id,
                        "slide_barcode": slide_barcode,
                        "survival_time": survival_time,
                        "censored": censored
                    })
    
    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, index):
        item = self.metadata[index]
        
        with open(item["json_file"], 'r') as f:
            case_metadata = json.load(f)

        slide_data = None
        for sample in case_metadata["biospecimen"]["biospecimen_data"]:
            if sample["sample_type"] == "Primary Tumor":
                for slide in sample["slides"]:
                    if slide["slide_barcode"] == item["slide_barcode"]:
                        slide_data = slide
                        break
            if slide_data:
                break
        
        if not slide_data:
            raise ValueError(f"Slide data not found for {item['slide_barcode']} in {item['json_file']}.")

        embedding = slide_data["embedding"]
        survival_time = case_metadata["clinical"]["clinical_patient_data"]["days_to_death"]
        censored = 0

        if survival_time is None:
            survival_time = case_metadata["clinical"]["clinical_patient_data"]["days_to_last_followup"]
            censored = 1

        embedding_tensor = torch.tensor(embedding, dtype=torch.float32)
        survival_time_tensor = torch.tensor(int(survival_time), dtype=torch.float32)
        censored_tensor = torch.tensor(censored, dtype=torch.float32)

        return embedding_tensor, survival_time_tensor, censored_tensor

In [19]:
class SurvivalNN(nn.Module):
    def __init__(self, embedding_size):
        super(SurvivalNN, self).__init__()
        
        # Define a simple feedforward neural network
        self.fc1 = nn.Linear(embedding_size + 2, 64)  # +2 for survival time and censored
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 1)  # Output single value for survival prediction

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [21]:
dataset = WSIEmbeddingsDataset(base_dir="cases_TEST_TRAIN_100")
all_data = DataLoader(dataset)

for data in all_data:
    print(f"Data: {data}")

train_size = int(0.75 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset)
test_loader = DataLoader(test_dataset)

embedding_size = 2048
model = SurvivalNN(embedding_size)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    
    for batch in train_loader:
        embedding, survival_time, censored = batch
        
        input_data = torch.cat((embedding, survival_time.unsqueeze(1), censored.unsqueeze(1)), dim=1)
        
        output = model(input_data)
        loss = criterion(output, survival_time.unsqueeze(1))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
    

Data: [tensor([[0.0059, 0.2152, 0.0093,  ..., 0.0064, 0.0457, 0.1279]]), tensor([155.]), tensor([1.])]
Data: [tensor([[0.0185, 0.1392, 0.0075,  ..., 0.0047, 0.0365, 0.0790]]), tensor([0.]), tensor([1.])]
Data: [tensor([[0.0068, 0.1594, 0.0053,  ..., 0.0050, 0.0276, 0.0602]]), tensor([0.]), tensor([1.])]
Data: [tensor([[0.0076, 0.1382, 0.0070,  ..., 0.0040, 0.0267, 0.0489]]), tensor([28.]), tensor([1.])]
Data: [tensor([[0.0035, 0.1567, 0.0128,  ..., 0.0034, 0.0312, 0.0600]]), tensor([28.]), tensor([1.])]
Data: [tensor([[0.0036, 0.1382, 0.0037,  ..., 0.0039, 0.0314, 0.0513]]), tensor([0.]), tensor([1.])]
Data: [tensor([[0.0094, 0.1390, 0.0112,  ..., 0.0021, 0.0277, 0.0418]]), tensor([0.]), tensor([1.])]
Data: [tensor([[0.0029, 0.1134, 0.0039,  ..., 0.0030, 0.0159, 0.0370]]), tensor([61.]), tensor([0.])]
Data: [tensor([[0.0063, 0.2225, 0.0050,  ..., 0.0024, 0.0136, 0.0214]]), tensor([61.]), tensor([0.])]
Data: [tensor([[0.0134, 0.3036, 0.0186,  ..., 0.0079, 0.0244, 0.0850]]), tensor([61.]