# Imports

In [None]:
from os import path, listdir
from copy import deepcopy
import stlearn as st
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import matplotlib.pyplot as plt
import cv2
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder
from torch import tensor
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim

%load_ext autoreload
%autoreload 2

from scanpy_stlearn_loaders import StlearnLoader
import trainer as trainer
from data import get_data
from models import get_model
from tester import tester

In [None]:
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# Load Data 

In [None]:
dataset_name = 'Visium_Mouse_Olfactory_Bulb'

## Genes-Spots Expression Values 

In [None]:
obj = StlearnLoader().load_local_visum(path=path.join('/', 'data', dataset_name),
                                      count_file='filtered_feature_bc_matrix.h5')
x = obj.X.toarray()
n_spots, n_genes = x.shape
print(f'# spots: {n_spots} | # genes: {n_genes}')
obj

# Filter Genes
- min_cells = Keep genes with at least X spots with non zero expression (x[spot, gene] > 0)
- min_counts = Keep genes with at least X expressions (sum of expression over all spots) (sum(x[:, gene]) > X)

In [None]:
# Keep genes with at least 15% non zero spots
min_cells = int(n_spots * 0.15)
print(f'Keep genes with at least {min_cells} non zero spots')
st.pp.filter_genes(obj, min_cells=min_cells)
x = obj.X.toarray()
n_spots, n_genes = x.shape
print(f'# spots: {n_spots} | # genes: {n_genes}')

In [None]:
min_counts = 10
print(f'Keep genes with total expression of at least {min_counts} over all spots')
st.pp.filter_genes(obj, min_counts=min_counts)
x = obj.X.toarray()
n_spots, n_genes = x.shape
print(f'# spots: {n_spots} | # genes: {n_genes}')

# Generate Expression DF

## Transfer Matrix to DF 

In [None]:
spots_values = obj.obs.index.values
genes_values = obj.var.index.values
df_expressions_matrix = pd.DataFrame(x, columns=genes_values, index=spots_values)
df_expressions = df_expressions_matrix.stack().reset_index()
df_expressions.columns = ['spot', 'gene', 'expression']
print(f'shape: {df_expressions.shape}')
display(df_expressions['expression'].describe())
df_expressions.head()

In [None]:
plt.figure(figsize=(15, 5))
df_expressions.loc[df_expressions['expression'] < 10, 'expression'].plot.hist(bins=10)
plt.title('Gene-Spot expression histogram')
plt.show()

## Encode Genes and Spots 

In [None]:
# Ordinal encoding the genes and spots for supported type
oe_genes = OrdinalEncoder()
df_expressions[['gene']] = oe_genes.fit_transform(df_expressions[['gene']].values)
oe_spots = OrdinalEncoder()
df_expressions[['spot']] = oe_spots.fit_transform(df_expressions[['spot']].values)

df_expressions[['spot', 'gene']] = df_expressions[['spot', 'gene']].astype(int)
df_expressions.head()

# Train Test Split 

In [None]:
def plot_datasets_distribution(df_train, df_valid, df_test):
    print(f'Split to train, valid, and test:\nTrain shape:{df_train.shape}\nValid shape:{df_valid.shape}\nTest shape:{df_test.shape}')
    f, axs = plt.subplots(1, 3, figsize=(19, 5), sharey=True, sharex=True)
    df_train.loc[df_train['expression'] < 10, 'expression'].value_counts(normalize=True).sort_index().plot.bar(ax=axs[0])
    df_valid.loc[df_valid['expression'] < 10, 'expression'].value_counts(normalize=True).sort_index().plot.bar(ax=axs[1])
    df_test.loc[df_test['expression'] < 10, 'expression'].value_counts(normalize=True).sort_index().plot.bar(ax=axs[2])
    axs[0].set_title(f'Train Expression Dist | Avg Expression: {df_train["expression"].mean():.2f}')
    axs[1].set_title(f'Valid Expression Dist | Avg Expression: {df_valid["expression"].mean():.2f}')
    axs[2].set_title(f'Test Expression Dist | Avg Expression: {df_test["expression"].mean():.2f}')
    plt.show()

## Normal Random

In [None]:
df_train, df_test = train_test_split(df_expressions, test_size=0.10)
df_train, df_valid = train_test_split(df_train, test_size=0.10)

In [None]:
plot_datasets_distribution(df_train, df_valid, df_test)
print('We can see that the train, valid and test datasets share the same expression distribution')

## Random on Top Expressed Genes
Keep only the top N expressed genes (those with the highest expression ratio) and split randomly

In [None]:
N = 100
genes_expressed = np.sum(x, axis=0) / (np.count_nonzero(x, axis=0) + 1)
top_genes_indices = genes_expressed.argsort()[-N:][::-1]
top_genes_names = obj.var.index[top_genes_indices]
top_genes_codes = oe_genes.transform(X=pd.DataFrame(np.array(top_genes_names)).values)[:, 0]

In [None]:
mask = df_expressions['gene'].isin(top_genes_codes)
df_expressions_top_genes = df_expressions.loc[mask]

df_train_top_genes, df_test_top_genes = train_test_split(df_expressions_top_genes, test_size=0.10)
df_train_top_genes, df_valid_top_genes = train_test_split(df_train_top_genes, test_size=0.10)

In [None]:
plot_datasets_distribution(df_train_top_genes, df_valid_top_genes, df_test_top_genes)

# Create Pytorch Data Loaders

In [None]:
batch_size = 128

## Generate DataSets 

In [None]:
class ExpressionDataset(Dataset):
    """
    Generate expression dataset to use in the our models, where each sample should be a tuple of (gene, spot, expression)
    """

    def __init__(self, df, device):
        self.num_samples = len(df)
        self.genes = tensor(df['gene'].values).to(device)
        self.spots = tensor(df['spot'].values).to(device)
        self.labels = tensor(df['expression'].values)
        self.num_genes = df['gene'].max()
        self.num_spots = df['spot'].max()

    def __getitem__(self, index):
        gene = self.genes[index]
        spot = self.spots[index]
        label = self.labels[index].item()
        return gene, spot, label

    def __len__(self):
        return self.num_samples

    def get_all_data(self):
        return self.genes, self.spots, self.labels

In [None]:
ds_train = ExpressionDataset(df=df_train, device=device)
ds_valid = ExpressionDataset(df=df_valid, device=device)
ds_test = ExpressionDataset(df=df_test, device=device)

In [None]:
ds_train_top_genes = ExpressionDataset(df=df_train_top_genes, device=device)
ds_valid_top_genes = ExpressionDataset(df=df_valid_top_genes, device=device)
ds_test_top_genes = ExpressionDataset(df=df_test_top_genes, device=device)

## DataSets to DataLoaders 

In [None]:
dl_train = DataLoader(dataset=ds_train, batch_size=batch_size, shuffle=True)
dl_valid = DataLoader(dataset=ds_valid, batch_size=batch_size, shuffle=True)
dl_test = DataLoader(dataset=ds_test, batch_size=batch_size, shuffle=True)

In [None]:
dl_train_top_genes = DataLoader(dataset=ds_train_top_genes, batch_size=batch_size, shuffle=True)
dl_valid_top_genes = DataLoader(dataset=ds_valid_top_genes, batch_size=batch_size, shuffle=True)
dl_test_top_genes = DataLoader(dataset=ds_test_top_genes, batch_size=batch_size, shuffle=True)

# Load Model

In [None]:
dl_train_exp = dl_train_top_genes
dl_valid_exp = dl_valid_top_genes
dl_test_exp = dl_test_top_genes

In [None]:
model_name='NMF'
params = {
    'learning_rate': 0.001,
    'optimizer': "SGD",
    'latent_dim': 10,
    'batch_size': batch_size
}

In [None]:
model = get_model(model_name=model_name, params=params, dl_train=dl_train_exp)

# Train Model 

In [None]:
max_epochs = 5
early_stopping = 4

## Load Optimizer 

In [None]:
optimizer = getattr(optim, params['optimizer'])(model.parameters(), lr=params['learning_rate'])

## Train

In [None]:
model, valid_loss = trainer.trainer(
    model=model, 
    optimizer=optimizer, 
    max_epochs=max_epochs, 
    early_stopping=early_stopping, 
    dl_train=dl_train_exp, 
    dl_test=dl_valid_exp, 
    device=device, 
    dataset_name=dataset_name, 
    model_name=model_name
)

## Test 

In [None]:
test_loss = tester(
    model=model,
    dl_test=dl_test_exp,
    device=device
)
print(f'Test loss = {test_loss}')

## Reconstruction of the whole matrix

In [None]:
model = model.to(device)
model.eval()

all_gens = []
all_spots = []
expressions_pred = []
expressions_true = []

with torch.no_grad():
    for set_dl in [dl_train_exp, dl_valid_exp, dl_test_exp]:
        for batch in set_dl:
            gens, spots, y = batch
            gens.to(device)
            spots.to(device)
            y_pred = model(gens, spots)
            y_pred = np.clip(a=y_pred, a_min=0, a_max=None)
            
            all_gens.extend(gens.tolist())
            all_spots.extend(spots.tolist())
            expressions_pred.extend(y_pred.tolist())
            expressions_true.extend(y.tolist())

In [None]:
df_expressions_preds = pd.DataFrame({'gene': all_gens, 'spot': all_spots, 'expression': expressions_pred})
df_expressions_preds[['gene']] = oe_genes.inverse_transform(df_expressions_preds[['gene']].values)
df_expressions_preds[['spot']] = oe_spots.inverse_transform(df_expressions_preds[['spot']].values)

In [None]:
df_expressions_true = df_expressions_preds.copy()
df_expressions_true['expression'] = expressions_true

In [None]:
df_expressions_preds_matrix = df_expressions_preds.pivot(index='spot', columns='gene', values='expression')
df_expressions_true_matrix = df_expressions_true.pivot(index='spot', columns='gene', values='expression')

In [None]:
f, (ax1, ax2) = plt.subplots(1, 2, figsize=(19, 6))
df_expressions_true.expression.hist(ax=ax1)
df_expressions_preds.expression.hist(ax=ax2)
ax1.set_title('True Genes Expression Histogram')
ax2.set_title('Prediction Genes Expression Histogram')
plt.show()

<b>The model is regressed to the mean, therefore we are not predicting high expressions - need to normalize</b>

In [None]:
new_obj = deepcopy(obj)
tmp_genes_locations = [obj.var.index.get_loc(key=gene_key) for gene_key in df_expressions_true_matrix.columns]

In [None]:
# Check that the matrix is in the same order and I can replace the values
print(df_expressions_true_matrix.shape[0]*df_expressions_true_matrix.shape[1])
np.equal(new_obj.X.toarray()[:, tmp_genes_locations], df_expressions_true_matrix).sum().sum()

In [None]:
# Replace the values
new_obj.X[:, tmp_genes_locations] = df_expressions_preds_matrix.values

In [None]:
# Check that the matrix is in the same order and the replacement went well
print(df_expressions_preds_matrix.shape[0]*df_expressions_preds_matrix.shape[1])
np.equal(new_obj.X.toarray()[:, tmp_genes_locations], df_expressions_preds_matrix).sum().sum()

In [None]:
for gene_symbol in top_genes_names[:5]:
    f, (ax1, ax2) = plt.subplots(1, 2, figsize=(17, 8))
    st.pl.gene_plot(obj, gene_symbols=gene_symbol, size=20, ax=ax1)
    ax1.set_title('True Expression')
    st.pl.gene_plot(new_obj, gene_symbols=gene_symbol, size=20, ax=ax2)
    ax2.set_title('Reconstructed Expression')
    print(f'Gene: {gene_symbol}')
    plt.show()

## Clustering Before & After 

In [None]:
obj_clusters = deepcopy(obj)
new_obj_clusters = deepcopy(new_obj)

In [None]:
st.pp.normalize_total(obj_clusters)
st.pp.log1p(obj_clusters)

st.pp.normalize_total(new_obj_clusters)
st.pp.log1p(new_obj_clusters)

In [None]:
# run PCA for gene expression data
st.em.run_pca(obj_clusters, n_comps=50)
# K-means clustering
st.tl.clustering.kmeans(obj_clusters, n_clusters=7, use_data="X_pca", key_added="X_pca_kmeans")

In [None]:
# run PCA for gene expression data
st.em.run_pca(new_obj_clusters, n_comps=50)
# K-means clustering
st.tl.clustering.kmeans(new_obj_clusters, n_clusters=7, use_data="X_pca", key_added="X_pca_kmeans")

In [None]:
f, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
st.pl.cluster_plot(obj_clusters, use_label="X_pca_kmeans", ax=ax1)
st.pl.cluster_plot(new_obj_clusters, use_label="X_pca_kmeans", ax=ax2)