In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import scprep

In [277]:
import torch.nn.functional as F

In [146]:
from torchvision import transforms

In [29]:
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader


### Cell type classification using single cell RNA seq profiles

For this notebook, I will use the retinal bipolar cells dataset from the visualization notebook and explore how well expression profiles can predict the cell type (the ground truth for which is obtained from expert annotation by the authors) using a neural network implementation in PyTorch.


In [3]:
data_raw = pd.read_pickle("/Users/anuraglimdi/Desktop/Single_cell_workshop/Datasets/retinal_bipolar/retinal_bipolar_data.pickle.gz")
metadata = pd.read_pickle("/Users/anuraglimdi/Desktop/Single_cell_workshop/Datasets/retinal_bipolar/retinal_bipolar_metadata.pickle.gz")

In [4]:
data_raw.shape

(21552, 15524)

This is too much data to run on a local computer, going to reduce this data to fewer dimensions using PCA and then work with the first 50 or 100 components.

In [5]:
#using the scprep helper functions
data_pca = scprep.reduce.pca(data_raw, n_components=50, method='dense').to_numpy()

Standardizing the data (by column) with mean = 0 and variance = 1

In [166]:
scaler = StandardScaler()  #instantiating the standard scalar class
data_scaled = scaler.fit_transform(data_pca)

In [167]:
## convert cell type labels to numbers using pandas
labels, cluster_names = pd.factorize(metadata['CELLTYPE'])

### Splitting the dataset into training/test sets

Using the train_test_split function; note that this function is incredibly slow if not using the PCA reduced data

In [213]:
expr_train, expr_test, cell_train, cell_test = train_test_split(data_scaled, labels, test_size=0.2)

In [214]:
expr_train.shape

(17241, 50)

### Building the neural network using PyTorch

Defining the dataset class for the single cell RNA seq input data

In [255]:
class scRNAseq_Dataset(Dataset):
    def __init__(self, expression, labels):
        self.labels = labels     # cell type labels
        self.expression = expression   # PCA reduced and scaled expression matrix
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        label = self.labels[idx]     # label for the idx'th cell type
        expression = self.expression[idx].astype('float') #a vector of expression PCA components for each cells
        
        return expression, label

Defining the neural network itself:

Including two fully connected layers following by a log(softmax) output

In [278]:
class NeuralNetwork(nn.Module):
    def __init__(self, input_components):
        super(NeuralNetwork, self).__init__()
        self.linear1 = nn.Linear(input_components, 128) #dense (fully connected) layer going from input_components to 128
        self.activation = nn.ReLU()  # rectified linear unit activation function
        self.linear2 = nn.Linear(128, 50) #another dense layer (128 to 50)

    def forward(self, x):   #defining the forward propagation through the network
        x = self.linear1(x)    
        x = self.activation(x)
        x = self.linear2(x)
        output = F.log_softmax(x, dim=-1)
        
        return output

model = NeuralNetwork(input_components = 50)

Instantiation of the `Dataset` class using the split train and test data

In [279]:
training_data = scRNAseq_Dataset(expression=expr_train, labels=cell_train)
test_data = scRNAseq_Dataset(expression=expr_test, labels=cell_test)

Using the `DataLoader` class to get the iterable for batch learning

In [280]:
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

Loss function is negative log-likelihood (note that this is the appropriate function given that the model outputs log probabilities; if the output was labels, then the loss function would be a cross entropy function)

In [281]:
loss_function = nn.NLLLoss()   #negative log-likelihood loss

Setting parameters for the learning process

In [282]:
learning_rate = 1e-3
batch_size = 64
epochs = 50

The optimizer uses stochastic gradient descent

In [283]:
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

Defining functions for the training and test loop for the neural network model

In [291]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):    
        # Compute prediction and loss
        pred = model(X.float())    #convert input to type float otherwise forward propagation fails
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X.float())
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()   #count how many predictions are correct
            # this works by computing how many times the model prediction of highest log(probability) matches
            # the true label

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [292]:
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 0.908354  [    0/17241]
loss: 1.048406  [ 6400/17241]
loss: 0.758581  [12800/17241]
Test Error: 
 Accuracy: 81.0%, Avg loss: 0.717485 

Epoch 2
-------------------------------
loss: 0.721321  [    0/17241]
loss: 0.520383  [ 6400/17241]
loss: 0.685628  [12800/17241]
Test Error: 
 Accuracy: 81.4%, Avg loss: 0.707738 

Epoch 3
-------------------------------
loss: 0.669907  [    0/17241]
loss: 0.795179  [ 6400/17241]
loss: 0.647724  [12800/17241]
Test Error: 
 Accuracy: 81.6%, Avg loss: 0.698879 

Epoch 4
-------------------------------
loss: 0.829800  [    0/17241]
loss: 0.754176  [ 6400/17241]
loss: 0.601442  [12800/17241]
Test Error: 
 Accuracy: 81.9%, Avg loss: 0.690395 

Epoch 5
-------------------------------
loss: 0.832526  [    0/17241]
loss: 0.771664  [ 6400/17241]
loss: 0.669948  [12800/17241]
Test Error: 
 Accuracy: 82.1%, Avg loss: 0.677487 

Epoch 6
-------------------------------
loss: 0.578678  [    0/17241]
loss: 0.799996  [ 64

loss: 0.336177  [12800/17241]
Test Error: 
 Accuracy: 87.5%, Avg loss: 0.459638 

Epoch 47
-------------------------------
loss: 0.384228  [    0/17241]
loss: 0.520854  [ 6400/17241]
loss: 0.573929  [12800/17241]
Test Error: 
 Accuracy: 87.5%, Avg loss: 0.457010 

Epoch 48
-------------------------------
loss: 0.505392  [    0/17241]
loss: 0.236574  [ 6400/17241]
loss: 0.555562  [12800/17241]
Test Error: 
 Accuracy: 87.5%, Avg loss: 0.456293 

Epoch 49
-------------------------------
loss: 0.449335  [    0/17241]
loss: 0.309241  [ 6400/17241]
loss: 0.425991  [12800/17241]
Test Error: 
 Accuracy: 87.6%, Avg loss: 0.454103 

Epoch 50
-------------------------------
loss: 0.366268  [    0/17241]
loss: 0.544982  [ 6400/17241]
loss: 0.380674  [12800/17241]
Test Error: 
 Accuracy: 87.7%, Avg loss: 0.451009 

Done!
