In [2]:
import pandas as pd
import numpy as np
import scanpy as sc
from sklearn.linear_model import LinearRegression
from collections import Counter
from scipy import stats
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
import math
from pathlib import Path
import os
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import torch
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.utils.class_weight import compute_class_weight
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from sklearn.preprocessing import label_binarize
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
from sklearn.preprocessing import label_binarize
from sklearn.metrics import r2_score



In [3]:
## optional
## delete celltypes with cell counts lower than certain amount
## specifing an lower bond is a must
def filter_low_counts(celltype_df, age_df, celltype_col, threshold):
    print("Checking low count cell types...")
    
    celltype_count = Counter(celltype_df[celltype_col])
    for key in celltype_count:
        if threshold == None:
            unique_ages = np.unique(age_df)
            num_groups = (len(unique_ages) + 1) * 100
            if celltype_count[key] < num_groups:
                print(key, " has too low counts")
                celltype_df = celltype_df[celltype_df[celltype_col] != key]
        else:
            if celltype_count[key] < threshold:
                print(key, " has too low counts")
                celltype_df = celltype_df[celltype_df[celltype_col] != key]
    return celltype_df

In [5]:
def get_skewed_count_info(adata, class_col, age_col, age_threshold):
    print("Checking skewed count cell types...")
    
    # Compute the fraction of cells for each age group within each cell ontology class
    group_counts = adata.obs.groupby([class_col, age_col]).size()
    total_counts = adata.obs.groupby([class_col]).size()
    
    # Calculate the fraction of each age group within each class
    class_age_fraction = group_counts / total_counts
    
    # Find the cell classes to filter out based on age distribution
    classes_to_filter = class_age_fraction[class_age_fraction > age_threshold].index.get_level_values(0).unique()
    
    return classes_to_filter

In [6]:
## Read h5ad file 
## and do cell type filtering based on age distribution and cell count thresholds.
def read_and_filter_h5ad(filepath_1, filepath_2 = None, class_col="celltype", age_col="age", age_threshold=0.8, count_threshold=None):
    """Parameters:
    adata: AnnData object
        The Scanpy AnnData object containing single-cell data.
    class_col: str, optional (default: 'celltype')
        The column name in adata.obs representing the cell ontology class.
    age_col: str, optional (default: 'age')
        The column name in adata.obs representing the age of the cells.
    age_threshold: float, optional (default: 0.8)
        The threshold fraction for filtering based on age distribution. If one age group has more than this
        fraction of cells in a class, the class will be filtered out.
    count_threshold: list, optional (default: [100])
        Threshold for filtering cell types based on count. If a single value is provided,
        it filters out cell types with counts lower than this value. If a range is provided,
        it filters out cell types outside this range.
    
    Returns:
    filtered_adata: AnnData object
        The filtered AnnData object with specified cell ontology classes removed based on both criteria."""
    try:
        adata1 = sc.read_h5ad(filepath_1)
        if adata2!=None:
            adata2 = sc.read_h5ad(filepath_2)
            adata1 = adata1.concatenate(adata2)
        adata = adata1
        
        celltype_df = adata.obs[[class_col]].copy()
        age_df = adata.obs[[age_col]].copy()
        
        # Apply the cell count threshold filtering
        celltype_df = filter_low_counts(celltype_df, age_df, class_col, count_threshold)
    
        # Create a filtered AnnData object based on cell count filtering
        filtered_adata = adata[celltype_df.index].copy()
        
        # Identify the skewed classes to filter based on age distribution
        classes_to_filter = get_skewed_count_info(filtered_adata, class_col, age_col, age_threshold)
        
        if len(classes_to_filter):
            print(classes_to_filter[0], " has skewed cell counts")
        # Further filter the AnnData object based on age distribution
        final_filtered_adata = filtered_adata[~filtered_adata.obs[class_col].isin(classes_to_filter)].copy()
        
        return final_filtered_adata
    except Exception as e:
        raise(e)

In [7]:
# Get the current working directory, this should get the path automatically
# hope it works for mac
current_dir = Path.cwd()
print(f"Current working directory: {current_dir}")

# Construct full file paths using the `/` operator
file1 = current_dir / "tabula-muris-senis-facs-processed-official-annotations-Brain_Myeloid.h5ad"
# file2 = current_dir / "tabula-muris-senis-facs-processed-official-annotations-Brain_Non-Myeloid.h5ad"
# Print the paths for verification
print(f"File 1 path: {file1}")
# print(f"File 2 path: {file2}")

# Verify that files exist
assert file1.is_file(), f"File not found: {file1}"
# assert file2.is_file(), f"File not found: {file2}"

# Read and filter the data
# adata = read_and_filter_h5ad(str(file1), str(file2), "cell_ontology_class", "age")

# adata = read_and_filter_h5ad("../Mouse Tabula Muris/tabula-muris-senis-facs-processed-official-annotations-Brain_Myeloid.h5ad", 
#                              "../Mouse Tabula Muris/tabula-muris-senis-facs-processed-official-annotations-Brain_Non-Myeloid.h5ad",
#                              "cell_ontology_class", "age")
# # 

Current working directory: /home/hang/SC_Ageing_Prediction
File 1 path: /home/hang/SC_Ageing_Prediction/tabula-muris-senis-facs-processed-official-annotations-Brain_Myeloid.h5ad


In [67]:
# What does adata look like?

print(type(adata))
print(adata)
print(type(adata.obs))
# print out all obs columns
print(adata.obs.columns)    
print(adata.obsm['X_tsne'].shape)
print(adata.X.T)

<class 'anndata._core.anndata.AnnData'>
AnnData object with n_obs × n_vars = 19154 × 22966
    obs: 'FACS.selection', 'age', 'cell', 'cell_ontology_class', 'cell_ontology_id', 'free_annotation', 'method', 'mouse.id', 'sex', 'subtissue', 'tissue', 'n_genes', 'n_counts', 'louvain', 'leiden', 'batch'
    var: 'n_cells', 'means-0', 'dispersions-0', 'dispersions_norm-0', 'highly_variable-0', 'means-1', 'dispersions-1', 'dispersions_norm-1', 'highly_variable-1'
    obsm: 'X_pca', 'X_tsne', 'X_umap'
<class 'pandas.core.frame.DataFrame'>
Index(['FACS.selection', 'age', 'cell', 'cell_ontology_class',
       'cell_ontology_id', 'free_annotation', 'method', 'mouse.id', 'sex',
       'subtissue', 'tissue', 'n_genes', 'n_counts', 'louvain', 'leiden',
       'batch'],
      dtype='object')
(19154, 2)
  (1, 0)	2.314707
  (24, 0)	4.270282
  (33, 0)	2.0799525
  (87, 0)	3.9693682
  (91, 0)	2.137829
  (111, 0)	4.3484535
  (119, 0)	1.5670886
  (141, 0)	1.7044617
  (381, 0)	5.772542
  (544, 0)	0.6663411
  

In [68]:
# print out all obs columns
for col in adata.obs.columns:
    print(adata.obs[col].value_counts())
    print("***************************************************")
    print("\n")


# print(adata.obs['age'].value_counts())

# question: why there is not 18m here

FACS.selection
Microglia    8642
nan          7394
Neurons      3118
Name: count, dtype: int64
***************************************************


age
3m     7394
18m    6928
24m    4832
Name: count, dtype: int64
***************************************************


cell
A10_B001060                1
O8.MAA000617.3_10_M.1.1    1
O8.MAA000593.3_8_M.1.1     1
O8.MAA000592.3_9_M.1.1     1
O8.MAA000591.3_8_M.1.1     1
                          ..
B7_B003910                 1
B7_B003907                 1
B7_B003899                 1
B6_B003914                 1
P9.MAA001894.3_39_F.1.1    1
Name: count, Length: 19154, dtype: int64
***************************************************


cell_ontology_class
microglial cell     13268
endothelial cell     2232
oligodendrocyte      2094
astrocyte             592
brain pericyte        484
neuron                484
Name: count, dtype: int64
***************************************************


cell_ontology_id
nan           11925
CL:0000129     4394

In [69]:
# data preprocessing

sc.pp.filter_genes(adata, min_cells=5)
sc.pp.filter_cells(adata, min_genes=500)
adata.obs['n_counts'] = np.sum(adata.X, axis=1).A1
adata = adata[adata.obs['n_counts']>=3000]
sc.pp.normalize_per_cell(adata, counts_per_cell_after=1e4) #simple lib size normalization?
adata = sc.pp.filter_genes_dispersion(adata, subset = False, min_disp=.5, max_disp=None, 
                              min_mean=.0125, max_mean=10, n_bins=20, n_top_genes=None, 
                              log=True, copy=True)
sc.pp.log1p(adata)
sc.pp.scale(adata, max_value=10, zero_center=False)
sc.tl.pca(adata,use_highly_variable=True)
sc.pp.neighbors(adata, n_neighbors=18)
sc.tl.louvain(adata, resolution = 1)
sc.tl.umap(adata)
if 'X_umap' not in adata.obsm.keys():
    sc.pp.neighbors(adata, n_neighbors=15, use_rep='X_pca')
    sc.tl.umap(adata)

  adata.obs[key_n_counts] = counts_per_cell


In [70]:
# Index(['FACS.selection', 'age', 'cell', 'cell_ontology_class',
#        'cell_ontology_id', 'free_annotation', 'method', 'mouse.id', 'sex',
#        'subtissue', 'tissue', 'n_genes', 'n_counts', 'louvain', 'leiden',
#        'batch']
# all indexes.



le = LabelEncoder()

celltype_df = pd.DataFrame(adata.obs["cell_ontology_class"])
celltype_df = celltype_df.rename(columns={"cell_ontology_class": "celltype"})


#  Extract and display all unique cell types
unique_cell_types = celltype_df['celltype'].unique()
celltype_encoded = le.fit_transform(adata.obs["cell_ontology_class"])
celltype_df = pd.DataFrame(celltype_encoded, index=adata.obs.index, columns=['celltype_encoded'])
print(celltype_df.values)


age_df = pd.DataFrame(adata.obs["age"])
age_df = age_df.rename(columns={"age": "age"})

gender_df = pd.DataFrame(adata.obs["sex"])
gender_df = gender_df.rename(columns={"sex": "gender"})
#map male to 0, female to 1
gender_mapping = {'male': 0, 'female': 1}
gender_df['gender'] = gender_df['gender'].map(gender_mapping)



ngenes_df = pd.DataFrame(adata.obs["n_genes"])
ngenes_df = ngenes_df.rename(columns={"n_genes": "n_genes"})






[[3]
 [3]
 [3]
 ...
 [2]
 [5]
 [0]]


In [71]:
def clean_age(age_df, substring):
    values = []
    for x in age_df["age"]:
        try:
            # Attempt to strip the substring and convert to integer
            value = int(x.strip(substring))
            values.append(value)
        except ValueError:
            # Handle the case where conversion fails
            warnings.warn(f"Warning: '{x}' could not be converted to an integer.")
            break
    age_df["age"] = values
    return age_df

def get_raw_counts(adata, celltype_df):
    raw_count = pd.DataFrame.sparse.from_spmatrix(adata.X.T, 
                                               index = adata.var_names, 
                                               columns = adata.obs_names).astype(int)
    raw_count = raw_count[list(celltype_df.index)]
    return raw_count

In [72]:
cleaned_age_df = clean_age(age_df, "m")


In [73]:
raw_count = get_raw_counts(adata, celltype_df)
raw_count.values

array([[0, 0, 0, ..., 0, 0, 0],
       [2, 0, 0, ..., 1, 1, 0],
       [0, 0, 0, ..., 1, 2, 0],
       ...,
       [0, 1, 0, ..., 1, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])

In [74]:
print(raw_count[:10])
print(age_df[:10])
print(gender_df[:10])
print(celltype_df[:10])


index          A10_B001060_B009250_S214.mm10-plus-1-0-0  \
index                                                     
0610005C13Rik                                         0   
0610007C21Rik                                         2   
0610007L01Rik                                         0   
0610007N19Rik                                         0   
0610007P08Rik                                         0   
0610007P14Rik                                         0   
0610007P22Rik                                         0   
0610008F07Rik                                         0   
0610009B14Rik                                         0   
0610009B22Rik                                         0   

index          A10_B001061_B009251_S298.mm10-plus-1-0-0  \
index                                                     
0610005C13Rik                                         0   
0610007C21Rik                                         0   
0610007L01Rik                                         0

In [75]:
# now concatenate all the dataframes together
# Transpose raw_count so that cells are the index
raw_count_T = raw_count.T


cells_in_raw = set(raw_count_T.index)
cells_in_age = set(age_df.index)
cells_in_gender = set(gender_df.index)
cells_in_celltype = set(celltype_df.index)

# Find common cells present in all DataFrames

print("Shape of raw_count_T:", raw_count_T.shape)
print("Shape of age_df:", age_df.shape)
print("Shape of gender_df:", gender_df.shape)
print("Shape of celltype_df:", celltype_df.shape)



Shape of raw_count_T: (15692, 21026)
Shape of age_df: (15692, 1)
Shape of gender_df: (15692, 1)
Shape of celltype_df: (15692, 1)


In [76]:
# Optionally, filter DataFrames to include only common cells
common_cells = cells_in_raw  & cells_in_gender & cells_in_celltype & cells_in_age
common_cells_list = list(common_cells)

# Filter DataFrames to include only common cells
raw_count_T = raw_count_T.loc[common_cells_list]
age_df = age_df.loc[common_cells_list]
gender_df = gender_df.loc[common_cells_list]
celltype_df = celltype_df.loc[common_cells_list]

In [77]:
# Concatenate the DataFrames for deep learning
combined_df = pd.concat([raw_count_T,  gender_df, celltype_df, age_df], axis=1)



In [78]:
# validate the combined_df with the original dataframes
# the cell type is marked out since it has been changed to numbers

#  Ensure DataFrames are properly aligned
def prepare_dataframe(df):
    df.index = df.index.astype(str).str.strip().str.lower()
    df.sort_index(inplace=True)
    return df

raw_count_T = prepare_dataframe(raw_count_T)
age_df = prepare_dataframe(age_df)
gender_df = prepare_dataframe(gender_df)
celltype_df = prepare_dataframe(celltype_df)
combined_df = prepare_dataframe(combined_df)

# Ensure that indices match across all DataFrames
common_indices = raw_count_T.index.intersection(age_df.index).intersection(gender_df.index).intersection(celltype_df.index)
raw_count_T = raw_count_T.loc[common_indices]
age_df = age_df.loc[common_indices]
gender_df = gender_df.loc[common_indices]
celltype_df = celltype_df.loc[common_indices]
combined_df = combined_df.loc[common_indices]

#  Compare Gene Expression Data
gene_columns = raw_count_T.columns
combined_gene_data = combined_df[gene_columns]
gene_data_matches = combined_gene_data.equals(raw_count_T)
print("Gene expression data matches:", gene_data_matches)

#  Compare Metadata Columns
age_matches = combined_df['age'].equals(age_df['age'])
print("Age data matches:", age_matches)

gender_matches = combined_df['gender'].equals(gender_df['gender'])
print("Gender data matches:", gender_matches)

# celltype_matches = combined_df['celltype'].equals(celltype_df['celltype'])
# print("Cell type data matches:", celltype_matches)

#  Report Overall Match
all_data_matches = gene_data_matches and age_matches and gender_matches #and celltype_matches
print("\nOverall data matches:", all_data_matches)

# Step 5: Identify and Report Discrepancies
if not all_data_matches:
    if not gene_data_matches:
        # Identify discrepancies in gene expression data
        gene_diff = (combined_gene_data != raw_count_T)
        cells_with_diff = gene_diff.any(axis=1)
        genes_with_diff = gene_diff.any(axis=0)
        print("\nDiscrepancies found in gene expression data.")
        print(f"Number of cells with discrepancies: {cells_with_diff.sum()}")
        print(f"Number of genes with discrepancies: {genes_with_diff.sum()}")
        # List first few discrepancies
        discrepant_cells = cells_with_diff[cells_with_diff].index[:5]
        for cell in discrepant_cells:
            diff_genes = gene_diff.loc[cell][gene_diff.loc[cell]].index.tolist()
            print(f"Cell '{cell}' has discrepancies in genes: {diff_genes[:5]}")
    
    if not age_matches:
        age_diff = combined_df['age'] != age_df['age']
        discrepant_cells = age_diff[age_diff].index.tolist()
        print("\nDiscrepancies found in age data for cells:", discrepant_cells)
    
    if not gender_matches:
        gender_diff = combined_df['gender'] != gender_df['gender']
        discrepant_cells = gender_diff[gender_diff].index.tolist()
        print("\nDiscrepancies found in gender data for cells:", discrepant_cells)
    
    # if not celltype_matches:
    #     celltype_diff = combined_df['celltype'] != celltype_df['celltype']
    #     discrepant_cells = celltype_diff[celltype_diff].index.tolist()
    #     print("\nDiscrepancies found in cell type data for cells:", discrepant_cells)
else:
    print("\nAll data in combined_df matches the original DataFrames.")


Gene expression data matches: True
Age data matches: True
Gender data matches: True

Overall data matches: True

All data in combined_df matches the original DataFrames.


In [79]:
print(combined_df[:10].values)


[[0 0 0 ... 1 3 3]
 [0 0 0 ... 1 3 3]
 [0 0 0 ... 1 3 3]
 ...
 [0 0 3 ... 0 3 3]
 [0 0 0 ... 0 3 3]
 [0 2 0 ... 0 3 3]]


In [80]:
# pre process data for machine learning
# sperate data into training and testing sets

data = combined_df.values


# Separate features and target
x = data[:, :-1]  # All columns except the last one
y = data[:, -1]   # The last column

# Use LabelEncoder to encode labels to 0, 1, 2
le = LabelEncoder()
le.fit(y)
y = le.transform(y)

# optional, balance the class weights since the data is imbalanced
# Compute class weights
class_weights = compute_class_weight('balanced', classes=np.unique(y), y=y)
class_weights = torch.tensor(class_weights, dtype=torch.float)

# Update criterion
class_criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

# Split data into training and test sets (80% training, 20% test)
x_train, x_test, y_train, y_test = train_test_split(
    x, y, test_size=0.2, random_state=42, shuffle=True
)

# Check the shapes of the splits
print("Training features shape:", x_train.shape)
print("Test features shape:", x_test.shape)
print("Training target shape:", y_train.shape)
print("Test target shape:", y_test.shape)

# Convert tensors to NumPy for preprocessing
X_train_np = x_train
X_test_np = x_test


scaler = StandardScaler()
X_train_np[:, :-2] = scaler.fit_transform(X_train_np[:, :-2])
X_test_np[:, :-2] = scaler.transform(X_test_np[:, :-2])

# Convert back to tensors
X_train = torch.from_numpy(X_train_np).float()
X_test = torch.from_numpy(X_test_np).float()

# Reduce gene expression features to 100 components
pca = PCA(n_components=100)
X_train_pca = pca.fit_transform(X_train_np[:, :-2])
X_test_pca = pca.transform(X_test_np[:, :-2])

# Concatenate the last two features back
X_train_pca = np.concatenate((X_train_pca, X_train_np[:, -2:]), axis=1)
X_test_pca = np.concatenate((X_test_pca, X_test_np[:, -2:]), axis=1)

# Convert back to tensors
X_train_pca = torch.from_numpy(X_train_pca).float()
X_test_pca = torch.from_numpy(X_test_pca).float()




Training features shape: (12553, 21028)
Test features shape: (3139, 21028)
Training target shape: (12553,)
Test target shape: (3139,)


In [81]:
print(x_train.shape)

(12553, 21028)


In [82]:
# convert data to tensor dataset
y_train = torch.from_numpy(y_train).long()
y_test = torch.from_numpy(y_test).long()
train_dataset = TensorDataset(X_train_pca, y_train)
test_dataset = TensorDataset(X_test_pca, y_test)

# Data loaders
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

# Define model, loss function, optimizer
input_size = X_train_pca.shape[1]
y = torch.from_numpy(y).long()
num_classes = len(torch.unique(y))


In [95]:
# networks
# use a MLP as a start point

################################################################################
## simple MLP
# class Net(nn.Module):
#     def __init__(self, input_size, num_classes):
#         super(Net, self).__init__()
#         self.fc1 = nn.Linear(input_size, 128)
#         self.dropout1 = nn.Dropout(0.5)
#         self.fc2 = nn.Linear(128, 64)
#         self.dropout2 = nn.Dropout(0.5)
#         self.fc3 = nn.Linear(64, num_classes)
    
#     def forward(self, x):
#         # Weight the last two features more heavily
#         x[:, -2:] *= 5.0  # Adjust the weighting factor as needed
        
#         x = F.relu(self.fc1(x))
#         x = self.dropout1(x)
#         x = F.relu(self.fc2(x))
#         x = self.dropout2(x)
#         x = self.fc3(x)
#         return x
################################################################################    


################################################################################    
# slightly more complex MLP with out weights
# result log: acc: 92% r2: 0.83

# class Net(nn.Module):
#     def __init__(self, input_size, num_classes):
#         super(Net, self).__init__()
#         self.fc1 = nn.Linear(input_size, 128)
#         self.dropout1 = nn.Dropout(0.3)
#         self.fc2 = nn.Linear(128, 64)
#         self.dropout2 = nn.Dropout(0.3)
#         self.fc3 = nn.Linear(64, 32)
#         self.dropout3 = nn.Dropout(0.3)
#         self.fc4 = nn.Linear(32, num_classes)
    
#     def forward(self, x):
#         x[:, -2:] *= 5.0  # Adjust the weighting factor as needed
#         x = F.relu(self.fc1(x))
#         x = self.dropout1(x)
#         x = F.relu(self.fc2(x))
#         x = self.dropout2(x)
#         x = self.fc3(x)
#         x = self.dropout3(x)
#         x = self.fc4(x)
#         return x
################################################################################    



################################################################################
# a configurable MLP


class Net(nn.Module):
    def __init__(self, input_size, num_classes, depth, base_exp=7, decay_factor=1):
        super(Net, self).__init__()
        self.depth = depth
        self.base_exp = base_exp
        
        # Generate hidden layer sizes as powers of 2
        # Starts from 2^base_exp and decreases by 1 exponent with each layer
        hidden_sizes = [2 ** (self.base_exp - int(i/decay_factor)) for i in range(depth - 1)]
        hidden_sizes.append(32)  # Last hidden layer has 32 neurons
        print("Hidden layer sizes:", hidden_sizes)
        self.fcs = nn.ModuleList()
        self.dropouts = nn.ModuleList()
        in_features = input_size
        
        # Create hidden layers and dropout layers
        for out_features in hidden_sizes:
            self.fcs.append(nn.Linear(in_features, out_features))
            self.dropouts.append(nn.Dropout(0.3))
            in_features = out_features
        
        # Output layer
        self.fc_out = nn.Linear(in_features, num_classes)
    
    def forward(self, x):
        x[:, -2:] *= 5.0  # Adjust the weighting factor as needed
        
        for i in range(len(self.fcs)):
            if i < len(self.fcs) - 1:  # Apply ReLU and Dropout after all but last hidden layer
                x = F.relu(self.fcs[i](x))
                x = self.dropouts[i](x)
            else:  # For the last hidden layer, do not apply ReLU
                x = self.fcs[i](x)
                x = self.dropouts[i](x)
        
        x = self.fc_out(x)
        return x
################################################################################



################################################################################
## transformer
# I fixed the lenght of the gene sequence to 100, should change later 

# result log: acc: 88%, r2:0.76

# class Net(nn.Module):
#     def __init__(self, num_classes):
#         super(Net, self).__init__()
#         # Gene sequence parameters
#         gene_seq_length = 100  # Since we have 100 gene features
#         embedding_dim = 16     # Adjust based on your preference

#         # Embedding layer for gene sequence features
#         self.embedding = nn.Linear(1, embedding_dim)

#         # Positional Encoding
#         self.positional_encoding = nn.Parameter(torch.zeros(1, gene_seq_length, embedding_dim))

#         # Transformer Encoder
#         encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=4)
#         self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)

#         # Processing the last two features
#         self.fc_aux = nn.Sequential(
#             nn.Linear(2, 16),
#             nn.ReLU(),
#             nn.Linear(16, 16),
#             nn.ReLU()
#         )

#         # Final fully connected layer
#         self.fc_out = nn.Linear(gene_seq_length * embedding_dim + 16, num_classes)

#     def forward(self, x):
#         # Split the input into gene sequence features and auxiliary features
#         gene_seq_features = x[:, :100]  # (batch_size, 100)
#         aux_features = x[:, 100:]       # (batch_size, 2)

#         batch_size = gene_seq_features.size(0)

#         # Reshape gene sequence features to (batch_size, seq_length, 1)
#         gene_seq_features = gene_seq_features.unsqueeze(-1)  # (batch_size, 100, 1)

#         # Embedding
#         gene_seq_features = self.embedding(gene_seq_features)  # (batch_size, 100, embedding_dim)

#         # Add positional encoding
#         gene_seq_features = gene_seq_features + self.positional_encoding  # (batch_size, 100, embedding_dim)

#         # Transformer expects input shape (seq_length, batch_size, embedding_dim)
#         gene_seq_features = gene_seq_features.permute(1, 0, 2)  # (100, batch_size, embedding_dim)

#         # Pass through Transformer Encoder
#         gene_seq_features = self.transformer_encoder(gene_seq_features)

#         # Permute back to (batch_size, seq_length, embedding_dim)
#         gene_seq_features = gene_seq_features.permute(1, 0, 2)

#         # Flatten gene sequence features
#         gene_seq_features = gene_seq_features.contiguous().view(batch_size, -1)

#         # Process the auxiliary features
#         aux_features = self.fc_aux(aux_features)  # (batch_size, 16)

#         # Concatenate gene sequence features and auxiliary features
#         x = torch.cat([gene_seq_features, aux_features], dim=1)  # (batch_size, total_features)

#         # Output layer
#         x = self.fc_out(x)
#         return x


################################################################################

In [84]:
# encoder and decoder structure
# instead of doing pca, we can use autoencoder to reduce the dimensionality of the data
# seems not work, the dimension seems too high for the autoencoder to work 

# class Autoencoder(nn.Module):
#     def __init__(self, input_dim, latent_dim=128, aux_dim=2):
#         super(Autoencoder, self).__init__()
#         gene_dim = input_dim - aux_dim  # Dimension of gene expression data
        
#         # Encoder for gene expression data
#         self.encoder = nn.Sequential(
#             nn.Linear(gene_dim, 4096),
#             nn.ReLU(True),
#             nn.Dropout(0.3),
#             nn.Linear(4096, 2048),
#             nn.ReLU(True),
#             nn.Dropout(0.3),
#             nn.Linear(2048, 1024),
#             nn.ReLU(True),
#             nn.Dropout(0.3),
#             nn.Linear(1024, 512),
#             nn.ReLU(True),
#             nn.Linear(512, latent_dim)
#         )
        
#         # Decoder for gene expression data
#         self.decoder = nn.Sequential(
#             nn.Linear(latent_dim + aux_dim, 512),
#             nn.ReLU(True),
#             nn.Linear(512, 1024),
#             nn.ReLU(True),
#             nn.Dropout(0.3),
#             nn.Linear(1024, 2048),
#             nn.ReLU(True),
#             nn.Dropout(0.3),
#             nn.Linear(2048, 4096),
#             nn.ReLU(True),
#             nn.Dropout(0.3),
#             nn.Linear(4096, gene_dim),
#             nn.Sigmoid()  # Use Sigmoid if gene data is normalized between 0 and 1
#         )
        
#         # Auxiliary feature processor
#         self.aux_processor = nn.Sequential(
#             nn.Linear(aux_dim, aux_dim * 8),
#             nn.ReLU(True),
#             nn.Linear(aux_dim * 8, aux_dim * 8),
#             nn.ReLU(True),
#             nn.Linear(aux_dim * 8, aux_dim * 4),
#             nn.ReLU(True),
#             nn.Linear(aux_dim * 4, aux_dim * 2),
#             nn.ReLU(True),
#             nn.Linear(aux_dim * 2, aux_dim),
#             nn.ReLU(True)
#         )
        
#     def forward(self, x):
#         # Split the input into gene expression data and auxiliary features
#         gene_data = x[:, :-2]  # (batch_size, gene_dim)
#         aux_features = x[:, -2:]  # (batch_size, aux_dim)
        
#         # Process auxiliary features
#         aux_processed = self.aux_processor(aux_features)
        
#         # Encode gene expression data
#         latent = self.encoder(gene_data)
        
#         # Concatenate latent representation with auxiliary features
#         combined = torch.cat((latent, aux_processed), dim=1)
        
#         # Decode to reconstruct gene expression data
#         reconstructed_gene = self.decoder(combined)
        
#         # Optionally, reconstruct auxiliary features (if desired)
#         # For now, we assume we only reconstruct gene expression data
        
#         # Combine reconstructed gene data with original auxiliary features
#         output = torch.cat((reconstructed_gene, aux_features), dim=1)
        
#         return output
    
# # convert data to tensor dataset differently if pca is not used
# y_train = torch.from_numpy(y_train).long()
# y_test = torch.from_numpy(y_test).long()
# train_dataset = TensorDataset(X_train, y_train)
# test_dataset = TensorDataset(X_test, y_test)

# # Data loaders
# batch_size = 32
# train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size=batch_size)

# # Define model, loss function, optimizer
# input_size = X_train.shape[1]
# y = torch.from_numpy(y).long()
# num_classes = len(torch.unique(y))


In [96]:
# Training the model
print(input_size, num_classes)
#model = Net(input_size, num_classes)
#model = Autoencoder(input_size, latent_dim=128, aux_dim=2)
model = Net(input_size, num_classes, depth=5, base_exp=7, decay_factor=2)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)


criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)



num_epochs = 50

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        #loss = class_criterion(outputs, labels) # if you want to balance the class weights
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    epoch_loss = running_loss / len(train_loader.dataset)
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')
# 

102 3
Hidden layer sizes: [128.0, 90.50966799187809, 64.0, 45.254833995939045, 32]


TypeError: empty() received an invalid combination of arguments - got (tuple, dtype=NoneType, device=NoneType), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.memory_format memory_format, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, *, torch.memory_format memory_format, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)


In [89]:
# evlaute the model

model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        #print(labels)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
accuracy = 100 * correct / total
print(f'Accuracy on test set: {accuracy:.2f}%')


Accuracy on test set: 92.55%


In [90]:
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        # For classification probabilities
        probabilities = torch.softmax(outputs, dim=1)
        # Assuming labels are on CPU
        all_labels.extend(labels.cpu().numpy())
        # Collect probabilities or predicted class indices
        all_preds.extend(probabilities.cpu().numpy())



# Convert lists to NumPy arrays
all_preds = np.array(all_preds)        # Shape: [num_samples, num_classes]
all_labels = np.array(all_labels)      # Shape: [num_samples]

# Number of classes
num_classes = all_preds.shape[1]

# Convert true labels to one-hot encoding
all_labels_one_hot = label_binarize(all_labels, classes=range(num_classes))



r_squared_values = []

for i in range(num_classes):
    y_true = all_labels_one_hot[:, i]
    y_pred = all_preds[:, i]
    r_squared = r2_score(y_true, y_pred)
    r_squared_values.append(r_squared)
    print(f"Class {i} - R²: {r_squared:.4f}")


# Flatten the arrays
y_true_flat = all_labels_one_hot.flatten()
y_pred_flat = all_preds.flatten()

# Calculate R²
overall_r_squared = r2_score(y_true_flat, y_pred_flat)
print(f"Overall R²: {overall_r_squared:.4f}")




Class 0 - R²: 0.9237
Class 1 - R²: 0.7814
Class 2 - R²: 0.7124
Overall R²: 0.8320
