In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from pathlib import Path

from learn.dataset import TabularDataset
from learn.model import CiteAutoencoder
from learn.train import train_model, get_encodings

import umap
import plotly.express as px
import plotly.graph_objects as go

In [3]:
data_path = Path("./data/")
list(data_path.iterdir())

[WindowsPath('data/metadata.csv.gz'),
 WindowsPath('data/protein_scale.csv.gz'),
 WindowsPath('data/rna_scale.csv.gz')]

In [4]:
rna = pd.read_csv(data_path/"rna_scale.csv.gz", index_col=0).T
rna = rna.reset_index(drop=True)
rna.head()

Unnamed: 0,IGKC,HBA2,HBB,HBA1,IGHA1,IGLC2,JCHAIN,HBM,IGHG1,IGHM,...,INSR,RAD23B,COMMD4,PPARA,PFDN6,PDSS1,BANF1,DDI2,DCAF6,HSPA5
0,-0.428204,0.183748,0.406723,-0.488968,1.248214,-0.614123,-0.324617,-0.181656,-0.200979,-0.414643,...,-0.216175,2.892558,-0.237605,-0.122832,-0.32904,-0.130763,-0.506981,3.164261,-0.197585,-0.44033
1,-1.047339,-0.601956,-0.985002,-0.488968,-0.552748,-0.614123,-0.324617,-0.181656,4.964708,-0.414643,...,-0.216175,-0.342985,-0.237605,-0.122832,-0.32904,-0.130763,-0.506981,-0.171434,-0.197585,-0.44033
2,-1.047339,-0.601956,1.534527,-0.488968,-0.552748,-0.614123,-0.324617,-0.181656,-0.200979,-0.414643,...,-0.216175,-0.342985,6.509144,-0.122832,-0.32904,-0.130763,-0.506981,-0.171434,-0.197585,-0.44033
3,0.080052,0.828744,0.614775,1.051572,-0.552748,0.713069,-0.324617,-0.181656,-0.200979,-0.414643,...,4.322459,-0.342985,4.264288,-0.122832,-0.32904,-0.130763,-0.506981,-0.171434,-0.197585,-0.44033
4,0.953832,-0.601956,0.794895,-0.488968,1.413335,-0.614123,-0.324617,-0.181656,-0.200979,-0.414643,...,-0.216175,-0.342985,-0.237605,-0.122832,3.360494,-0.130763,-0.506981,-0.171434,-0.197585,-0.44033


In [5]:
nfeatures = rna.shape[1]
nfeatures

2000

In [6]:
train, valid = train_test_split(rna.to_numpy(dtype=np.float32), test_size=0.1, random_state=0)
train.shape, valid.shape

((27604, 2000), (3068, 2000))

In [7]:
train[0]

array([ 0.4614134 , -0.6019564 ,  1.0438204 , ..., -0.17143415,
       -0.19758475,  2.7492154 ], dtype=float32)

In [8]:
train_ds = TabularDataset(train)
valid_ds = TabularDataset(valid)

In [9]:
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=64, shuffle=False)

In [10]:
x, y = next(iter(train_dl))
x.shape, y.shape

(torch.Size([64, 2000]), torch.Size([64, 2000]))

In [11]:
model = CiteAutoencoder(nfeatures_rna=nfeatures, nfeatures_pro=0, hidden_rna=100, hidden_pro=0, z_dim=20)

In [12]:
model

CiteAutoencoder(
  (encoder): Encoder(
    (encoder_rna): LinBnDrop(
      (0): Linear(in_features=2000, out_features=100, bias=False)
      (1): LeakyReLU(negative_slope=0.01)
      (2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Dropout(p=0.1, inplace=False)
    )
    (encoder): LinBnDrop(
      (0): Linear(in_features=100, out_features=20, bias=False)
      (1): LeakyReLU(negative_slope=0.01)
      (2): BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (decoder): Decoder(
    (decoder): Sequential(
      (0): LinBnDrop(
        (0): Linear(in_features=20, out_features=100, bias=False)
        (1): LeakyReLU(negative_slope=0.01)
        (2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): LinBnDrop(
        (0): Linear(in_features=100, out_features=2000, bias=True)
      )
    )
  )
)

In [13]:
model(x).shape

torch.Size([64, 2000])

In [14]:
lr = 1e-2
epochs = 50
model, losses = train_model(model, train_dl, valid_dl, lr=lr, epochs=epochs)

 20%|████████████████▍                                                                 | 10/50 [00:28<01:52,  2.81s/it]

Epoch 10: train loss 0.6335026832304732; valid loss 0.6318057463874718


 40%|████████████████████████████████▊                                                 | 20/50 [00:56<01:24,  2.81s/it]

Epoch 20: train loss 0.6283182992034849; valid loss 0.6294482286626077


 60%|█████████████████████████████████████████████████▏                                | 30/50 [01:21<00:45,  2.28s/it]

Epoch 30: train loss 0.6232206119417955; valid loss 0.623957363614203


 80%|█████████████████████████████████████████████████████████████████▌                | 40/50 [01:43<00:21,  2.12s/it]

Epoch 40: train loss 0.6176626040544774; valid loss 0.620304610303359


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [02:04<00:00,  2.48s/it]

Epoch 50: train loss 0.6148799769326649; valid loss 0.6191976671890299





In [15]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=np.arange(1, epochs+1, 1), y=losses['train'],
                         mode='lines',
                         name='train'))
fig.add_trace(go.Scatter(x=np.arange(1, epochs+1, 1), y=losses['valid'],
                         mode='lines',
                         name='valid'))
fig.show()

In [16]:
test_ds = TabularDataset(rna.to_numpy(dtype=np.float32))
test_dl = DataLoader(test_ds, batch_size=64, shuffle=False)

encodings = get_encodings(model, test_dl)
encodings = encodings.cpu().numpy()
encodings.shape

(30672, 20)

In [17]:
# annotations
metadata = pd.read_csv(data_path/"metadata.csv.gz", index_col=0)
metadata.head()

Unnamed: 0,orig.ident,nCount_RNA,nFeature_RNA,nCount_ADT,nFeature_ADT,lane,donor,celltype.l1,celltype.l2,RNA.weight,ADT.weight,wsnn_res.2,seurat_clusters
a_AAACCTGAGCTTATCG-1,bmcite,7546,2136,1350,25,HumanHTO4,batch1,Progenitor cells,Prog_RBC,0.487299,0.512701,19,19
a_AAACCTGAGGTGGGTT-1,bmcite,1029,437,2970,25,HumanHTO1,batch1,T cell,gdT,0.245543,0.754457,10,10
a_AAACCTGAGTACATGA-1,bmcite,1111,429,2474,23,HumanHTO5,batch1,T cell,CD4 Naive,0.50168,0.49832,1,1
a_AAACCTGCAAACCTAC-1,bmcite,2741,851,4799,25,HumanHTO3,batch1,T cell,CD4 Memory,0.431308,0.568692,4,4
a_AAACCTGCAAGGTGTG-1,bmcite,2099,843,5434,25,HumanHTO2,batch1,Mono/DC,CD14 Mono,0.572097,0.427903,2,2


In [18]:
metadata.shape

(30672, 13)

In [19]:
# separate CD4 and CD8 in l1
metadata["celltype.l1.5"] = metadata["celltype.l1"].values
metadata.loc[metadata["celltype.l2"].str.startswith("CD4"), "celltype.l1.5"] = "CD4 T"
metadata.loc[metadata["celltype.l2"].str.startswith("CD8"), "celltype.l1.5"] = "CD8 T"

In [20]:
embedding = umap.UMAP(random_state=0).fit_transform(encodings)

In [21]:
plot_df = metadata.copy()
plot_df["UMAP1"] = embedding[:, 0]
plot_df["UMAP2"] = embedding[:, 1]

In [22]:
fig = px.scatter(plot_df, x="UMAP1", y="UMAP2", color="celltype.l1.5")
fig.show()