# Imports

In [None]:
import stlearn as st
import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from sklearn.decomposition import NMF
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

In [None]:
%load_ext autoreload
%autoreload 2

from data import *
import torch.optim as optim
from models import get_model
from trainer import trainer

In [None]:
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

In [None]:
dataset_name = 'V1_Human_Lymph_Node'
max_epochs = 2
model_name = 'NMF'
best_params = {
    'learning_rate': 0.001,
    'optimizer': "RMSprop",
    'latent_dim': 20,
    'batch_size': 512
}
early_stopping = 0

# Load Data

In [None]:
dl_train, dl_valid, dl_test, dl_full_train = get_data(dataset_name=dataset_name, batch_size=128, device='cpu')

In [None]:
model = get_model(model_name, best_params, dl_train)  # Build model
model

In [None]:
optimizer = getattr(optim, best_params['optimizer'])(model.parameters(), lr=best_params['learning_rate'])  # Instantiate optimizer
optimizer

In [None]:
test_loss = trainer(
    model=model, 
    optimizer=optimizer, 
    max_epochs=max_epochs, 
    early_stopping=early_stopping,
    dl_train=dl_train, 
    dl_test=dl_test, 
    device=device, 
    dataset_name=dataset_name, 
    model_name=model_name
)

# EDA 

# Imputation 

In [None]:
R = data.X
R.shape

## SKlearn NMF 

In [None]:
model = NMF(n_components=2, init='random', random_state=0)
model.fit(R)

In [None]:
U = model.transform(R)
V = model.components_
print(f'U shape: {U.shape}\nV shape: {V.shape}')

In [None]:
R_trans = model.inverse_transform(U)
R_trans.shape

In [None]:
rmse = loss_functions.RMSE(y_true=R.toarray(), y_pred=R_trans)
print(f'RMSE = {rmse}')

- Visualize (specific gene)
- Add dropout
- Add GMF
- Results analysis

## STLearn SME Imputation 

In [None]:
data_SME = data.copy()
count_zeros(data_=data_SME)

In [None]:
from pathlib import Path
TILE_PATH = Path("/tmp/tiles")
TILE_PATH.mkdir(parents=True, exist_ok=True)

st.pp.tiling(data_SME, TILE_PATH)

# this step uses deep learning model to extract high-level features from tile images
# may need few minutes to be completed
st.pp.extract_feature(data_SME)

In [None]:
# apply stSME to normalise log transformed data
data_SME_copy = st.spatial.SME.SME_impute0(data_SME, copy=True)
# count_zeros(data_=data_SME)