In [15]:
import torch
import pandas as pd
from torch import nn
import scanpy as sc
import numpy as np

In [2]:
class WeightedPolynomialRegression(nn.Module):
    def __init__(self, num_genes, num_cells, num_categories):
        super(WeightedPolynomialRegression, self).__init__()
        # Initialize G with smaller values
        self.G = nn.Parameter(torch.randn(num_genes, 4, dtype=torch.float64) * 0.1)
        # Initialize K with positive values using softmax
        self.K = nn.Parameter(torch.ones(num_categories, num_genes, dtype=torch.float64))
        with torch.no_grad():
            self.K.data = F.softmax(self.K.data, dim=1)
    
    def forward(self, N):
        N = N.squeeze()
        # Normalize input
        N_normalized = (N - N.mean()) / (N.std() + 1e-8)
        N_poly = torch.stack([
            torch.ones_like(N_normalized),
            N_normalized,
            N_normalized**2,
            N_normalized**3
        ])
        predictions = self.G @ N_poly
        return predictions

def train_model_with_categories(M, N, J, epochs=1000, lr=0.01):
    num_genes, num_cells = M.shape
    num_categories = J.shape[1]
    
    # Normalize M
    M_mean = M.mean(dim=1, keepdim=True)
    M_std = M.std(dim=1, keepdim=True) + 1e-8
    M_normalized = (M - M_mean) / M_std
    
    model = WeightedPolynomialRegression(num_genes, num_cells, num_categories)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=50, min_lr=1e-5
    )
    
    # Pre-compute masks for each category
    category_masks = [J[:, i] == 1 for i in range(num_categories)]
    
    # For early stopping
    best_loss = float('inf')
    patience = 100
    patience_counter = 0
    
    for epoch in range(epochs):
        optimizer.zero_grad()
        predictions = model(N)
        
        # Vectorized loss calculation with clipping
        total_loss = torch.tensor(0., dtype=torch.float64)
        category_gene_losses = torch.zeros(num_categories, num_genes)
        
        for cat_idx, cat_mask in enumerate(category_masks):
            if cat_mask.any():
                # Compute all gene losses for this category at once
                gene_losses = torch.mean((predictions[:, cat_mask] - 
                                        M_normalized[:, cat_mask])**2, dim=1)
                
                # Clip extremely large losses
                gene_losses = torch.clamp(gene_losses, max=10.0)
                
                # Use softmax for weights
                weights = F.softmax(model.K[cat_idx], dim=0)
                weighted_losses = gene_losses * weights
                total_loss += weighted_losses.sum()
                
                category_gene_losses[cat_idx] = gene_losses.detach()
        
        # Add L2 regularization
        l2_reg = torch.norm(model.G)**2 * 1e-4
        total_loss = total_loss + l2_reg
        
        # Backward pass with gradient clipping
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        # Update learning rate
        scheduler.step(total_loss)
        
        # Early stopping check
        if total_loss.item() < best_loss:
            best_loss = total_loss.item()
            patience_counter = 0
        else:
            patience_counter += 1
            
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
        
        # Vectorized K update
        if (epoch + 1) % 100 == 0:
            with torch.no_grad():
                avg_losses = category_gene_losses.mean(dim=1, keepdim=True)
                update_mask = category_gene_losses > avg_losses
                
                # Smoother updates
                model.K.data[update_mask] *= 0.95
                model.K.data[~update_mask] *= 1.05
                
                # Use softmax for normalization
                for cat_idx in range(num_categories):
                    model.K.data[cat_idx] = F.softmax(model.K.data[cat_idx], dim=0)
            
            print(f'Epoch [{epoch+1}/{epochs}], Loss: {total_loss.item():.4f}')
    
    return model

# Example usage
if __name__ == "__main__":
    M = torch.tensor([
        [1.0, 2.0, 3.0],
        [2.0, 4.0, 6.0],
        [3.0, 6.0, 9.0],
        [4.0, 8.0, 12.0]
    ], dtype=torch.float64)
    
    N = torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float64)
    
    J = torch.tensor([
        [1, 0],
        [0, 1],
        [1, 0]
    ], dtype=torch.float64)
    
    # Train model
    model = train_model_with_categories(M, N, J)
    
    # Print results
    print("\nTrained parameters G:")
    print(model.G.detach())
    print("\nTrained weights K:")
    print(model.K.detach())

Epoch [100/1000], Loss: 0.1007
Epoch [200/1000], Loss: 0.7750
Epoch [300/1000], Loss: 0.5308
Epoch [400/1000], Loss: 0.1940
Epoch [500/1000], Loss: 0.0939
Epoch [600/1000], Loss: 0.0679
Epoch [700/1000], Loss: 0.0594
Epoch [800/1000], Loss: 0.0239
Epoch [900/1000], Loss: 0.0405
Epoch [1000/1000], Loss: 0.0423

Trained parameters G:
tensor([[ 0.3185,  0.3786,  0.3598, -0.0628],
        [ 0.1232, -0.3977,  3.0732, -0.9521],
        [ 0.3106,  0.2579,  1.1174,  0.0876],
        [ 1.7721,  1.6026,  0.7449, -0.0470]], dtype=torch.float64)

Trained weights K:
tensor([[ 2.3795e-01, -1.6308e-04,  1.5782e-03,  7.6063e-01],
        [-5.0745e-02,  1.3939e+00, -3.0130e-01, -4.1902e-02]],
       dtype=torch.float64)


In [3]:
adata = sc.read("../../processed_data/toy_data/20241116_bone_toy_5000.h5ad")

In [4]:
adata

AnnData object with n_obs × n_vars = 4986 × 1992
    obs: 'orig.ident', 'nCount_originalexp', 'nFeature_originalexp', 'Sample', 'Project', 'Limb.Atlas', 'Organ', 'Tissue', 'Tissue.Specific.', 'Stage', 'Gene.type', 'Treatment', 'Age', 'Age.In.Detail.', 'Machine', 'Species', 'Isolation.approach', 'Digestion', 'Enzymes', 'Bone.Forming.Methods', 'Data.Source', 'Related.Assay', 'Origin', 'nCount_RNA', 'nFeature_RNA', 'paper_label', 'coarse_label', 'scDblFinder_class', 'short_id', 'temp_cluster', 'batch', 'batch_atlas', 'size_factors', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'new_totals_log', 'anno_level_1', 'anno_level_2', 'anno_level_3', 'anno_level_4', 'anno_level_5', 'mix_level_1', 'mix_level_2', 'merge_id_level3', 'merge_id_level4', 'merge_id_level5', 'cellid', 'leiden_clusters_level_1', 'leiden_clusters_level_2', 'leide

In [5]:
geneMat = torch.tensor(adata.X.copy())
dpt = torch.tensor(adata.obs['pred_dpt'].values.astype('float64'))

In [16]:
catogory_str = np.repeat(['A', 'B'], 2493)

In [17]:

catogory = pd.get_dummies(catogory_str)

In [18]:
catogory_tensor = torch.tensor(catogory.values, dtype=torch.float64)

In [19]:
catogory_tensor.shape

torch.Size([4986, 2])

In [20]:
dpt.shape

torch.Size([4986])

In [21]:
geneMat.shape

torch.Size([4986, 1992])

In [22]:
N = torch.tensor([[1.0, 2.0, 3.0]])

In [23]:
model = train_model_with_categories(geneMat.T,dpt,catogory_tensor)

Epoch [100/1000], Loss: 2.3475
Epoch [200/1000], Loss: 4.1322
Epoch [300/1000], Loss: 3.5744
Epoch [400/1000], Loss: 4.1715
Epoch [500/1000], Loss: 4.1677
Epoch [600/1000], Loss: 4.2041
Epoch [700/1000], Loss: 4.1863
Epoch [800/1000], Loss: 36.7641
Epoch [900/1000], Loss: 3.9039
Epoch [1000/1000], Loss: 4.2629


WeightedPolynomialRegression()

## Create simulated datasets

In [24]:
adata_small = sc.pp.subsample(adata,fraction=0.2,copy=True)

In [26]:
cellName = adata_small.obs_names

In [28]:
sample_size = len(cellName) // 2  # 50% of the length
sampled_indices = np.random.choice(len(cellName), size=sample_size, replace=False)
sampled_vector = cellName[sampled_indices]

In [32]:
feature = adata_small.var_names

sample_size = len(feature) // 2  # 50% of the length
sampled_indices = np.random.choice(len(cellName), size=sample_size, replace=False)
sampled_feature = feature[sampled_indices]

In [39]:
permute = adata_small[sampled_vector,sampled_feature].copy()

In [40]:
permuted_cols = np.random.permutation(permute.shape[1])
permuted_matrix = permute[:, permuted_cols].copy()

In [43]:
permuted_matrix.X[0:5,0:10]

array([[-0.1851851 , -0.07246266,  0.6304078 , -0.1696592 , -0.17869857,
        -0.39011335, -0.14705946,  0.14839184, -0.2844765 , -0.00644733],
       [-0.1851851 , -0.07246266, -0.7271087 , -0.1696592 , -0.17869857,
        -0.34820375, -0.14705946, -0.60009336, -0.2844765 , -0.4657444 ],
       [-0.1851851 , -0.07246266, -0.7271087 , -0.1696592 , -0.17869857,
         0.27465189, -0.14705946, -0.60009336,  3.1022573 , -0.4657444 ],
       [-0.1851851 , -0.07246266, -0.7271087 , -0.1696592 , -0.17869857,
         0.22084308, -0.14705946,  0.5309738 , -0.2844765 , -0.4657444 ],
       [-0.1851851 , -0.07246266, -0.7271087 , -0.1696592 , -0.17869857,
         2.908886  , -0.14705946, -0.60009336, -0.2844765 , -0.4657444 ]],
      dtype=float32)

In [44]:
permute.X[0:5,0:10]

array([[-0.4843635 , -0.17446429, -0.42961827, -0.53025097,  0.15386802,
        -0.49738604, -0.2650335 , -0.18074319, -0.13501647,  0.89878887],
       [-0.4843635 , -0.17446429, -0.42961827, -0.08181527, -0.4005429 ,
         0.02528306, -0.2650335 , -0.18074319, -0.13501647, -0.39042246],
       [-0.4843635 , -0.17446429, -0.42961827, -0.53025097, -0.4005429 ,
        -0.49738604, -0.2650335 , -0.18074319, -0.13501647, -0.70098454],
       [-0.4843635 , -0.17446429, -0.42961827, -0.53025097,  1.1521194 ,
        -0.49738604,  2.3524287 , -0.18074319, -0.13501647,  1.4016274 ],
       [-0.4843635 , -0.17446429, -0.42961827, -0.53025097, -0.4005429 ,
        -0.49738604, -0.2650335 , -0.18074319, -0.13501647, -0.70098454]],
      dtype=float32)

In [49]:
adata_small[sampled_vector,sampled_feature].X = permuted_matrix.X.copy()

In [52]:
adata_small.obs["simu"] = "real"

In [55]:
adata_small.obs["simu"][sampled_vector] = "not real"

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  adata_small.obs["simu"][sampled_vector] = "not real"


In [56]:
adata_small.obs["simu"]

Unnamed: 0
BmscChondro_Long_TTAGTCTCACAAAGTA-0-2                                          real
PerichondrialP21_Matsushita_FR3CreCxcl12GfpP21_CTCAGAATCAGAGCTT-1_1-0-2    not real
Septoclasts_Kishor_Pdgfra_TCCCATGTCCTTCACG_3-0-2                               real
Suture2021_Farmer_E17_AGGGATGAGGGTCGAT-3_1-0                                   real
Ablation_Matsushita_abl7con2_GCCTGTTGTCGCTGCA-1_3-0-2                      not real
                                                                             ...   
Bmsc2019_Regev_bm4_TTGCCGTGTACCGTTA-1_4-0-2                                    real
Bmsc2019_Regev_samp1_baryawno:std1_TACTTGTGTCACACGC-0-2                        real
Bmsc2019_Regev_bm4_CGTCAGGTCGCTTGTC-1_4-0-2                                    real
Maxillary_Bian_E14.5_E14.5-707_CCGACAACCAGATCTGCAAGACTA_4-0                not real
BmscEndosteal_Ono_Fgfr3CE_CAAGAGGGTGGATTTC-1_1-0-2                             real
Name: simu, Length: 997, dtype: object

In [57]:
geneMat = torch.tensor(adata_small.X.copy())
dpt = torch.tensor(adata_small.obs['pred_dpt'].values.astype('float64'))
catogory_str= adata_small.obs["simu"]
catogory = pd.get_dummies(catogory_str)
catogory_tensor = torch.tensor(catogory.values, dtype=torch.float64)

In [58]:
model2 = train_model_with_categories(geneMat.T,dpt,catogory_tensor)

Epoch [100/1000], Loss: 2.3433
Epoch [200/1000], Loss: 3.9117
Epoch [300/1000], Loss: 4.2102
Epoch [400/1000], Loss: 4.0641
Epoch [500/1000], Loss: 125.3505
Epoch [600/1000], Loss: 3.9821
Epoch [700/1000], Loss: 5.3102
Epoch [800/1000], Loss: 4.0035
Epoch [900/1000], Loss: 3.9597
Epoch [1000/1000], Loss: 3.8225
