In [None]:
import pickle
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, TensorDataset
import h5py
import numpy as np
import os
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import anndata as ad
from sklearn.model_selection import train_test_split
from scipy.stats import poisson
from transformers import get_cosine_schedule_with_warmup
import psutil
import math

In [None]:
p = psutil.Process(os.getpid())
p.cpu_affinity([0, 1])

In [None]:
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['NUMEXPR_NUM_THREADS'] = '1'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '1'

In [None]:
torch.cuda.empty_cache()
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

In [None]:
class GeneDataset(Dataset):
    def __init__(self, path, num_patches):
        """
        Custom dataset for loading synthetic spatial transcriptomics data.

        Args:
            path (str): The base directory where the data is stored.
            num_patches (int): Number of patches (datasets) to load.
        """
        self.path = path
        self.num_patches = num_patches
        self.data = []
        self.load_data()

    def load_data(self):
        """
        Load the synthetic data from the specified path.
        """
        # for i in range(self.num_patches):
        #     # Load the .npz file containing the coordinates and gene indices
        #     data = np.load(os.path.join(self.path, 'data', f'gene_coordinates_{i}.npz'))
        #     x_coords = torch.tensor(data['x_coords'], dtype=torch.long)
        #     y_coords = torch.tensor(data['y_coords'], dtype=torch.long)
        #     gene_indices = torch.tensor(data['gene_indices'], dtype=torch.long)
        #     # Load the corresponding cell positions from the pickle file
        #     # with open(os.path.join(self.path, 'cell_positions', f'cell_{i}.pkl'), 'rb') as f:
        #     #     cell_positions = pickle.load(f)

        #     # Store the data as a tuple
        #     self.data.append((x_coords, y_coords, gene_indices))#cell_positions))
            
        data = np.load('/data/aram/Xenium/output-XETG00056__0004637__Region_1__20230718__204100/code/data/synthetic_data_2/gene_coordinates_test_cells1.npz')
        x_coords = torch.tensor(data['x_coords'], dtype=torch.long)
        y_coords = torch.tensor(data['y_coords'], dtype=torch.long)
        gene_indices = torch.tensor(data['gene_indices'], dtype=torch.long)
        self.data.append((x_coords, y_coords, gene_indices))
        

    def __len__(self):
        """
        Return the number of patches (datasets).
        """
        return self.num_patches

    def __getitem__(self, idx):
        """
        Get a specific dataset item by index.
        Args:
            idx (int): The index of the dataset item.

        Returns:
            A tuple containing x-coordinates, y-coordinates, gene indices, and cell positions.
        """
        return self.data[idx]

In [None]:
def standardize_coords(coords):
    coords = coords.float()
    mean = torch.mean(coords)
    std = torch.std(coords)
    return (coords - mean) / std

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        
        # Create the positional encodings
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, pos_indices):
        """
        Arguments:
            pos_indices: Tensor of shape ``[seq_len]``, containing the positional indices.
        Returns:
            Tensor of shape ``[seq_len, d_model]`` containing the corresponding positional encodings.
        """
        return self.pe[pos_indices]

In [None]:
class GeneEmbedding(nn.Module):
    def __init__(self, n_genes, embedding_dim):
        super(GeneEmbedding, self).__init__()
        self.embedding = nn.Embedding(num_embeddings=n_genes, embedding_dim=embedding_dim)

    def forward(self, gene_indices):
        return self.embedding(gene_indices)

In [None]:
class TransformerModel(nn.Module):
    def __init__(self, embedding_dim, nhead, num_layers, dim_feedforward):
        super(TransformerModel, self).__init__()
        self.transformer_encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            batch_first=True 
        )
        self.transformer_encoder = nn.TransformerEncoder(
            self.transformer_encoder_layer, num_layers=num_layers
        )

    def forward(self, x):
        return self.transformer_encoder(x)

In [None]:
class STEArgmax(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        # Perform the argmax to get the index of the maximum value
        indices = torch.argmax(input, dim=-1)
        # Create a one-hot encoded tensor with the same shape as the input
        one_hot_output = F.one_hot(indices, num_classes=input.shape[-1]).float()
        ctx.save_for_backward(input)
        return one_hot_output

    @staticmethod
    def backward(ctx, grad_output):
        # In the straight-through estimator, we return the gradient as-is
        return grad_output,  # Return as a tuple


In [None]:
class FullPipeline(nn.Module):
    def __init__(self, height_dim, width_dim, gene_dim, n_genes, n_cells, nhead, num_layers, dim_feedforward):
        super(FullPipeline, self).__init__()
        embedding_dim = height_dim #+ width_dim + gene_dim
        self.d_model= height_dim
        self.gene_embedding = GeneEmbedding(n_genes, gene_dim)
        self.transformer_model = TransformerModel(embedding_dim, nhead, num_layers, dim_feedforward)
        self.ffn = nn.Sequential(
            nn.Linear(embedding_dim, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, n_cells)
        )
        
        self.pe1 = PositionalEncoding(height_dim)
        self.pe2 = PositionalEncoding(width_dim)
        self.temperature = 1

    def forward(self, height_coords, width_coords, gene_indices):
        positional_height = self.pe1(height_coords)#.to(device)
        positional_width = self.pe2(width_coords)#.to(device)
        gene_encoded = self.gene_embedding(gene_indices)#.to(device)
        
        encoded_input = positional_height + positional_width + gene_encoded
        
        transformer_output = self.transformer_model(encoded_input)
        
        output = self.ffn(transformer_output)
        output = nn.Softmax(dim=2)(output / self.temperature)
        
        # final_output = STEArgmax.apply(output).squeeze(0)
        return output

In [None]:
def pairwise_distance_loss(height_coords, width_coords, output_matrix):
    """
    Calculate the differentiable pairwise distance loss based on height and width coordinates.
    
    Args:
        height_coords (torch.Tensor): Tensor of shape (n_samples,) containing standardized height coordinates.
        width_coords (torch.Tensor): Tensor of shape (n_samples,) containing standardized width coordinates.
        output_matrix (torch.Tensor): Tensor of shape (n_samples, 5) representing the softmax output from the model.

    Returns:
        torch.Tensor: The computed pairwise distance loss (scalar).
    """
    device = output_matrix.device
    height_coords = height_coords.to(device)
    width_coords = width_coords.to(device)

    coords = torch.stack((height_coords.squeeze(0), width_coords.squeeze(0)), dim=1).float()  # Shape: (n_samples, 2)
    pairwise_distances = torch.cdist(coords, coords)
    # print(output_matrix)  
    mask = torch.matmul(output_matrix, output_matrix.T)
    # print(mask)
    masked_distances = mask * pairwise_distances
    total_loss = masked_distances.sum()
    
    # prob_outer = output_matrix.unsqueeze(2) * output_matrix.unsqueeze(1)  # Shape: (1000, 5, 1000)

    # total_loss = torch.sum(pairwise_distances.unsqueeze(1) * prob_outer)
    n_samples = height_coords.size(0)
    normalized_loss = total_loss / n_samples
    
    return normalized_loss

In [None]:
with open('/data/aram/Xenium/output-XETG00056__0004637__Region_1__20230718__204100/code/data/synthetic_data_2/cell_type_profiles.pkl', 'rb') as file:
    cell_type_profiles_np = pickle.load(file)
    cell_type_profiles = torch.tensor(cell_type_profiles_np, dtype=torch.float32).to(device)

def one_hot_encode(gene_indices, n_genes):
    """
    Converts gene indices into a one-hot encoded matrix.
    
    Args:
        gene_indices (torch.Tensor): Tensor of shape (n_samples,) containing gene indices.
        n_genes (int): The number of unique genes.

    Returns:
        torch.Tensor: One-hot encoded matrix of shape (n_samples, n_genes).
    """
    return F.one_hot(gene_indices, num_classes=n_genes).float()

def compute_log_likelihood(expression_vector, cell_type_profile, umi_count):
    lambdas = umi_count * cell_type_profile
    # Compute Poisson log-likelihood: log P(v_i | λ_i)
    # Poisson log-likelihood: v_i * log(λ_i) - λ_i - log(v_i!)
    log_likelihood = expression_vector * torch.log(lambdas + 1e-8) - lambdas - torch.lgamma(expression_vector + 1)
    
    # Sum the log-probabilities across all genes to get the total log-likelihood
    total_log_likelihood = torch.sum(log_likelihood)
    
    average_log_likelihood = total_log_likelihood / umi_count
    
    return average_log_likelihood
    

def likelihood_model(expression_vector, temperature):
    """
    Takes in each row (cell) and computes a score.

    Args:
        expression_vector (torch.Tensor): Tensor of shape (n_genes,) representing gene counts for a cell.

    Returns:
        torch.Tensor: Scalar likelihood score for the cell.
    """
    observed_umi_count = torch.sum(expression_vector)

    log_likelihoods = torch.stack([
        compute_log_likelihood(expression_vector, cell_type_profile, observed_umi_count)
        for cell_type_profile in cell_type_profiles
    ])
    log_likelihoods_ = log_likelihoods *1
    
    # return torch.max(log_likelihoods)
    prob = F.softmax(log_likelihoods_/temperature, dim=0)
    # print(log_likelihoods, "log_likelihoods")
    # print(prob, "prob")
    # max_index = torch.argmax(prob)
    # max_value = prob[max_index]
    # print(f"Cell type: {max_index.item()}, Prob: {max_value.item()}")
    # print(log_likelihoods, "log_likelihoods")
    # print(prob*log_likelihoods, "prob*log_likelihoods")
    summed_likelihood = torch.sum(prob*log_likelihoods_)
    # print(summed_likelihood, "summed_likelihood")
    return -summed_likelihood#torch.sum(prob * torch.log(prob + 1e-10), dim=0).mean()

def likelihood_loss(output_matrix, gene_indices, n_genes, temperature):
    """
    Calculate the likelihood loss based on the gene index one-hot encoding.
    
    Args:
        output_matrix (torch.Tensor): Tensor of shape (n_samples, 5) representing the softmax output from the model.
        gene_indices (torch.Tensor): Tensor of shape (n_samples,) containing gene indices.
        n_genes (int): The number of unique genes.

    Returns:
        torch.Tensor: The computed likelihood loss (scalar).
    """
    print(temperature, "temperature")

    one_hot_genes = one_hot_encode(gene_indices, n_genes)  # Shape: (n_samples, n_genes)
    cell_by_gene = torch.matmul(output_matrix.T, one_hot_genes).squeeze(0)  # Shape: (5, n_genes)
    total_loss = 0.0
    for cell in cell_by_gene:
        total_loss += likelihood_model(cell, temperature)

    return total_loss

In [None]:
def background_loss(output_matrix):
    """
    Calculate the background loss, which is the number of transcripts assigned to the last cell class.
    
    Args:
        output_matrix (torch.Tensor): Tensor of shape (n_samples, 5) representing the softmax output from the model.

    Returns:
        torch.Tensor: The background loss (scalar), which is the sum of all probabilities assigned to the last cell class.
    """
    last_cell_probs = output_matrix[:, -1]  # Shape: (n_samples,)
    #import pdb; pdb.set_trace()
    total_background_loss = last_cell_probs.sum()
    return total_background_loss

In [None]:
def train_model(model, dataloader, optimizer, scheduler, lambda_, n_genes, num_epochs=5):
    temperature = 1.0
    min_temperature = 0.001
    N = 1
    r = 1e-5  
    global_steps = 0
    model.train() 
    global_steps = 0
    for epoch in range(num_epochs):
        global_steps += 1
        epoch_loss = 0
        epoch_loss1 = 0
        epoch_loss2 = 0
        epoch_loss3 = 0
        for batch in dataloader:
            
            # if global_steps % N == 0:
            #     temperature = max(min_temperature, math.exp(-r * global_steps*1000))
            temperature = 1
                
            height_coords, width_coords, gene_indices = [x.to(device) for x in batch]
            optimizer.zero_grad()

            output_matrix = model(height_coords, width_coords, gene_indices)
            output_matrix = output_matrix.squeeze(0)
            # output_matrix.retain_grad()
            
            # cells_matrix = output_matrix[:, :-1]
            pairwise_distance = 0#pairwise_distance_loss(height_coords, width_coords, output_matrix)
            likelihood = likelihood_loss(output_matrix, gene_indices, n_genes, temperature)
            background = 0#background_loss(output_matrix)
            
            loss1 = lambda_[0]*pairwise_distance
            loss2 = lambda_[1]*likelihood 
            loss3 = lambda_[2]*background

            total_loss = loss1 + loss2 + loss3
            
            # total_loss = total_loss/1000
    
            total_loss.backward()
            # print(output_matrix.grad)
            
            optimizer.step()
            scheduler.step()
            epoch_loss += total_loss.item()
            epoch_loss1 += loss1
            epoch_loss2 += loss2
            epoch_loss3 += loss3

        print(f'Epoch [{epoch + 1}/{num_epochs}], total Loss: {epoch_loss:.4f}')
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss 1: {epoch_loss1:.4f}')
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss 2: {epoch_loss2:.4f}')
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss 3: {epoch_loss3:.4f}')
    return output_matrix, height_coords, width_coords, gene_indices

In [None]:
path = "/data/aram/Xenium/output-XETG00056__0004637__Region_1__20230718__204100/code/data/synthetic_data_2"
num_patches = 1
lambda_ = (0.05, 100, 0)
n_genes = 30

dataset = GeneDataset(path, num_patches)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

model = FullPipeline(
height_dim=256, width_dim=256, gene_dim=256, n_genes=30, n_cells=5, nhead=2, num_layers=6, dim_feedforward=500).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

num_warmup_steps = 100  
num_training_steps = 1000  

scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps,
    num_cycles=0.5 
)

output_matrix, height_coords, width_coords, gene_indices = train_model(model, dataloader, optimizer, scheduler, lambda_, n_genes, num_epochs=2000)

In [None]:
temperature = 0.01
one_hot_genes = one_hot_encode(gene_indices, n_genes)  # Shape: (n_samples, n_genes)
cell_by_gene = torch.matmul(output_matrix.T, one_hot_genes).squeeze(0)  # Shape: (5, n_genes)
total_loss = 0.0
cell_num = -1

loss_values = {}

for cell in cell_by_gene:
    cell_num +=1
    print(cell_num, "cell_num")
    observed_umi_count = torch.sum(cell)

    log_likelihoods = torch.stack([
        compute_log_likelihood(cell, cell_type_profile, observed_umi_count)
        for cell_type_profile in cell_type_profiles
    ])
    prob = F.softmax(log_likelihoods/temperature, dim=0)
    print(prob,"prob")
    print(log_likelihoods,"log_likelihoods")
    summed_likelihood = torch.sum(prob*log_likelihoods)
    print(summed_likelihood,"summed_likelihood")
    loss_values[cell_num] = -summed_likelihood
    total_loss -= summed_likelihood#torch.sum(prob * torch.log(prob + 1e-10), dim=0).mean()
print(total_loss,"total_loss")

In [None]:
sum(loss_values.values())

In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns
import torch
import torch.nn.functional as F
import numpy as np

# Ensure unique genes are identified and a color map is created
unique_genes = np.unique(gene_indices.cpu())
n_genes = len(unique_genes)
cmap = ListedColormap(sns.color_palette("tab20", n_genes))

# Identify the most likely cell each transcript belongs to
indices = torch.argmax(output_matrix, dim=-1)
one_hot_output = F.one_hot(indices, num_classes=output_matrix.shape[-1]).float()
transcript_by_cell = one_hot_output.squeeze(0)

# Number of cells
n_cells = transcript_by_cell.shape[1]

# Create a list to hold patches of each cell
cell_patches = [np.zeros((100, 100)) for _ in range(n_cells)]

# Populate cell patches with gene indices
for i in range(len(gene_indices[0, :])):
    cell_idx = transcript_by_cell[i].nonzero(as_tuple=True)[0].item()
    x, y = height_coords[0, i].item(), width_coords[0, i].item()
    gene = gene_indices[0, i].item()

    if 0 <= x < 100 and 0 <= y < 100:
        cell_patches[cell_idx][y, x] = gene

# Determine grid size based on the number of cells
grid_size = int(np.ceil(np.sqrt(n_cells)))
fig, axs = plt.subplots(grid_size, grid_size, figsize=(15, 15))

# Loop through each cell and plot it in the grid
for cell_idx in range(n_cells):
    row = cell_idx // grid_size
    col = cell_idx % grid_size
    
    sns.heatmap(cell_patches[cell_idx], cmap=cmap, cbar=False, vmin=0, vmax=n_genes - 1, ax=axs[row, col])
    loss = loss_values[cell_idx]  # Default to 0 if no loss value
    axs[row, col].set_title(f"Cell {cell_idx} (Loss: {loss:.4f})")
    axs[row, col].axis('off')

# Remove empty subplots if n_cells doesn't fill the grid
for i in range(n_cells, grid_size * grid_size):
    fig.delaxes(axs.flatten()[i])

plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns

unique_genes = np.unique(gene_indices.cpu())
n_genes = len(unique_genes)

cmap = ListedColormap(sns.color_palette("tab20", n_genes))

indices = torch.argmax(output_matrix, dim=-1)
one_hot_output = F.one_hot(indices, num_classes=output_matrix.shape[-1]).float()
transcript_by_cell = one_hot_output.squeeze(0)

# Number of cells
n_cells = transcript_by_cell.shape[1]

cell_patches = [np.zeros((100, 100)) for _ in range(n_cells)]

for i in range(len(gene_indices[0,:])):
    # Get the cell that this transcript is associated with
    cell_idx = transcript_by_cell[i].nonzero(as_tuple=True)[0].item()
    x, y = height_coords[0,i].item(), width_coords[0,i].item()
    gene = gene_indices[0,i].item()
    
    if 0 <= x < 100 and 0 <= y < 100:
        cell_patches[cell_idx][y, x] = gene

for cell_idx in range(n_cells):
    sns.heatmap(cell_patches[cell_idx], cmap=cmap, cbar=True, vmin=0, vmax=n_genes - 1)  # Added colormap for better visualization
    plt.show()

In [None]:
for name, param in model.named_parameters():
    if param.grad is not None:
        print(f"Gradient for {name}:")
        print(param.grad)


In [None]:
# for name, param in model.named_parameters():
#     if param.requires_grad:
#         plt.hist(param.detach().cpu().numpy(), bins=50)
#         plt.title(f'Weight distribution for {name}')
#         plt.show()

In [None]:
import matplotlib.pyplot as plt

# Open and read the text file
with open('/data/aram/Xenium/output-XETG00056__0004637__Region_1__20230718__204100/loss.txt', 'r') as file:
    data = file.read()

# Parsing the data
epochs = []
total_loss = []
loss_1 = []
loss_2 = []

lines = data.strip().split('\n')
for i in range(1, len(lines), 5):
    epoch_info = lines[i].split(',')
    epoch_num = int(epoch_info[0].split('[')[1].split('/')[0])
    epochs.append(epoch_num)
    
    total_loss_val = float(epoch_info[1].split(': ')[1])
    total_loss.append(total_loss_val)
    
    loss_1_val = float(lines[i+1].split(': ')[1])
    loss_1.append(loss_1_val)
    
    loss_2_val = float(lines[i+2].split(': ')[1])
    loss_2.append(loss_2_val)

# Plotting the losses
plt.figure(figsize=(10, 6))
plt.plot(epochs, total_loss, label='Total Loss')
plt.plot(epochs, loss_1, label='Loss 1')
plt.plot(epochs, loss_2, label='Loss 2')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Losses over Epochs')
plt.legend()
plt.grid(True)
plt.show()
