# ML Single Cell Classification Project

In this project, we build a classifier to predict brain cell types using single-cell RNA sequencing data. The dataset consists of:
- **Expression data:** A normalized and partially preprocessed gene expression matrix (280,186 cells × 254 genes) stored in `counts.h5ad`
- **Cell annotations:** A CSV file (`cell_labels.csv`) containing cell type labels and additional metadata  
  (Three classes: **GABAergic**, **Glutamatergic**, and **Other**)

We use the `scanpy` library for handling single-cell data, along with common machine learning libraries from scikit-learn for model training, hyperparameter tuning, and evaluation.

---

## 1. Environment Setup and Imports

First, install and load the required packages. (If working in Google Colab, install `scanpy` with `!pip install scanpy`.)

```python
# Install scanpy if needed
# !pip install scanpy

import scanpy as sc
import numpy as np
import pandas as pd
import anndata as ad
import pooch 
import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.utils import resample
from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.metrics import f1_score


## 2. Data Loading and Exploratory Data Analysis

**2.1. Load Annotation and Expression Data**

The cell annotation file (`cell_labels.csv`) contains the class labels (in column class_label), while the expression data (`counts.h5ad`) contains the gene expression matrix and metadata.

```python
#Load cell annotation file (cell labels)
cell_labels = pd.read_csv("C:/Users/Семья/Desktop/bioinformatics/ML/project/cell_labels.csv", index_col=0)

#Load expression data (scanpy format)
adata = sc.read_h5ad("C:/Users/Семья/Desktop/bioinformatics/ML/project/counts.h5ad")



**2.2. Explore the Annotation Data**

Display the first few rows and get a summary of class labels.
```python
# Display first rows of cell_labels
cell_labels.head()

# Extract class labels and print unique classes
class_label = cell_labels['class_label'].values  
print(f"Unique classes: {np.unique(class_label)}")
```
Expected output:
Unique classes: ['GABAergic', 'Glutamatergic', 'Other']


Check the distribution of cell types:

```python
# Count of cells in each class
print(cell_labels['class_label'].value_counts())

# Group by class label for further insights
cell_labels.groupby('class_label').count()
```

Visualize the distribution of subclasses (if available) with a bar plot:

```python
cell_labels['subclass'].value_counts().plot(kind='bar')
plt.title("Subclass Distribution")
plt.xlabel("Subclass")
plt.ylabel("Count")
plt.show()
```
Create a comparison between subclass and label:

```python
# Check if 'subclass' matches 'label'
cell_labels['is_equal'] = cell_labels['subclass'] == cell_labels['label']
counts = cell_labels['is_equal'].value_counts()

plt.figure(figsize=(6, 4))
sns.barplot(x=counts.index, y=counts.values, palette="pastel")
plt.xticks([0, 1], ['Not Equal', 'Equal'])
plt.ylabel("Count")
plt.title("Comparison of Subclass and Label")
plt.show()
```

Visualize the mapping of subclasses to class labels using heatmap:
```python
cross_tab = pd.crosstab(cell_labels['class_label'], cell_labels['subclass'])
plt.figure(figsize=(12, 8))
sns.heatmap(cross_tab, annot=False, cmap="Blues", cbar=True)
plt.title("Mapping of Subclasses to Class Labels", fontsize=16)
plt.xlabel("Subclass", fontsize=12)
plt.ylabel("Class Label", fontsize=12)
plt.xticks(rotation=45, ha="right")
plt.yticks(rotation=0)
plt.show()
```

## 3. Data Preprocessing with Scanpy

**3.1. Explore the scRNAseq Data**

Display key information about the `adata` object:

```python
# Print basic details about the AnnData object
print(adata)
adata.X        # Expression matrix
adata.var      # Gene information
adata.obs      # Cell metadata
adata.layers.keys()  # Other available data layers
```

**3.2. Quality Control (QC) Metrics**

Calculate quality control metrics using Scanpy's built-in function:

```python
# Calculate QC metrics and add to adata.obs and adata.var
obs_qc, var_qc = sc.pp.calculate_qc_metrics(adata, percent_top=[20,50,100], inplace=False)
print(obs_qc.columns)
print(var_qc.columns)
```

Visualize QC metrics:

```python
# Add QC metrics to the adata object
sc.pp.calculate_qc_metrics(adata, percent_top=[20,50,100], inplace=True)

# Violin plot for several QC metrics
sc.pl.violin(
    adata,
    ["n_genes_by_counts", "log1p_n_genes_by_counts", "total_counts", "pct_counts_in_top_100_genes",
     "pct_counts_in_top_50_genes", "pct_counts_in_top_20_genes"],
    jitter=0.4,
    multi_panel=True,
)
```

Scatter plot to check the relationship between total counts and number of genes:

```python
sc.pl.scatter(adata, "total_counts", "n_genes_by_counts")
```

Filter out cells with extremely low expression:

```python
# Filter cells with fewer than 1 gene detected (removes cells with zero counts)
sc.pp.filter_cells(adata, min_genes=1)
```

**3.3. Integrate Annotation Data with adata**

Join cell annotation data into `adata.obs`:


```python
adata.obs = adata.obs.join(cell_labels, how='left', rsuffix="_cell_labels")

# Drop unnecessary columns from the joined metadata
cols_to_drop = ["sample_id_cell_labels", "slice_id_cell_labels", "class_label_cell_labels", 
                "subclass_cell_labels", "label_cell_labels"]
adata.obs = adata.obs.drop(columns=cols_to_drop)
```

(Optional) Check for duplicate cells or doublets:

```python
sc.pp.scrublet(adata, batch_key="sample_id")  # This computes doublet scores, if applicable
```

## 4. Defining Features and Labels for Classification

**4.1. Define Input (X) and Output (y)**

Extract the expression matrix as features and convert class labels to categorical numeric codes:

```python
# Define X as the expression matrix (cells × genes)
X = adata.X

# Convert class labels to categorical and then to numeric codes
print("Unique class labels:", adata.obs['class_label'].unique())
adata.obs['class_label'] = pd.Categorical(
    adata.obs['class_label'],
    categories=['Other', 'Glutamatergic', 'GABAergic']
)
y = adata.obs['class_label'].cat.codes.values
print("Categories:", adata.obs['class_label'].cat.categories)
```

Check a few predicted labels:

```python
_labels = [adata.obs['class_label'].cat.categories[code] for code in y[:10]]
print("First 10 cell type predictions:", _labels)
```


**4.2. Split the Data into Training, Validation, and Test Sets**

Use stratified splitting to preserve class distributions:

```python
# First, split into training/validation and test sets (80/20)
X_trainval, X_test, y_trainval, y_test = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=42
)
# Then, split the training/validation set into training and validation (75/25)
X_train, X_val, y_train, y_val = train_test_split(
    X_trainval, y_trainval, test_size=0.25, stratify=y_trainval, random_state=42
)
```

Set up cross-validation to preserve the stratification:
```python
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
```

Normalize the feature data:
```python
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val = scaler.transform(X_val)
X_test = scaler.transform(X_test)
```

## 5. Model Training, Hyperparameter Tuning, and Evaluation

**5.1. Define Models and Pipelines**

Set up a dictionary with pipelines and parameter grids for multiple classifiers:

```python
models_params = {
    "KNN": {
        "pipeline": Pipeline([("clf", KNeighborsClassifier())]),
        "params": {
            "clf__n_neighbors": [3, 5, 7, 10],
            "clf__weights": ["uniform", "distance"]
        }
    },
    "LogisticRegression": {
        "pipeline": Pipeline([("clf", LogisticRegression(random_state=42))]),
        "params": {
            "clf__penalty": ["l2", "none"],
            "clf__solver": ["lbfgs", "saga", "liblinear"],
            "clf__max_iter": [100, 200, 500],
            "clf__multi_class": ["auto", "ovr", "multinomial"],
            "clf__n_jobs": [-1]
        }
    },
    "RandomForest": {
        "pipeline": Pipeline([("clf", RandomForestClassifier(random_state=42))]),
        "params": {
            "clf__n_estimators": [100, 200, 500],
            "clf__max_depth": [None, 10, 20],
            "clf__min_samples_leaf": [1, 2, 5]
        }
    },
    "SVM": {
        "pipeline": Pipeline([("clf", SVC(random_state=42))]),
        "params": {
            "clf__kernel": ["linear", "rbf"],
            "clf__C": [0.1, 1, 10]
        }
    }
}
```

**5.2. Hyperparameter Tuning and Model Evaluation**

Loop through each model, perform grid search with cross-validation, and evaluate performance on validation and test sets:

```python
results = []  # To store results

for model_name, mp in models_params.items():
    print(f"Running GridSearch for model: {model_name}")
    grid_search = GridSearchCV(
        estimator=mp["pipeline"],
        param_grid=mp["params"],
        cv=cv,
        scoring="f1_macro",  #F1-score metric
        n_jobs=-1,
        return_train_score=True
    )
    # Train the model
    grid_search.fit(X_train, y_train)
    
    # Best parameters and cross-validation score
    best_params = grid_search.best_params_
    best_cv_score = grid_search.best_score_
    
    # Evaluate on the validation set
    val_pred = grid_search.predict(X_val)
    val_score = f1_score(y_val, val_pred, average="macro")
    
    # Evaluate on the test set
    test_pred = grid_search.predict(X_test)
    test_score = f1_score(y_test, test_pred, average="macro")
    
    results.append({
        "model": model_name,
        "best_params": best_params, 
        "cv_score": best_cv_score,
        "val_score": val_score,
        "test_score": test_score
    })
    
    # Visualize hyperparameter
    plt.figure(figsize=(8, 6))
    plt.plot(grid_search.cv_results_['mean_test_score'], label=f'{model_name}')
    plt.xlabel("Iterations")
    plt.ylabel("Mean F1 Score")
    plt.title(f"Hyperparameter Tuning for {model_name}")
    plt.legend()
    plt.show()

# Display results for all models
results_df = pd.DataFrame(results)
print(results_df)
```

**5.3. Discussion on Doublet Information**\

Note:

In the preprocessing, we used 
```python
sc.pp.scrublet(adata, batch_key="sample_id")
```
to compute doublet scores. Whether to include the doublet_score and predicted_doublet in the classification depends on your hypothesis and if you expect doublets to impact cell type classification. You might want to experiment by including or excluding these metrics in your feature set and evaluating any changes in model performance.