In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import anndata as ad
import numpy as np
import pandas as pd
import umap
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

In [2]:
input_dir = "/Users/apple/Desktop/KB/data"
adata_train = ad.read_h5ad(input_dir+'/LarryData/train_test/Larry_500_train.h5ad')
adata_test = ad.read_h5ad(input_dir+'/LarryData/train_test/Larry_500_test.h5ad')

train_labels = adata_train.obs["clone_id"].to_numpy()
test_labels = adata_test.obs["clone_id"].to_numpy()

print(train_labels.shape, test_labels.shape)

(17054,) (2177,)


## supUMAP embedding

In [3]:
# Extract the data matrix and labels
train_data = adata_train.X
test_data = adata_test.X


# labels = adata_train.obs['clone_id'].values

# Initialize UMAP with a higher number of neighbors for supervised learning
reducer = umap.UMAP(n_neighbors=15, n_components=10)



In [4]:
# Fit and transform the data with the labels
X_train = reducer.fit_transform(train_data, y=train_labels)

In [5]:
# get the test embeddings
X_test = reducer.transform(test_data)

In [6]:
X_train.shape, X_test.shape

((17054, 10), (2177, 10))

### Load Data

### Pick the cells in the neutrophil monocyte trajectory

In [7]:
# load data
meta_df = pd.read_csv("/Users/apple/Desktop/KB/Dataset1/stateFate_inVitro_metadata.txt.gz", sep='\t')
cell_id = pd.read_csv("/Users/apple/Desktop/KB/Dataset1/stateFate_inVitro_neutrophil_monocyte_trajectory.txt.gz", sep='\t')
print("meta_df.shape:" ,meta_df.shape, "; cell_id.shape:", cell_id.shape )

# find the cells 
cell_indices = cell_id['Cell index']

# Use these indices to select rows from 'meta_df'
filtered_meta_df = meta_df.loc[cell_indices].copy()
filtered_meta_df["Lib_Cellbarcode"] = filtered_meta_df['Library'].astype(str) + "_" + filtered_meta_df['Cell barcode'].astype(str)

# Display the filtered dataframe
print("filtered_meta_df.shape: ", filtered_meta_df.shape)

# fiter adata and embedding
adata_train.obs["Lib_Cellbarcode"] = adata_train.obs['Library'].astype(str) + "_" + adata_train.obs['Cell barcode'].astype(str)
adata_test.obs["Lib_Cellbarcode"] = adata_test.obs['Library'].astype(str) + "_" + adata_test.obs['Cell barcode'].astype(str)

print(len(adata_train.obs["Lib_Cellbarcode"].unique()), len(adata_test.obs["Lib_Cellbarcode"].unique()))

meta_df.shape: (130887, 8) ; cell_id.shape: (96373, 1)
filtered_meta_df.shape:  (96373, 9)
17054 2177


#### Train data

In [8]:
# Step 1: Find the shared 'Lib_Cellbarcode' values
shared_barcodes_train = np.intersect1d(filtered_meta_df['Lib_Cellbarcode'], adata_train.obs['Lib_Cellbarcode'])
print("***")
print("len(shared_barcodes_train): ", len(shared_barcodes_train))
print("***")

# Step 2: Filter 'adata_train' based on the shared barcodes
adata_train_filter = adata_train[adata_train.obs['Lib_Cellbarcode'].isin(shared_barcodes_train)].copy()

# Step 3: Filter 'X_train' based on the same shared barcodes
# Find the indices of the shared barcodes in 'adata_train.obs'
indices = adata_train.obs['Lib_Cellbarcode'].isin(shared_barcodes_train).values

# Use these indices to filter 'X_train'
X_train_filter = X_train[indices]
adata_train_filter.shape, X_train_filter.shape


***
len(shared_barcodes_train):  11462
***


((11462, 2000), (11462, 10))

#### Test data

In [9]:
# Step 1: Find the shared 'Lib_Cellbarcode' values
shared_barcodes_test = np.intersect1d(filtered_meta_df['Lib_Cellbarcode'], adata_test.obs['Lib_Cellbarcode'])
print("***")
print("len(shared_barcodes_test): ", len(shared_barcodes_test))
print("***")

# Step 2: Filter 'adata_train' based on the shared barcodes
adata_test_filter = adata_test[adata_test.obs['Lib_Cellbarcode'].isin(shared_barcodes_test)].copy()

# Step 3: Filter 'X_train' based on the same shared barcodes
# Find the indices of the shared barcodes in 'adata_train.obs'
indices = adata_test.obs['Lib_Cellbarcode'].isin(shared_barcodes_test).values

# Use these indices to filter 'X_train'
X_test_filter = X_test[indices]
adata_test_filter.shape, X_test_filter.shape


***
len(shared_barcodes_test):  1455
***


((1455, 2000), (1455, 10))

### Generate the Composition pair

In [10]:
def composit_pair_gen(X_train, adata_train):
    # train_labels = adata_train.obs["clone_id"].to_numpy()
    # print("train_labels.shape:", train_labels.shape)

    ### generate the labels
    adata_6 = adata_train[adata_train.obs["time_info"] == 6.0]
    print("adata_6.shape:", adata_6.shape)

    # Initialize an empty dictionary to store the cell type distributions
    clone_state_info_distribution = {}

    # Get the unique lineage
    unique_clone_ids = adata_6.obs["clone_id"].unique()

    # Loop through each unique lineage
    for clone_id in unique_clone_ids:
        # Filter the data to get only rows with the current clone_id
        clone_data = adata_6.obs[adata_6.obs["clone_id"] == clone_id]
        
        # Get the distribution of cell types in the current clone_id
        state_info_distribution = clone_data["state_info"].value_counts(normalize=True)
        
        # Round each percentage to 4 decimal places and convert to a dictionary
        state_info_distribution = state_info_distribution.round(4).to_dict()
        
        # Store the rounded distribution in the main dictionary
        clone_state_info_distribution[clone_id] = state_info_distribution

    # Print the resulting dictionary for verification
    i = 0
    for clone_id, distribution in clone_state_info_distribution.items():
        print(f"Clone ID: {clone_id}, Cell Type Distribution: {distribution}")
        i+=1
        if i ==3:
            break


    # Step 1: Get embeddings for Day 12 cells
    day2_mask = adata_train.obs["time_info"] == 2.0
    X_train_day2 = X_train[day2_mask.values] 
    print(f"Day 12 embeddings shape: {X_train_day2.shape}")

    # Step 2: Get the clone labels for Day 12 cells
    clone_labels_day2 = adata_train.obs.loc[day2_mask, "clone_id"].to_numpy()

    # Step 3: Initialize y_train_prob matrix to store the probabilities
    # n_classes = len(adata_train.obs["state_info"].unique())
    y_train_prob = np.zeros((X_train_day2.shape[0], 3))

    # Step 4: Assign the distributions from clone_state_info_distribution to each cell based on its clone_id
    for i, clone_id in enumerate(clone_labels_day2):
        if clone_id in clone_state_info_distribution:
            # Get the distribution for the clone
            distribution = clone_state_info_distribution[clone_id]
            
            # Ensure the order of cell types matches 'Undifferentiated', 'Monocyte', 'Neutrophil', 'Erythroid'
            y_train_prob[i, 0] = distribution.get('Undifferentiated', 0)  # Default to 0 if not present
            y_train_prob[i, 1] = distribution.get('Monocyte', 0)  # Default to 0 if not present
            y_train_prob[i, 2] = distribution.get('Neutrophil', 0)  # Default to 0 if not present
            # y_train_prob[i, 3] = distribution.get('Erythroid', 0)

            # y_train_prob[i, 0] = distribution.get('iEP', 0)  # Default to 0 if not present
            # y_train_prob[i, 1] = distribution.get('Ambiguous', 0)  # Default to 0 if not present
            # y_train_prob[i, 2] = distribution.get('Fibroblast', 0)  # Default to 0 if not present

    # Print the shape and first few examples of y_train_prob
    print(f"y_train_prob shape: {y_train_prob.shape}")
    print(f"First 5 rows of y_train_prob:\n{y_train_prob[:5]}")


    X_train_day2 = torch.tensor(X_train_day2, dtype=torch.float32)

    # Example soft labels: 5 samples, each with a probability distribution over 3 classes
    y_train_prob = torch.tensor(y_train_prob, dtype=torch.float32)

    return X_train_day2, y_train_prob

In [11]:
adata_test_filter.obs["state_info"].unique()

['Undifferentiated', 'Monocyte', 'Neutrophil']
Categories (3, object): ['Monocyte', 'Neutrophil', 'Undifferentiated']

In [12]:
X_train_day2, y_train_prob = composit_pair_gen(X_train_filter, adata_train_filter)
X_test_day2, y_test_prob = composit_pair_gen(X_test_filter, adata_test_filter)

adata_6.shape: (8104, 2000)
Clone ID: 1261, Cell Type Distribution: {'Neutrophil': 0.6642, 'Monocyte': 0.1971, 'Undifferentiated': 0.1387, 'Erythroid': 0.0}
Clone ID: 2370, Cell Type Distribution: {'Monocyte': 0.4792, 'Neutrophil': 0.2812, 'Undifferentiated': 0.2396, 'Erythroid': 0.0}
Clone ID: 292, Cell Type Distribution: {'Monocyte': 0.6988, 'Neutrophil': 0.1687, 'Undifferentiated': 0.1325, 'Erythroid': 0.0}
Day 12 embeddings shape: (215, 10)
y_train_prob shape: (215, 3)
First 5 rows of y_train_prob:
[[0.1387 0.1971 0.6642]
 [0.0851 0.1596 0.7553]
 [0.0851 0.1596 0.7553]
 [0.2    0.4143 0.3857]
 [0.6    0.32   0.08  ]]
adata_6.shape: (1018, 2000)
Clone ID: 1261, Cell Type Distribution: {'Neutrophil': 0.7778, 'Monocyte': 0.1667, 'Undifferentiated': 0.0556}
Clone ID: 2370, Cell Type Distribution: {'Monocyte': 0.6364, 'Neutrophil': 0.3636, 'Undifferentiated': 0.0}
Clone ID: 292, Cell Type Distribution: {'Monocyte': 0.7778, 'Neutrophil': 0.1111, 'Undifferentiated': 0.1111}
Day 12 embeddi

In [13]:
import torch
import torch.nn as nn

class SoftLabelNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SoftLabelNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)  # Raw output before softmax
        return out


class Trainer:
    def __init__(self, model, optimizer, criterion, X_train, y_train_prob, num_epochs=10000, lr=0.01):
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.X_train = X_train
        self.y_train_prob = y_train_prob
        self.num_epochs = num_epochs
        self.lr = lr

    def train(self):
        for epoch in range(self.num_epochs):
            # Forward pass
            outputs = self.model(self.X_train)
            
            # Apply log_softmax to get log probabilities
            outputs_log_prob = torch.log_softmax(outputs, dim=1)
            
            # Calculate the KL divergence loss
            loss = self.criterion(outputs_log_prob, self.y_train_prob)
            
            # Backward pass and optimization
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            # Print loss every 50 epochs
            if (epoch+1) % 100 == 0:
                print(f'Epoch [{epoch+1}/{self.num_epochs}], Loss: {loss.item():.4f}')
    
    def predict(self, X_test):
        # Set the model to evaluation mode
        self.model.eval()
        with torch.no_grad():
            outputs = self.model(X_test)
            # Apply softmax to get predicted probabilities
            probabilities = torch.softmax(outputs, dim=1)
        return probabilities

    def evaluate_kl_divergence(self, X_test, y_test_prob):
        # Get the predicted log probabilities
        predicted_probabilities_log = torch.log_softmax(self.model(X_test), dim=1)
        
        # Calculate KL divergence between predicted and true probabilities
        kl_divergence = self.criterion(predicted_probabilities_log, y_test_prob)
        return kl_divergence.item()



In [49]:
# Initialize the model, optimizer, and KLDivLoss function
input_size = X_train_day2.shape[1]
hidden_size = 10
output_size = y_train_prob.shape[1]

model = SoftLabelNN(input_size, hidden_size, output_size)
criterion = nn.KLDivLoss(reduction='batchmean')  # KLDivLoss for comparing distributions
optimizer = optim.AdamW(model.parameters(), lr=0.01)

# Instantiate the Trainer class and start training
trainer = Trainer(model, optimizer, criterion, X_train_day2, y_train_prob, num_epochs=500)
trainer.train()

kl_divergence = trainer.evaluate_kl_divergence(X_test_day2, y_test_prob)
print(f"KL Divergence on test set: {kl_divergence:.4f}")

Epoch [100/500], Loss: 0.5014
Epoch [200/500], Loss: 0.4905
Epoch [300/500], Loss: 0.4889
Epoch [400/500], Loss: 0.4874
Epoch [500/500], Loss: 0.4853
KL Divergence on test set: 0.7824


In [16]:
X_test_day2.shape, y_test_prob.shape, X_train_day2.shape, y_train_prob.shape

(torch.Size([27, 10]),
 torch.Size([27, 3]),
 torch.Size([215, 10]),
 torch.Size([215, 3]))