In [None]:
import torch
import torch.nn.functional as F
import numpy as np

import dgl
import copy
import gc

#ScDeepSort Imports
from dance.modules.single_modality.cell_type_annotation.scdeepsort import ScDeepSort
from dance.utils import set_seed

import os
os.environ["DGLBACKEND"] = "pytorch"
from pprint import pprint
from dance.datasets.singlemodality import ScDeepSortDataset

import scanpy as sc
from dance.transforms import AnnDataTransform, FilterGenesPercentile
from dance.transforms import Compose, SetConfig
from dance.transforms.graph import PCACellFeatureGraph, CellFeatureGraph
from dance.typing import LogLevel, Optional


in_channels = 400
hidden_channels = 400
out_channels = 100
num_classes = 21


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
set_seed = 42

#ScDeepSort

model = ScDeepSort(dim_in=in_channels, dim_hid=hidden_channels, num_layers=1, species='mouse', tissue='Kidney', device=device)
preprocessing_pipeline = Compose(
    AnnDataTransform(sc.pp.normalize_total, target_sum=1e-4),
    AnnDataTransform(sc.pp.log1p),
    FilterGenesPercentile(min_val=1, max_val=99, mode="sum"),
)
def train_pipeline(n_components: int = 400, log_level: LogLevel = "INFO"):
    return Compose(
        PCACellFeatureGraph(n_components=n_components, split_name="train"),
        SetConfig({"label_channel": "cell_type"}),
        log_level=log_level,
    )
def test_pipeline(n_components: int = 400, log_level: LogLevel = "INFO"):
    return Compose(
        PCACellFeatureGraph(n_components=n_components, split_name="test"),
        SetConfig({"label_channel": "cell_type"}),
        log_level=log_level,
    )

In [None]:
dataset = ScDeepSortDataset(species="mouse", tissue="Kidney",
                            train_dataset=["4682", "203"], test_dataset=["203", "203"])
data = dataset.load_data()
preprocessing_pipeline(data)
train_pipeline()(data)
y_train = data.get_train_data(return_type="torch")[1]
y_test = data.get_test_data(return_type="torch")[1]
y_train = torch.cat([y_train, y_test], dim=0)
y_train = torch.argmax(y_train, 1)
y_test = torch.argmax(y_test, 1)

In [None]:
model.fit(graph=data.data.uns["CellFeatureGraph"], labels=y_train)

In [None]:
test_pipeline()(data)

In [None]:
result = model.predict_proba(graph=data.data.uns["CellFeatureGraph"])

In [None]:
result = torch.tensor(result)
predicted = torch.argmax(result, 1)
print(predicted)
correct = (predicted == y_test).sum().item()
total = y_test.numel()
accuracy = correct / total
print('accuracy: ', accuracy)