In [None]:
!git clone https://github.com/Kei0501/LincSpectr

In [None]:
from google.colab import drive
drive.mount('/content/drive')

#Installation

In [None]:
!pip install scanpy==1.9.6 ssqueezepy==0.6.4 pynwb==2.5.0

#Import packages

In [None]:
import torch
import torch.distributions as dist
import torch.nn as nn
from torch.nn.parameter import Parameter
from torch import functional as F
from torch.distributions.kl import kl_divergence
from torch.nn import init
from torchvision import transforms
import numpy as np
import pandas as pd
import seaborn as sns
import scanpy as sc
from matplotlib import pyplot as plt
import umap
from ssqueezepy import ssq_cwt
torch.cuda.is_available()

import scipy.stats as stats
import scipy.io as io
import scipy.signal as signal
from scipy.optimize import curve_fit
from scipy import integrate
import logging
import importlib
import os
import warnings
from PIL import Image
from pynwb import NWBHDF5IO
import random
from functorch import vmap
from functorch import vjp
from tqdm import tqdm

sns.set_style()

import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42

from sklearn import linear_model
ransac = linear_model.RANSACRegressor()

%matplotlib inline

#Load data and preprocessing
LincSpectr requires single-cell transcriptome data and CWT transformed electrophysiological data. After pre-processing, normalized counts are converted back to raw counts.

In [None]:
#prepare transcriptome data
#add RNA_family information and remove low quality data
adata = sc.read_csv('./m1_patchseq_exon_counts.csv')
plus_data = pd.read_table('./m1_patchseq_meta_data.csv',sep='\t',index_col=1)
adata = adata.T
adata.obs = plus_data.loc[adata.obs_names]
adata = adata[adata.obs["RNA family"] != "low quality"]
adata.layers['count'] = adata.X
sc.pp.filter_cells(adata,min_counts=100)
sc.pp.filter_genes(adata,min_cells=10)
sc.pp.normalize_total(adata,target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata,n_top_genes=2000)

In [None]:
#prepare electrophysiological data
file_names = collect_filename("./000008/*")
transform_efeatures(file_names, "./data_for_VAE/")

#Training model
LincSpectr trains three models and estimates the electrophysiological features from transcriptomic features.

In [None]:
workflow

loss_mode poisson
Start Dynamics opt
Dynamics opt patience 10
val loss mean (post 10 epochs) at epoch 0 is 7.434419631958008
Early Stopping! at 99 epoch
Done Dynamics opt
Dynamics_last_val_loss:5.581026554107666
Dynamics_last_test_loss:11.45341682434082
Start Kinetics opt
Kinetics opt patience 10
val loss mean (post 10 epochs) at epoch 0 is 5.602231502532959
Early Stopping! at 23 epoch
Done Kinetics opt
Kinetics_last_val_loss:5.5755414962768555
Kinetics_last_test_loss:11.440404891967773
train_s_correlation 0.37785741099056397
train_u_correlation 0.28100604102252885
val_s_correlation 0.33543039621025844
val_u_correlation 0.19849278860828423
test_s_correlation 0.34006264804505393
test_u_correlation 0.19878720975214437


#Visualization of latent space

In [None]:
t_test, e_test = [], []
for i in range(len(dataset)):
    t_test.append(dataset[i][0])
    e_test.append(dataset[i][1])
test_x = torch.stack(t_test, dim = 0)
test_xcell_id = torch.stack(e_test, dim = 0)
test_x = test_x.to(device)
test_xcell_id = test_xcell_id.to(device)

t_vae.to(device)
with torch.no_grad():
    t_vae.eval()
    tz, qz, xld = t_model(test_x)
utils.make_umap(tz)

In [None]:
e_vae.to(device)
with torch.no_grad():
    e_vae.eval()
    ez, qz, ld_img = t_vae(test_xcell_id.view(-1,set_timeax*set_freqax))
utils.make_umap(ez)

#Estimate e-features from t-features

In [None]:
rand_num = random.randrange(len(valid_list))
valid_sample1 = valid_list[rand_num]
cell_name1 = './data_for_VAE/' + valid_sample1 + '.npy'
rand_num = random.randrange(len(valid_list))
valid_sample2 = valid_list[rand_num]
cell_name2 = './data_for_VAE/' + valid_sample2 + '.npy'

utils.show_prediction(cell_name1, cell_name2)

#Inverse analysis of the model

In [None]:
#Inverse analysis of Vip cells
Vip, Lamp5, Pvalb, Sst, ET, IT, CT = utils.celltype_list(valid_list)
avr_express = utils.average_expression(adata,Vip)
image_shape = (128,128)
u_pick, top_genes, top_expression = utils.inverse_analysis(avr_express, N=50, image_shape)

In [None]:
u_pick = u_pick.reshape(np.load(sample_data).shape)
upick_image = u_pick.to('cpu').detach().numpy().copy()
plt.figure(figsize=(10,8))
plt.imshow(upick_image, aspect='auto', cmap='turbo', vmin=0)

In [None]:
plt.bar(top_genes,top_exoression)
plt.xticks(rotation=45)

In [None]:
for top_gene in top_genes:
    print(top_gene)