# SPACI tutorial

## 1. import python modules

In [1]:
import argparse

import yaml
from model.dataloader import TripletData
from model.model import PairModel, TripletModel, TripletGraphModel
import os
import torch.utils.data as Data
import torch
import random
import numpy as np
import pandas as pd
from tqdm import tqdm

### 1.1 load yaml configurations

In [2]:
yaml_file = 'configure_0.90.yml'
with open(yaml_file) as f:
    cfg = yaml.load(f, Loader=yaml.FullLoader)

### 1.2 fix seed

In [3]:
seed = 10
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

### 2. build dataloader

In [4]:
def build_dataset(cfg, train=1):
    if train == 1:
        root = cfg['DATASET']['TRAIN_ROOT']
    elif train == 0:
        root = cfg['DATASET']['TEST_ROOT']
    elif train == 2:
        root = cfg['DATASET']['PRED_ROOT']

    if cfg['DATASET']['NAME'] == 'TripletData':
        dataset = TripletData(istrain=train,
                              dataroot=root,
                              matrixroot=cfg['DATASET']['MATRIX_ROOT'])
    else:
        raise NotImplementedError
    return dataset

### 2.1 build train_dataset and test dataset
We set up three different modes (mode=0, 1, 2) when building the dataset
1. mode=1 for training. In training mode
2. mode=2 for evaluate the f1 scores
3. mode=3 for prediction and save the embeddings and predictions

In [5]:
train_dataset = build_dataset(cfg, train=1)

train_dataloader = Data.DataLoader(train_dataset,
                                batch_size=cfg['TRAIN']['BATCH_SIZE'],
                                shuffle=True)

In [6]:
test_dataset = build_dataset(cfg, train=0)

test_dataloader = Data.DataLoader(test_dataset,
                                batch_size=cfg['TRAIN']['BATCH_SIZE'],
                                shuffle=False)

### 2.2 Load spatial graph 
we processed the graph in the form of adjacent list, saving in csv files.

In [7]:
adj = pd.read_csv(cfg['DATASET']['ADJ_ROOT'], header=0, index_col=0)
adj = torch.from_numpy(adj.to_numpy()).float()
best_f1 = 0

### 3. build model
We provided an interface in the yaml file for end-users to build up their own model structures.
1. Input Dim is the number of genes in your dataset. This is the input of MLP trunk. In this demo, the dimension is 4000
2. Graph Dim is the number of genes in your dataset. This is the input of Graph trunk. In this demo, the dimension is 4000
3. Mlp_hid_dim is the hidden dimensions of the MLP layers
4. Graph_hid_dim is the hidden dimensions of the Graph layers
5. save_path is the dir of save the checkpoints

In [8]:
def build_model(cfg):
    lr = float(cfg['TRAIN']['LR'])
    if cfg['MODEL']['NAME'] == 'TripletGraphModel':
        model = TripletGraphModel(
            lr=lr,
            input_dim=cfg['MODEL']['INPUT_DIM'],
            graph_dim=cfg['MODEL']['GRAPH_DIM'],
            mlp_channels=cfg['MODEL']['MLP_HID_DIMS'],
            graph_channels=cfg['MODEL']['GRAPH_HID_DIMS'],
            save_path=cfg['MODEL']['SAVE_PATH'])
        return model
    else:
        raise NotImplementedError

In [9]:
model = build_model(cfg)

### train and select best f1
### 3.1 inference
we first define the train/inference/prediction of spaci. 
1. Infer was used to evaluate the performance of spaci. When "load_model" was not assign, or is None, we will use the default parameters in the pre-defined "model" object. Otherwise, we will load the saved checkpoint from disk.
2. verbose was used to print the evaluation results. When verbose=True, we will print the performance of current model in validation set.

In [10]:
@torch.no_grad()
def infer(model, cfg, load_model=None, verbose=False):
    seed = 10
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    dataset = build_dataset(cfg, train=0)
    dataloader = Data.DataLoader(dataset,
                                 batch_size=cfg['TEST']['BATCH_SIZE'],
                                 shuffle=False)
    if load_model is not None:
        model_path = os.path.join(cfg['MODEL']['SAVE_PATH'],
                                  load_model + '.pth')
        model.load(model_path)

    TP = 0
    TN = 0
    FP = 0
    FN = 0
    label_tp = 0
    label_tn = 0

    savepred = open(cfg['TEST']['PRED'], 'w')
    savepred.write('ligand,receptor,truelabel,pred\n')
    adj = pd.read_csv(cfg['DATASET']['ADJ_ROOT'], header=0, index_col=0)
    adj = torch.from_numpy(adj.to_numpy()).float()

    for batch, (x1, x2, y, x1id, x2id) in enumerate(dataloader):
        inputs = {}
        inputs['x1'] = x1
        inputs['x2'] = x2
        inputs['label'] = y
        inputs['adj'] = adj
        threshold = cfg['TEST']['THRESHOLD']

        model.set_input(inputs, istrain=0)
        dis = model.inference()
        # print(pred.shape, y.shape)
        dis = dis.detach().cpu()

        pred = torch.zeros(dis.shape)
        pred[dis > threshold] = 1

        TP += ((pred == 1) & (y == 1)).sum()
        TN += ((pred == 0) & (y == 0)).sum()
        FP += ((pred == 1) & (y == 0)).sum()
        FN += ((pred == 0) & (y == 1)).sum()
        label_tp += (y == 1).sum()
        label_tn += (y == 0).sum()

        for i in range(len(x1id)):
            id1, id2 = x1id[i], x2id[i]
            savepred.write('%s,%s,%d,%d,%.4f\n' %
                           (id1, id2, y[i], int(pred[i]), dis[i]))

    precision = TP / (TP + FP) if (TP + FP) else 0
    recall = TP / (TP + FN) if (TP + FN) else 0
    sensitive = TP / (TP + FN) if (TP + FN) else 0
    specity = TN / (TN + FP) if (TN + FP) else 0
    acc = (TP + TN) / (label_tp + label_tn)
    F1 = (2 * precision * recall) / (precision + recall)

    if verbose:
        message = '\n------------------------results----------------------\n'
        message += '{:>10d}\t{:>10d}\n'.format(TP, label_tp)
        message += '{:>10d}\t{:>10d}\n'.format(TN, label_tn)
        message += '{:>10}\t{:>10.4f}\n'.format('acc:', acc)
        message += '{:>10}\t{:>10.4f}\n'.format('precision:', precision)
        message += '{:>10}\t{:>10.4f}\n'.format('recall:', recall)
        message += '{:>10}\t{:>10.4f}\n'.format('Specificity:', specity)
        message += '{:>10}\t{:>10.4f}\n'.format('Sensitivity:', sensitive)
        message += '{:>10}\t{:>10.4f}\n'.format('F1-measure:', F1)
        message += '------------------------------------------------------\n'
        print(message)
    return F1

### 3.2 train spaci
we evaluate the performance of each epoch, and saved the model with best f1 scores as our final checkpoint.

In [11]:
best_f1 = 0
for epoch in tqdm(range(cfg['TRAIN']['EPOCHS'])):
    # train
    for batch, (a, p, n, aid, pid, nid) in enumerate(train_dataloader):
        inputs = {}
        inputs['A'] = a; inputs['P'] = p; inputs['N'] = n
            
        inputs['adj'] = adj
        model.set_input(inputs, istrain=1)
        model.single_update()
    
    f1 = infer(model, cfg, verbose=False)
    if f1 > best_f1:
        best_f1 = f1
        best_epoch = epoch
        model.save('best_f1')
model.save('final')

100%|██████████████████████████████████████████████████████████████████████████████| 10/10 [05:30<00:00, 33.02s/it]


### 3.3 show evaluation results
after training spaci, we print the performance of the model.
We will show:
1. accuracy
2. Precision
3. Recall
4. Specificity
5. Sensitivity
6. F1-score

In [12]:
f1 = infer(model, cfg, load_model='best_f1', verbose=True)



------------------------results----------------------
       103	       123
       980	       981
      acc:	    0.9810
precision:	    0.9904
   recall:	    0.8374
Specificity:	    0.9990
Sensitivity:	    0.8374
F1-measure:	    0.9075
------------------------------------------------------



### 3.4 saved the embeddings and predictions
1. the prediction will be saved in "results/predict.csv"
2. The embeddings will be saved in "results/embed_ligand.csv" and "results/embed_receptor.csv"
3. We set up the threshold as 0.9. Which means, a larger cos-similarity of ligand-receptor pair score (>0.9) will be considered as positive.

In [13]:
@torch.no_grad()
def predict(cfg, load_model=None):
    seed = 10
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    model = build_model(cfg)

    dataset = build_dataset(cfg, train=2)
    dataloader = Data.DataLoader(dataset,
                                 batch_size=cfg['TEST']['BATCH_SIZE'],
                                 shuffle=False)
    if load_model is not None:
        model_path = os.path.join(cfg['MODEL']['SAVE_PATH'],
                                  load_model + '.pth')
        model.load(model_path)

    savepred = open(cfg['TEST']['PRED'], 'w')
    savepred.write('ligand,receptor,truelabel,pred\n')

    adj = pd.read_csv(cfg['DATASET']['ADJ_ROOT'], header=0,
                      index_col=0)  #, chunksize=1000)
    adj = torch.from_numpy(adj.to_numpy()).float()
    threshold = cfg['TEST']['THRESHOLD']

    embs1 = None
    embs2 = None
    index1 = None
    index2 = None

    for batch, (x1, x2, y, x1id, x2id) in enumerate(dataloader):
        inputs = {}
        inputs['x1'] = x1
        inputs['x2'] = x2
        inputs['label'] = y
        inputs['adj'] = adj

        model.set_input(inputs, istrain=0)
        dis, emb1, emb2 = model.inference(return_intermediate=True)
        # print(x1id, emb1.shape)
        dis = dis.detach().cpu()
        emb1 = emb1.detach().cpu().numpy()
        emb2 = emb2.detach().cpu().numpy()

        if embs1 is None:
            embs1 = emb1
            index1 = x1id
        else:
            embs1 = np.concatenate([embs1, emb1], axis=0)
            index1 = np.concatenate([index1, x1id], axis=0)

        if embs2 is None:
            embs2 = emb2
            index2 = x2id
        else:
            embs2 = np.concatenate([embs2, emb2], axis=0)
            index2 = np.concatenate([index2, x2id], axis=0)

        pred = torch.zeros(dis.shape)
        pred[dis > threshold] = 1

        for i in range(len(x1id)):
            id1, id2 = x1id[i], x2id[i]
            savepred.write('%s,%s,%d,%d,%.4f\n' %
                           (id1, id2, y[i], int(pred[i]), dis[i]))
        df1 = pd.DataFrame(embs1, index=index1)
        df2 = pd.DataFrame(embs2, index=index2)
        df1.to_csv(cfg['TEST']['EMB1'])
        df2.to_csv(cfg['TEST']['EMB2'])

    print('done')

In [14]:
predict(cfg, load_model='best_f1')


done
