In [14]:
%load_ext autoreload
%autoreload 2
import numpy as np
import pandas as pd

from torch.utils.data import Dataset, DataLoader

from data import load_test_data, generate_training_samples
from dataset import STDataset
from baseline import GANBaseline
from evaluate import Evaluator
from utils import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
all_test_items = load_test_data(num_holes=10)

Seeding all randomness with seed=2024
Donor_id: MsBrainAgingSpatialDonor_1
Slice_id: 0
Donor_id: MsBrainAgingSpatialDonor_2
Slice_id: 0
Slice_id: 1
Donor_id: MsBrainAgingSpatialDonor_3
Slice_id: 0
Slice_id: 1
Donor_id: MsBrainAgingSpatialDonor_4
Slice_id: 0
Slice_id: 1
Slice_id: 2
Donor_id: MsBrainAgingSpatialDonor_5
Slice_id: 0
Slice_id: 1
Slice_id: 2
Donor_id: MsBrainAgingSpatialDonor_6
Slice_id: 0
Slice_id: 1
Slice_id: 2
Donor_id: MsBrainAgingSpatialDonor_7
Slice_id: 0
Slice_id: 1
Slice_id: 2
Donor_id: MsBrainAgingSpatialDonor_8
Slice_id: 0
Slice_id: 1
Slice_id: 2
Donor_id: MsBrainAgingSpatialDonor_9
Slice_id: 0
Slice_id: 1
Slice_id: 2
Donor_id: MsBrainAgingSpatialDonor_10
Slice_id: 0
Slice_id: 1
Slice_id: 2
Donor_id: MsBrainAgingSpatialDonor_11
Slice_id: 0
Slice_id: 1
Slice_id: 2
Donor_id: MsBrainAgingSpatialDonor_12
Slice_id: 0
Slice_id: 1


In [9]:
training_samples = generate_training_samples(num_samples_per_slice=10)

Seeding all randomness with seed=2024


In [18]:
training_samples

[{'normalized_positions': array([[0.56965703, 0.14636786],
         [0.77828302, 0.12849187],
         [0.50011503, 0.45695788],
         [0.71833302, 0.07028587],
         [0.45542503, 0.85175587],
         [0.35978301, 0.06723387],
         [0.73206702, 0.82385187],
         [0.34168901, 0.69283387],
         [0.24402502, 0.63179387],
         [0.98843502, 0.49161988],
         [0.79790302, 0.69741187],
         [0.20870902, 0.04390787],
         [0.28152101, 0.30937188],
         [0.31552901, 0.32179788],
         [0.23203502, 0.31373188],
         [0.00989302, 0.35951188],
         [0.30484701, 0.01098987],
         [0.89796502, 0.11105187],
         [0.73250302, 0.34904788],
         [0.65358702, 0.13437787],
         [0.24511502, 0.55549387],
         [0.37744101, 0.10124187],
         [0.46567103, 0.03213587],
         [0.46545303, 0.18386386],
         [0.25012902, 0.48180988],
         [0.89774702, 0.30021588],
         [0.62720903, 0.89034187],
         [0.81970302, 0.5293338

In [10]:
# Initialize dictionaries to store metrics
metrics = {
    'GANBaseline': {
        'mse': [],
        'f1': [],
        'cosine_sim': [],
        'chamfer_dist': [],
        'emd': []
    }
}

In [17]:
test_item.test_area.hole_max_x

-4165.913093647361

In [16]:
for i, test_item in enumerate(all_test_items):
    print(f"Test Area {i+1}:")
    print(f"  Dominant Tissue: {test_item.test_area.dominant_tissue}")
    print(f"  Number of cells in ground truth: {len(test_item.ground_truth.hole_cells)}")
    
    # filter out  current slice from the training samples
    current_donor = test_item.meta_data['donor_id']
    current_slice = test_item.meta_data['slice_id']
    print(f'current_donor: {current_donor}')
    print(f'current_slice: {current_slice}')

    filtered_samples = [sample for sample in training_samples if not (sample['metadata']['donor_id'] == current_donor and sample['metadata']['slice_id'] == current_slice)]

    dataset = STDataset(filtered_samples)
    dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=4)

    # Apply GANBaseline
    gan_baseline = GANBaseline(test_item.adata, test_item.test_area, dataloader, num_epochs=10)

    gan_coords, gan_gene_expressions = gan_baseline.fill_region()

    # Evaluate predictions for GANBaseline
    true_coords = test_item.ground_truth.hole_cells[['center_x', 'center_y']].values
    true_gene_expressions = test_item.ground_truth.gene_expression

    mse_gan, f1_gan, cosine_sim_gan = Evaluator.evaluate_expression(true_coords, true_gene_expressions, gan_coords, gan_gene_expressions)
    chamfer_dist_gan = Evaluator.chamfer_distance(true_coords, gan_coords)
    emd_gan = Evaluator.calculate_emd(true_coords, gan_coords)

    # Collect results for KNNBaseline
    metrics['GANBaseline']['mse'].append(mse_gan)
    metrics['GANBaseline']['f1'].append(f1_gan)
    metrics['GANBaseline']['cosine_sim'].append(cosine_sim_gan)
    metrics['GANBaseline']['chamfer_dist'].append(chamfer_dist_gan)
    metrics['GANBaseline']['emd'].append(emd_gan)



Test Area 1:
  Dominant Tissue: striatum
  Number of cells in ground truth: 50
current_donor: MsBrainAgingSpatialDonor_1
current_slice: 0
Epoch [10/10], d_loss: 0.20117241144180298, g_loss: 3.5646862983703613
Test Area 2:
  Dominant Tissue: cortical layer VI
  Number of cells in ground truth: 50
current_donor: MsBrainAgingSpatialDonor_1
current_slice: 0
Epoch [10/10], d_loss: 0.4711822271347046, g_loss: 2.8742623329162598
Test Area 3:
  Dominant Tissue: cortical layer VI
  Number of cells in ground truth: 50
current_donor: MsBrainAgingSpatialDonor_1
current_slice: 0
Epoch [10/10], d_loss: 0.3205311894416809, g_loss: 3.433074951171875
Test Area 4:
  Dominant Tissue: cortical layer VI
  Number of cells in ground truth: 50
current_donor: MsBrainAgingSpatialDonor_1
current_slice: 0
Epoch [10/10], d_loss: 0.2198580801486969, g_loss: 3.601646900177002
Test Area 5:
  Dominant Tissue: cortical layer VI
  Number of cells in ground truth: 50
current_donor: MsBrainAgingSpatialDonor_1
current_slic

KeyboardInterrupt: 

In [7]:
for method in metrics:
    print(f"Results for {method}:")
    for metric in metrics[method]:
        mean_value = np.mean(metrics[method][metric])
        std_value = np.std(metrics[method][metric])
        print(f"  {metric.capitalize()}: Mean = {mean_value:.4f}, Std = {std_value:.4f}")

Results for GANBaseline:
  Mse: Mean = 0.9371, Std = 0.3065
  F1: Mean = 0.6545, Std = 0.0455
  Cosine_sim: Mean = -0.0342, Std = 0.1118
  Chamfer_dist: Mean = 173.0040, Std = 20.2853
  Emd: Mean = 145.8393, Std = 12.3561


In [4]:
# Initialize dictionaries to store metrics
metrics = {
    'TissueGANBaseline': {
        'mse': [],
        'f1': [],
        'cosine_sim': [],
        'chamfer_dist': [],
        'emd': []
    }
}

In [7]:
for i, test_item in enumerate(all_test_items):
    print(f"Test Area {i+1}:")
    print(f"  Dominant Tissue: {test_item.test_area.dominant_tissue}")
    print(f"  Number of cells in ground truth: {len(test_item.ground_truth.hole_cells)}")
    
    # filter out  current slice from the training samples
    current_donor = test_item.meta_data['donor_id']
    current_slice = test_item.meta_data['slice_id']
    current_tissue = test_item.test_area.dominant_tissue
    print(f'current_donor: {current_donor}')
    print(f'current_slice: {current_slice}')

    # filtered_samples = [sample for sample in training_samples if not (sample['metadata']['donor_id'] == current_donor and sample['metadata']['slice_id'] == current_slice)]
    # filtered_samples = [sample for sample in filtered_samples if sample['metadata']['dominant_tissue'] == current_tissue]
    filtered_samples = [sample for sample in training_samples if  (sample['metadata']['donor_id'] == current_donor and sample['metadata']['slice_id'] == current_slice)]
    filtered_samples = [sample for sample in filtered_samples if sample['metadata']['dominant_tissue'] == current_tissue]
    print(len(filtered_samples))
    
    dataset = STDataset(filtered_samples)
    dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=4)

    # Apply GANBaseline
    gan_baseline = GANBaseline(test_item.adata, test_item.test_area, dataloader, num_epochs=50)

    gan_coords, gan_gene_expressions = gan_baseline.fill_region()

    # Evaluate predictions for GANBaseline
    true_coords = test_item.ground_truth.hole_cells[['center_x', 'center_y']].values
    true_gene_expressions = test_item.ground_truth.gene_expression

    print(true_coords)
    print(gan_coords)
    print(max(true_coords[0]), max(true_coords[1]))
    print(max(gan_coords[0]), max(gan_coords[1]))

    mse_gan, f1_gan, cosine_sim_gan = Evaluator.evaluate_expression(true_coords, true_gene_expressions, gan_coords, gan_gene_expressions)
    chamfer_dist_gan = Evaluator.chamfer_distance(true_coords, gan_coords)
    emd_gan = Evaluator.calculate_emd(true_coords, gan_coords)

    # Collect results for KNNBaseline
    metrics['TissueGANBaseline']['mse'].append(mse_gan)
    metrics['TissueGANBaseline']['f1'].append(f1_gan)
    metrics['TissueGANBaseline']['cosine_sim'].append(cosine_sim_gan)
    metrics['TissueGANBaseline']['chamfer_dist'].append(chamfer_dist_gan)
    metrics['TissueGANBaseline']['emd'].append(emd_gan)

    break


Test Area 1:
  Dominant Tissue: striatum
  Number of cells in ground truth: 50
current_donor: MsBrainAgingSpatialDonor_1
current_slice: 0
1
Epoch [10/50], d_loss: 1.2368333339691162, g_loss: 2.06490159034729
Epoch [20/50], d_loss: 1.113448977470398, g_loss: 1.9872255325317383
Epoch [30/50], d_loss: 1.0303536653518677, g_loss: 1.9377881288528442
Epoch [40/50], d_loss: 0.980897068977356, g_loss: 1.912305235862732
Epoch [50/50], d_loss: 0.9483224153518677, g_loss: 1.9025253057479858
[[-4282.49949084  1784.46449143]
 [-4314.49099007  1825.88449043]
 [-4219.82449235  1839.29149011]
 [-4250.07199162  1800.37849104]
 [-4189.14099309  1834.65899022]
 [-4283.42599082  1880.22098912]
 [-4208.27049263  1807.29999088]
 [-4297.86849047  1811.38749078]
 [-4245.87549172  1874.98898925]
 [-4210.66849257  1858.80248964]
 [-4292.6364906   1884.19948903]
 [-4241.13399184  1815.58399068]
 [-4177.36899337  1768.72898699]
 [-4340.86898944  1822.77799051]
 [-4338.79798949  1887.08798896]
 [-4330.18698969  18

In [8]:
for method in metrics:
    print(f"Results for {method}:")
    for metric in metrics[method]:
        mean_value = np.mean(metrics[method][metric])
        std_value = np.std(metrics[method][metric])
        print(f"  {metric.capitalize()}: Mean = {mean_value:.4f}, Std = {std_value:.4f}")

Results for TissueGANBaseline:
  Mse: Mean = 1.1995, Std = 0.0652
  F1: Mean = 0.6739, Std = 0.0091
  Cosine_sim: Mean = 0.0609, Std = 0.0214
  Chamfer_dist: Mean = 59.5196, Std = 12.9991
  Emd: Mean = 54.0886, Std = 8.3394
