# NACT-Pytorch Tutorial: Classifying NeuroCOVID scRNAseq

In this notebook, we will go over how to load a pre-trained NACT model, and how to make predictions and evaluate the performance

## Load in pre-trained NACT model

Since our implementation is in pytorch, we can use the `load` funtion that pytorch provides. Our model is stored as a dict, with `epoch` corresponding to the current epoch, and `Saved_Model` corresponding to the model.

In [1]:
import torch

model_dict = torch.load("/home/ubuntu/SindiLab/NACT/ClassifierWeights/pbmc-Best_model_Best.pth")

nact = model_dict["Saved_Model"]

print(nact)

FFAttentionClassifier(
  (layer0): Linear(in_features=17789, out_features=100, bias=True)
  (attention0): Linear(in_features=100, out_features=100, bias=True)
  (layer1): Linear(in_features=100, out_features=50, bias=True)
  (attention1): Linear(in_features=50, out_features=50, bias=True)
  (layer2): Linear(in_features=50, out_features=25, bias=True)
  (attention2): Linear(in_features=25, out_features=25, bias=True)
  (out_layer): Linear(in_features=25, out_features=11, bias=True)
  (test_layer): Linear(in_features=100, out_features=25, bias=True)
  (relu): ReLU()
  (leaky_relu): LeakyReLU(negative_slope=0.01)
)


### Determine the device where you want to generate data from

We recommend using GPUs for *training*, but for inference either CPUs or GPUs should work just fine (but GPUs would be faster). 

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if str(device) == "cuda":
    print('Using GPU (CUDA)')
else:
    print('Using CPU')

Using GPU (CUDA)


## Load in Data

Let us load in the test data set now

In [3]:
from NACT.utils import *
from NACT import Scanpy_IO

In [4]:
_, test_data_loader = Scanpy_IO('/home/ubuntu/RawData/68kPBMCs_preprocessed.h5ad',
                                                        test_no_valid = True,
                                                        log=False,
                                                        verbose = 1)

==> Reading in Scanpy/Seurat AnnData
    -> Splitting Train and Validation Data


  res = method(*args, **kwargs)


==> Using cluster info for generating train and validation labels
==> Checking if we have sparse matrix into dense
    -> Seems the data is dense
==> sample of the training data: tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000, 16.4474,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000, 11.3572,  0.0000]])
==> sample of the test data: tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])


## Use the `evaluate_classifier` function for a full report!

We provide easy-to-use utilities for making things easier. One of these utility functions is `evaluate_classifier`, which provides a full classification report if wanted (using `sklearn`). (valid_data_loader, model,classification_report=False)


In [6]:
evaluate_classifier(test_data_loader, nact, classification_report=True)

==> Evaluating on Validation Set:
    -> Accuracy of classifier network on validation set: 92.2329 %
    -> Non-Weighted F1 Score on validation set: 0.7482 
    -> Weighted F1 Score on validation set: 0.9217 
              precision    recall  f1-score   support

         0.0       0.90      0.90      0.90      1791
         1.0       0.88      0.88      0.88      1545
         2.0       0.95      0.96      0.95      1515
         3.0       0.90      0.90      0.90       697
         4.0       0.99      0.99      0.99       483
         5.0       0.97      0.97      0.97       466
         6.0       1.00      0.99      1.00       413
         7.0       0.97      0.85      0.90        71
         8.0       0.00      0.00      0.00         8
         9.0       0.00      0.00      0.00         2

    accuracy                           0.92      6991
   macro avg       0.75      0.74      0.75      6991
weighted avg       0.92      0.92      0.92      6991



  _warn_prf(average, modifier, msg_start, len(result))


(array([0., 2., 4., ..., 2., 4., 3.]),
 array([0., 2., 4., ..., 2., 4., 3.]),
 0.7482368776409196)