In [2]:
import argparse
import sys
import os, random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import subgraph
from torch_scatter import scatter
import dataset
from dataset import input_dataset
from utils import read_data, scGT_output
from parse import parse_method
from train import model_train
from plot import umap_emb
from scGT import *
import matplotlib.pyplot as plt
import time
import warnings
warnings.filterwarnings('ignore')

def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [None]:
parser = argparse.ArgumentParser()
default_args = {
    'dataset': 'data',
    'data_dir': '../data/dataset_name/', 
    'device': 0, 
    'seed': 42, 
    'maxepochs': 5000,
    'eval_step': 10,
    'model_dir': '../model/', 
    'hidden_channels': 128,
    'lr': 1e-4,              
    'weight_decay': 1e-2,
    'lamda1': 0.1,          
    'lamda2': 0.1,          
    'num_batch': 1,          
    'early_stop':30,        
    'is_move': True         
}
"""
'data_dir': Input path
'model_dir': Model storage path
'seed': Seed
'lr': Learning rate
'lamda1': Hard regularity relaxation coefficient
'lamda2': Query graph regularity relaxation coefficient
'num_batch': How many subgraphs to create. If the GPU allows, it's best to be as small as possible
'early_stop': Early stop epoch
'is_move': Whether move query data to reference if confidence score > 0.95
"""


args = argparse.Namespace(**default_args)
print(args)

fix_seed(args.seed)
device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")

In [None]:
%%time
import time

### Load and preprocess data ###
dataset = input_dataset(args.data_dir, args.dataset)
dataset, x, n, c, d, split_idx, adjs, adj_loss_inter, adj_loss_intra2 = read_data(args, dataset)

In [None]:
### Load method ###
model = parse_method(args, dataset, n, c, d, device)
model.train()
print('MODEL:', model)

In [None]:
### Train ###
model_train(args, dataset, model, split_idx, device, x, n, adjs, adj_loss_inter, adj_loss_intra2)

In [None]:
### umap_emb ###
embedding_umap, label, pre, tech = umap_emb(args, model, dataset, x, adjs)

In [None]:
# celltype
%matplotlib inline
labels = np.unique(label)
for i in labels:
    plt.scatter(embedding_umap[label == i, 0], embedding_umap[label == i, 1], s=0.5, label=str(i))
plt.legend()
plt.show()

In [None]:
# tech
%matplotlib inline
techs = np.unique(tech) 
for i in techs:
    plt.scatter(embedding_umap[tech == i, 0], embedding_umap[tech == i, 1], s=0.01, label=str(i))
plt.legend()
plt.show()

In [None]:
# 'data_dir'+'results/embedding.pt'   is the low-dimensional embedding for joint visualization
output = scGT_output(args)
output[-20:-10] 