In [1]:
%load_ext autoreload
%autoreload 2
import tqdm
import numpy as np
import pandas as pd

import torch
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader

from data import load_test_data, generate_training_samples
from dataset import STDataset
from evaluate import Evaluator
from utils import *

from parameter import create_args
from baseline import LatentSpaceGAN

In [2]:
training_samples = generate_training_samples(num_samples_per_slice=10)
all_test_items = load_test_data(num_holes=10)

Seeding all randomness with seed=2024
Seeding all randomness with seed=2024
Donor_id: MsBrainAgingSpatialDonor_1
Slice_id: 0
Donor_id: MsBrainAgingSpatialDonor_2
Slice_id: 0
Slice_id: 1
Donor_id: MsBrainAgingSpatialDonor_3
Slice_id: 0
Slice_id: 1
Donor_id: MsBrainAgingSpatialDonor_4
Slice_id: 0
Slice_id: 1
Slice_id: 2
Donor_id: MsBrainAgingSpatialDonor_5
Slice_id: 0
Slice_id: 1
Slice_id: 2
Donor_id: MsBrainAgingSpatialDonor_6
Slice_id: 0
Slice_id: 1
Slice_id: 2
Donor_id: MsBrainAgingSpatialDonor_7
Slice_id: 0
Slice_id: 1
Slice_id: 2
Donor_id: MsBrainAgingSpatialDonor_8
Slice_id: 0
Slice_id: 1
Slice_id: 2
Donor_id: MsBrainAgingSpatialDonor_9
Slice_id: 0
Slice_id: 1
Slice_id: 2
Donor_id: MsBrainAgingSpatialDonor_10
Slice_id: 0
Slice_id: 1
Slice_id: 2
Donor_id: MsBrainAgingSpatialDonor_11
Slice_id: 0
Slice_id: 1
Slice_id: 2
Donor_id: MsBrainAgingSpatialDonor_12
Slice_id: 0
Slice_id: 1


In [3]:
dataset = STDataset(training_samples)
# dataset = STDataset(filter_training_samples)
# Define a DataLoader to handle batching
# dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=4)
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)

In [6]:
model = LatentSpaceGAN()
model.train_model(dataloader=dataloader)


Training VAE...
Epoch 1, Average Loss: 3.6533
Epoch 2, Average Loss: 2.3188
Epoch 3, Average Loss: 1.6686
Epoch 4, Average Loss: 1.4297
Epoch 5, Average Loss: 1.3457
Epoch 6, Average Loss: 1.2878
Epoch 7, Average Loss: 1.2684
Epoch 8, Average Loss: 1.2250
Epoch 9, Average Loss: 1.2254
Epoch 10, Average Loss: 1.2165
Epoch 11, Average Loss: 1.1724
Epoch 12, Average Loss: 1.1922
Epoch 13, Average Loss: 1.1971
Epoch 14, Average Loss: 1.1964
Epoch 15, Average Loss: 1.2207
Epoch 16, Average Loss: 1.1921
Epoch 17, Average Loss: 1.1939
Epoch 18, Average Loss: 1.2032
Epoch 19, Average Loss: 1.1884
Epoch 20, Average Loss: 1.1883
Epoch 21, Average Loss: 1.2023
Epoch 22, Average Loss: 1.1870
Epoch 23, Average Loss: 1.2146
Epoch 24, Average Loss: 1.1978
Epoch 25, Average Loss: 1.1973
Epoch 26, Average Loss: 1.1938
Epoch 27, Average Loss: 1.1944
Epoch 28, Average Loss: 1.1919
Epoch 29, Average Loss: 1.1888
Epoch 30, Average Loss: 1.1877
Epoch 31, Average Loss: 1.1732
Epoch 32, Average Loss: 1.2069
E

In [7]:
# Initialize dictionaries to store metrics
metrics = {
    'L-GANBaseline': {
        'mse': [],
        'f1': [],
        'cosine_sim': [],
        'chamfer_dist': [],
        'emd': []
    }
}

In [9]:
for i, test_item in enumerate(all_test_items):
    print(f"Test Area {i+1}:")
    print(f"  Dominant Tissue: {test_item.test_area.dominant_tissue}")
    print(f"  Number of cells in ground truth: {len(test_item.ground_truth.hole_cells)}")

    # Evaluate the predicted results
    true_coords = test_item.ground_truth.hole_cells[['center_x', 'center_y']].values
    true_gene_expressions = test_item.ground_truth.gene_expression

    coords, gene_expressions = model.fill_region(adata=test_item.adata, test_area=test_item.test_area)

    mse, f1, cosine_sim = Evaluator.evaluate_expression(true_coords, true_gene_expressions, coords, gene_expressions)
    chamfer_dist = Evaluator.chamfer_distance(true_coords, coords)
    emd = Evaluator.calculate_emd(true_coords, coords)

    # Collect results for KNNBaseline
    metrics['L-GANBaseline']['mse'].append(mse)
    metrics['L-GANBaseline']['f1'].append(f1)
    metrics['L-GANBaseline']['cosine_sim'].append(cosine_sim)
    metrics['L-GANBaseline']['chamfer_dist'].append(chamfer_dist)
    metrics['L-GANBaseline']['emd'].append(emd)


Test Area 1:
  Dominant Tissue: striatum
  Number of cells in ground truth: 50
Test Area 2:
  Dominant Tissue: cortical layer VI
  Number of cells in ground truth: 50
Test Area 3:
  Dominant Tissue: cortical layer VI
  Number of cells in ground truth: 50
Test Area 4:
  Dominant Tissue: cortical layer VI
  Number of cells in ground truth: 50
Test Area 5:
  Dominant Tissue: cortical layer VI
  Number of cells in ground truth: 50
Test Area 6:
  Dominant Tissue: striatum
  Number of cells in ground truth: 50
Test Area 7:
  Dominant Tissue: striatum
  Number of cells in ground truth: 50
Test Area 8:
  Dominant Tissue: striatum
  Number of cells in ground truth: 50
Test Area 9:
  Dominant Tissue: cortical layer VI
  Number of cells in ground truth: 50
Test Area 10:
  Dominant Tissue: cortical layer VI
  Number of cells in ground truth: 50
Test Area 11:
  Dominant Tissue: striatum
  Number of cells in ground truth: 50
Test Area 12:
  Dominant Tissue: corpus callosum
  Number of cells in groun

In [10]:
for method in metrics:
    print(f"Results for {method}:")
    for metric in metrics[method]:
        mean_value = np.mean(metrics[method][metric])
        std_value = np.std(metrics[method][metric])
        print(f"  {metric.capitalize()}: Mean = {mean_value:.4f}, Std = {std_value:.4f}")

Results for L-GANBaseline:
  Mse: Mean = 0.9674, Std = 0.2391
  F1: Mean = 0.5966, Std = 0.0539
  Cosine_sim: Mean = 0.0216, Std = 0.0596
  Chamfer_dist: Mean = 78.6854, Std = 9.5834
  Emd: Mean = 72.3266, Std = 6.9106
