In [1]:
import os
import sys
from pprint import pprint
sys.path.append('.')

import pandas as pd
import numpy as np
from tqdm import tqdm
import optuna
from sklearn import metrics
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim

from models.gat.gat_pytorch import GAT
from models.gat import params as gat_params
from utils.utils import *
from runners import tools

In [2]:
os.chdir('/home/lyz/co-phase-separation/PSGAT/')
DATA_ROOT = './data'

In [3]:
def dim_reduction_cor(X, y, k=20):
    cors = np.zeros((X.shape[1]))
    
    # calculate the correlation with y for each feature
    for i in range(X.shape[1]):
        cor = np.corrcoef(X[:, i], y)[0, 1]
        if not np.isnan(cor):
            cors[i] = cor
    
    features = np.zeros_like(cors).astype(bool)
    features[np.argsort(-cors)[:k]] = True
    
    return features, cors

In [4]:
def data(
    ppi, # ['integrate', 'biogrid-all', 'biogrid-htp']
    seqemb=False, # ProSE sequence embedding
    expr=False, # Gene expression
    subloc=False, # Sublocalization
    gomf=False, # GO molecular function
    comp=False, # GO cellular component
    weights=False,
    weights_thr=200, # edge weight threshold
    max_feat_dim=150, # k in feature selection
    seed=0 # random state
):
    # edges and edge_weights
    ppi_path = os.path.join(
        DATA_ROOT,           
        f'PPIN/{ppi.upper()}.csv'
    )
    edges = pd.read_csv(ppi_path)
    edge_weights = None
    
    if (ppi in ['string']) and (weights==True):      
        key = 'combined_score'
        edges = edges[edges.loc[:, key] > weights_thr].reset_index()
        edge_weights = edges['combined_score'] / 1000
        print(f'Filtered {ppi} network with thresh:', weights_thr)
    
    edges = edges[['A', 'B']].copy()
    edges = edges.dropna()
    edges, index = edges.values, edges.index.values
    ppi_genes = np.union1d(edges[:, 0], edges[:, 1])
    if edge_weights is not None:
        edge_weights = edge_weights.iloc[index].values
    
    # labels
    label_path = os.path.join(
        DATA_ROOT,
        f'Label/labels.csv'
    )
    labels = pd.read_csv(label_path).set_index('Gene')

    ## filter labels not in the PPI network
    print('Number of labels before filtering:', len(labels))
    labels = labels.loc[np.intersect1d(labels.index, ppi_genes)].copy()
    print('Number of labels after filtering:', len(labels))
    
    """正负样本比例：1:1"""
    ratio = 1
    pos_samples = labels[labels['label']==1]
    neg_samples = labels[labels['label']==0].sample(n=pos_samples.shape[0] * ratio, random_state=0)
    labels = pd.concat([pos_samples, neg_samples]).sort_index()
    
    ## train and test split
    train_ds, test_ds = train_test_split(
        labels,
        test_size=.2,
        random_state=seed,
        stratify=labels
    )
    
    genes = np.union1d(labels.index, ppi_genes)
    print('Total number of genes:', len(genes))
    
    # node attributes
    X = np.zeros((len(genes), 0))
    X = pd.DataFrame(X, index=genes)
    
    ## ProSE sequence embedding
    if seqemb:
        seqemb_path = os.path.join(
            DATA_ROOT,           
            f'NodeFeat/SeqEmb/seqemb_80d.pkl'
        )
        seqemb_feat = pd.read_pickle(seqemb_path).set_index('entry')
        columns = [f'seqemb_{i}' for i in range(seqemb_feat.shape[1])]
        seqemb_feat.columns = columns
        X = X.join(seqemb_feat, how='left')
        print('Sequence Embedding dataset shape:', seqemb_feat.shape)
    
    ## Gene expression
    
    X = X.fillna(0)
    
    N = len(X)
    mapping = dict(zip(genes, range(N)))
    
    # preprocessing
    ## remove self-loops
    mask = edges[:, 0] != edges[:, 1]
    edges = edges[mask]
    if edge_weights is not None:
        edge_weights = edge_weights[mask]
    
    ## remove duplicated connections
    df = pd.DataFrame(edges, columns=['A', 'B'])
    df[0] = np.sort(df[['A', 'B']].values).sum(axis=1)
    df = df.drop_duplicates(subset=0)
    edges, index = df.iloc[:, :2].values, df.index.values
    if edge_weights is not None:
        edge_weights = edge_weights[index]
        edge_weights = torch.tensor(edge_weights, dtype=torch.float32)
    
    edge_index = np.vectorize(mapping.__getitem__)(edges)
    
    ## node attribute matrix X
    degrees = np.zeros((N, 1))
    nodes, counts = np.unique(edge_index, return_counts=True)
    degrees[nodes, 0] = counts
    
    if X is None or not X.shape[1]:
        X = np.random.random((N, 50))

    if X.shape[1] < 50:
        X = np.concatenate([X, np.random.random((N, 50))], axis=1)
        
    # X = np.concatenate([X, degrees.reshape((-1, 1))], 1) # concat degree vector
    X = X.to_numpy()
    X = (X - X.mean(0, keepdims=True)) / (X.std(0, keepdims=True) + 1e-8) # normalization
    
    ## train and val split
    train, val = train_test_split(
        train_ds,
        test_size=.05,
        random_state=seed,
        stratify=train_ds
    )
    
    train_idx = [mapping[t] for t in train.index]
    val_idx = [mapping[v] for v in val.index]
    test_idx = [mapping[v] for v in test_ds.index]
    
    # feature selection
    red_idx = np.concatenate([train_idx, test_idx, val_idx], axis=0)
    red_y = np.concatenate([train.label, test_ds.label, val.label], axis=0)
    feats, cors = dim_reduction_cor(
        X[red_idx], 
        red_y.astype(np.float32), 
        k=max_feat_dim
    )
    X = X[:, feats]
    
    # Torch
    edge_index = torch.from_numpy(edge_index.T)
    edge_index = edge_index.to(torch.long).contiguous()
    
    X = torch.from_numpy(X).to(torch.float32)
    train_y = torch.tensor(train.label.astype(int), dtype=torch.float32)
    val_y = torch.tensor(val.label.astype(int), dtype=torch.float32)
    test_y = torch.tensor(test_ds.label.astype(int), dtype=torch.float32)
    
    print(f'\nNumber of edges in graph: {len(edges)}')
    print(f'Number of nodes in graph: {len(X)}')
    print(f'Shape of node features: {X.shape[0], X.shape[1]}\n')
    print('Using Edge Weights' if edge_weights is not None else 'Not using edge weights')
    
    return (edge_index, edge_weights), X, (train_idx, train_y), (val_idx, val_y), \
            (test_idx, test_y), genes

In [5]:
# (edge_index, edge_weights), X, (train_idx, train_y), (val_idx, val_y), (test_idx, test_y), genes = data(ppi='biogrid-all', seqemb=True, weights_thr=250, weights=False)

In [6]:
gat_0 = {
    'lr': 0.005,
    'weight_decay': 5e-4,
    'h_feats': [8, 1],
    'heads': [8, 1],
    'dropout': 0.6,
    'negative_slope': 0.2}

In [7]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [8]:
# Loss class
class Loss():
    def __init__(self, y, idx):
        self.y = y
        idx = np.array(idx)

        self.y_pos = y[y == 1]
        self.y_neg = y[y == 0]

        self.pos = idx[y.cpu() == 1]
        self.neg = idx[y.cpu() == 0]

    def __call__(self, out):
        loss_p = F.binary_cross_entropy_with_logits(
            out[self.pos].squeeze(), self.y_pos)
        loss_n = F.binary_cross_entropy_with_logits(
            out[self.neg].squeeze(), self.y_neg)
        loss = loss_p + loss_n
        return loss

# AUC calculation
def evalAUC(model, X, A, y, mask, logits=None):
    assert(model is not None or logits is not None)
    if model is not None:
        model.eval()
        with torch.no_grad():
            logits = model(X, A)
            logits = logits[mask]
    probs = torch.sigmoid(logits)
    probs = probs.cpu().numpy()
    y = y.cpu().numpy()
    auc = metrics.roc_auc_score(y, probs)
    fpr, tpr, _ = metrics.roc_curve(y, probs)
    return auc, fpr, tpr

# Model training
def train(
    params,
    X, A,
    edge_weights,
    train_y, train_idx,
    val_y, val_idx,
    savepath=''
):
    epochs = 1000
    
    model = GAT(in_feats=X.shape[1], **params)
    model.to(DEVICE)
    X = X.to(DEVICE)
    A = A.to(DEVICE)
    train_y = train_y.to(DEVICE)
    val_y = val_y.to(DEVICE)
    if edge_weights is not None:
        edge_weights = edge_weights.to(DEVICE)
    
    optimizer = optim.Adam(
        model.parameters(),
        lr=params['lr'],
        weight_decay=params['weight_decay']
    )
    loss_fnc = tools.Loss(train_y, train_idx)
    val_loss_fnc = tools.Loss(val_y, val_idx)
    
    iterable = tqdm(range(epochs))
    for i in iterable:
        model.train()
        logits = model(X, A, edge_weights=edge_weights)
        
        optimizer.zero_grad()
        loss = loss_fnc(logits)
        loss.backward()
        optimizer.step()
        
        logits = logits.detach()
        val_loss = val_loss_fnc(logits)
        train_auc, _, _ = evalAUC(None, 0, 0, train_y, 0, logits[train_idx])
        val_auc, _, _ = evalAUC(None, 0, 0, val_y, 0, logits[val_idx])
        
        tqdm.set_description(
            iterable,
            desc='Loss: %.4f; Val Loss %.4f; Train AUC %.4f. Validation AUC: %.4f' % (loss, val_loss, train_auc, val_auc)
        )
    
    score, fpr, tpr = evalAUC(model, X, A, val_y, val_idx)
    auc_dict = {
        'auc': score,
        'fpr': fpr,
        'tpr': tpr
    }
    print(f'Val AUC: {score}')
    
    return model, auc_dict

# Model testing
def test(model, X, A, test_ds=None):
    model.to(DEVICE).eval()
    X = X.to(DEVICE)
    A = A.to(DEVICE)
    
    with torch.no_grad():
        logits = model(X, A)
    probs = torch.sigmoid(logits)
    probs = probs.cpu().numpy()
    
    if test_ds is not None:
        test_idx, test_y = test_ds
        test_y = test_y.cpu().numpy()
        auc = metrics.roc_auc_score(test_y, probs[test_idx])
        fpr, tpr, _ = metrics.roc_curve(test_y, probs[test_idx])
        auc_dict = {
            'auc': auc,
            'fpr': fpr,
            'tpr': tpr
        }
        return probs, auc_dict
    return probs

In [9]:
N = 10
models, preds, val_aucs, test_aucs = (dict() for i in range(4))

In [10]:
for i in range(N):
    (edge_index, edge_weights), X, (train_idx, train_y), (val_idx, val_y), (test_idx, test_y), genes = \
                                                    data(ppi='integrate', 
                                                         seqemb=True, 
                                                         weights_thr=0, 
                                                         weights=False, 
                                                         seed=i)
    model, val_auc = train(gat_0, X, edge_index, edge_weights, train_y, train_idx, val_y, val_idx)
    pred, test_auc = test(model, X, edge_index, (test_idx, test_y))
    # break
    models[i] = model
    preds[i] = pred
    val_aucs[i] = val_auc
    test_aucs[i] = test_auc

Number of labels before filtering: 20398
Number of labels after filtering: 15939
Total number of genes: 15939
Sequence Embedding dataset shape: (20398, 80)

Number of edges in graph: 241464
Number of nodes in graph: 15939
Shape of node features: (15939, 80)

Not using edge weights


Loss: 0.8917; Val Loss 1.5704; Train AUC 0.8697. Validation AUC: 0.6779: 100%|██████████| 1000/1000 [00:26<00:00, 37.52it/s]


Val AUC: 0.8303571428571428
Number of labels before filtering: 20398
Number of labels after filtering: 15939
Total number of genes: 15939
Sequence Embedding dataset shape: (20398, 80)


Loss: 1.4208; Val Loss 1.5226; Train AUC 0.6159. Validation AUC: 0.5064:   0%|          | 4/1000 [00:00<00:26, 37.12it/s]


Number of edges in graph: 241464
Number of nodes in graph: 15939
Shape of node features: (15939, 80)

Not using edge weights


Loss: 0.9337; Val Loss 0.9991; Train AUC 0.8579. Validation AUC: 0.8157: 100%|██████████| 1000/1000 [00:25<00:00, 38.49it/s]


Val AUC: 0.8660714285714285
Number of labels before filtering: 20398
Number of labels after filtering: 15939
Total number of genes: 15939
Sequence Embedding dataset shape: (20398, 80)


Loss: 1.5558; Val Loss 1.4904; Train AUC 0.5945. Validation AUC: 0.5268:   0%|          | 4/1000 [00:00<00:27, 35.67it/s]


Number of edges in graph: 241464
Number of nodes in graph: 15939
Shape of node features: (15939, 80)

Not using edge weights


Loss: 0.9138; Val Loss 1.1333; Train AUC 0.8708. Validation AUC: 0.7615: 100%|██████████| 1000/1000 [00:25<00:00, 38.89it/s]


Val AUC: 0.7397959183673468
Number of labels before filtering: 20398
Number of labels after filtering: 15939
Total number of genes: 15939
Sequence Embedding dataset shape: (20398, 80)


Loss: 1.5312; Val Loss 1.4657; Train AUC 0.5463. Validation AUC: 0.5816:   0%|          | 4/1000 [00:00<00:29, 33.24it/s]


Number of edges in graph: 241464
Number of nodes in graph: 15939
Shape of node features: (15939, 80)

Not using edge weights


Loss: 0.9184; Val Loss 1.1346; Train AUC 0.8680. Validation AUC: 0.7755: 100%|██████████| 1000/1000 [00:27<00:00, 36.88it/s]


Val AUC: 0.8290816326530612
Number of labels before filtering: 20398
Number of labels after filtering: 15939
Total number of genes: 15939
Sequence Embedding dataset shape: (20398, 80)


Loss: 1.3816; Val Loss 1.2185; Train AUC 0.6350. Validation AUC: 0.7392:   0%|          | 4/1000 [00:00<00:27, 36.50it/s]


Number of edges in graph: 241464
Number of nodes in graph: 15939
Shape of node features: (15939, 80)

Not using edge weights


Loss: 0.8945; Val Loss 1.6566; Train AUC 0.8750. Validation AUC: 0.7755: 100%|██████████| 1000/1000 [00:26<00:00, 37.16it/s]


Val AUC: 0.7920918367346939
Number of labels before filtering: 20398
Number of labels after filtering: 15939
Total number of genes: 15939
Sequence Embedding dataset shape: (20398, 80)


Loss: 1.5745; Val Loss 1.8866; Train AUC 0.5656. Validation AUC: 0.5191:   0%|          | 4/1000 [00:00<00:32, 31.04it/s]


Number of edges in graph: 241464
Number of nodes in graph: 15939
Shape of node features: (15939, 80)

Not using edge weights


Loss: 0.9296; Val Loss 1.8334; Train AUC 0.8723. Validation AUC: 0.6352: 100%|██████████| 1000/1000 [00:27<00:00, 36.21it/s]


Val AUC: 0.7385204081632654
Number of labels before filtering: 20398
Number of labels after filtering: 15939
Total number of genes: 15939
Sequence Embedding dataset shape: (20398, 80)


Loss: 1.5470; Val Loss 1.4941; Train AUC 0.5500. Validation AUC: 0.5497:   0%|          | 4/1000 [00:00<00:27, 36.59it/s]


Number of edges in graph: 241464
Number of nodes in graph: 15939
Shape of node features: (15939, 80)

Not using edge weights


Loss: 0.8784; Val Loss 1.7418; Train AUC 0.8864. Validation AUC: 0.7589: 100%|██████████| 1000/1000 [00:25<00:00, 39.04it/s]


Val AUC: 0.7551020408163266
Number of labels before filtering: 20398
Number of labels after filtering: 15939
Total number of genes: 15939
Sequence Embedding dataset shape: (20398, 80)


Loss: 1.5359; Val Loss 1.3786; Train AUC 0.5688. Validation AUC: 0.6556:   0%|          | 4/1000 [00:00<00:26, 37.76it/s]


Number of edges in graph: 241464
Number of nodes in graph: 15939
Shape of node features: (15939, 80)

Not using edge weights


Loss: 0.9284; Val Loss 1.7365; Train AUC 0.8734. Validation AUC: 0.7494: 100%|██████████| 1000/1000 [00:25<00:00, 38.79it/s]


Val AUC: 0.8048469387755102
Number of labels before filtering: 20398
Number of labels after filtering: 15939
Total number of genes: 15939
Sequence Embedding dataset shape: (20398, 80)


Loss: 1.5859; Val Loss 1.3429; Train AUC 0.5084. Validation AUC: 0.6333:   0%|          | 4/1000 [00:00<00:26, 37.69it/s]


Number of edges in graph: 241464
Number of nodes in graph: 15939
Shape of node features: (15939, 80)

Not using edge weights


Loss: 0.9667; Val Loss 1.1934; Train AUC 0.8389. Validation AUC: 0.8074: 100%|██████████| 1000/1000 [00:25<00:00, 39.19it/s]


Val AUC: 0.8596938775510204
Number of labels before filtering: 20398
Number of labels after filtering: 15939
Total number of genes: 15939
Sequence Embedding dataset shape: (20398, 80)


Loss: 1.3648; Val Loss 1.4780; Train AUC 0.6545. Validation AUC: 0.5599:   0%|          | 4/1000 [00:00<00:27, 36.73it/s]


Number of edges in graph: 241464
Number of nodes in graph: 15939
Shape of node features: (15939, 80)

Not using edge weights


Loss: 0.8742; Val Loss 1.8566; Train AUC 0.8722. Validation AUC: 0.6244: 100%|██████████| 1000/1000 [00:25<00:00, 38.70it/s]


Val AUC: 0.7257653061224489


In [11]:
test_aucs

{0: {'auc': 0.8068940531028415,
  'fpr': array([0.        , 0.        , 0.        , 0.00719424, 0.00719424,
         0.01438849, 0.01438849, 0.02158273, 0.02158273, 0.02877698,
         0.02877698, 0.03597122, 0.03597122, 0.04316547, 0.04316547,
         0.05035971, 0.05035971, 0.07194245, 0.07194245, 0.07913669,
         0.07913669, 0.10071942, 0.10071942, 0.10791367, 0.10791367,
         0.11510791, 0.11510791, 0.12230216, 0.12230216, 0.15107914,
         0.15107914, 0.15827338, 0.15827338, 0.17266187, 0.17266187,
         0.17985612, 0.17985612, 0.18705036, 0.18705036, 0.1942446 ,
         0.1942446 , 0.20143885, 0.20143885, 0.23021583, 0.23021583,
         0.23741007, 0.23741007, 0.25179856, 0.25179856, 0.25899281,
         0.25899281, 0.26618705, 0.26618705, 0.28057554, 0.28057554,
         0.29496403, 0.29496403, 0.30215827, 0.30215827, 0.31654676,
         0.31654676, 0.32374101, 0.32374101, 0.33093525, 0.33093525,
         0.34532374, 0.34532374, 0.35251799, 0.35251799, 0.37410

In [12]:
from sklearn.metrics import confusion_matrix, recall_score, precision_score, precision_recall_curve, auc

# help(precision_recall_curve)

model.to(DEVICE).eval()
X = X.to(DEVICE)
edge_index = edge_index.to(DEVICE)
    
with torch.no_grad():
    logits = model(X, edge_index)
probs = torch.sigmoid(logits)
probs = probs.cpu().numpy()

precision, recall, thresholds = precision_recall_curve(test_y, probs[test_idx], pos_label=1)

# precision

auc(recall, precision)

# confusion_matrix(test_y, probs[test_idx].round(), labels=(1, 0))

0.7578969880060369

In [12]:
ppi = 'integrate'
# edges and edge_weights
ppi_path = os.path.join(
    DATA_ROOT,           
    f'PPIN/{ppi.upper()}.csv'
)
edges = pd.read_csv(ppi_path)
edge_weights = None
    
if (ppi in ['string']) and (weights==True):      
    key = 'combined_score'
    edges = edges[edges.loc[:, key] > weights_thr].reset_index()
    edge_weights = edges['combined_score'] / 1000
    print(f'Filtered {ppi} network with thresh:', weights_thr)
    
edges = edges[['A', 'B']].copy()
edges = edges.dropna()
edges, index = edges.values, edges.index.values
ppi_genes = np.union1d(edges[:, 0], edges[:, 1])
if edge_weights is not None:
    edge_weights = edge_weights.iloc[index].values
    
# labels
label_path = os.path.join(
    DATA_ROOT,
    f'Label/labels.csv'
)
labels = pd.read_csv(label_path).set_index('Gene')

## filter labels not in the PPI network
print('Number of labels before filtering:', len(labels))
labels = labels.loc[np.intersect1d(labels.index, ppi_genes)].copy()
print('Number of labels after filtering:', len(labels))

Number of labels before filtering: 20398
Number of labels after filtering: 15939


In [14]:
ratio = 1
pos_samples = labels[labels['label']==1]
neg_samples = labels[labels['label']==0].sample(n=pos_samples.shape[0] * ratio, random_state=0)
labels = pd.concat([pos_samples, neg_samples]).sort_index()

In [16]:
labels.value_counts()

label
0        694
1        694
dtype: int64

In [57]:
## train and test split
train_ds, test_ds = train_test_split(
    labels,
    test_size=.2,
    random_state=1234,
    stratify=labels
)

In [71]:
pos_samples = train_ds[train_ds['label']==1]
neg_samples = train_ds[train_ds['label']==0].sample(n=pos_samples.shape[0] * 3, random_state=0)
train_ds = pd.concat([pos_samples, neg_samples]).sort_index()

In [13]:
# average AUC for validation set
avgValAUC = np.mean([v['auc'] for v in val_aucs.values()])
# average AUC for test set
avgTestAUC = np.mean([v['auc'] for v in test_aucs.values()])

In [14]:
print('Average AUC for val: {:.2f}, test: {:.2f}'.format(avgValAUC, avgTestAUC))

Average AUC for val: 0.79, test: 0.79


In [12]:
# mkdir
SAVE_ROOT = './saves/INTEGRATE_ProSE80d/'
if not os.path.exists(SAVE_ROOT):
    os.mkdir(SAVE_ROOT)
    os.mkdir(os.path.join(SAVE_ROOT, f'embeddings/'))
    os.mkdir(os.path.join(SAVE_ROOT, f'pairwise_cosine/'))
    os.mkdir(os.path.join(SAVE_ROOT, f'edge_cosine/'))
    os.mkdir(os.path.join(SAVE_ROOT, f'models/'))
    
# save models
for idx, model in tqdm(models.items()):
    torch.save(model, os.path.join(SAVE_ROOT, f'models/model_{idx}.pt'))
    torch.save(model.featuremap1.cpu(), os.path.join(SAVE_ROOT, f'embeddings/model_{idx}.pt'))

# save results
import pickle
dict1 = {
    'preds': preds,
    'val_aucs': val_aucs,
    'test_aucs': test_aucs,
    'genes': genes
}
for key, val in dict1.items():
    with open(os.path.join(SAVE_ROOT, f'{key}.pkl'), 'wb') as f:
        pickle.dump(val, f)

# with open('saved_dictionary.pkl', 'rb') as f:
#     loaded_dict = pickle.load(f)

100%|██████████| 10/10 [00:00<00:00, 42.56it/s]


In [13]:
from sklearn.metrics.pairwise import cosine_distances, cosine_similarity

def save_models(ppi, model):
    ppi_path = os.path.join(
        DATA_ROOT,           
        f'PPIN/{ppi.upper()}.csv'
    )
    edges = pd.read_csv(ppi_path)

    featuremap = model.featuremap1.cpu()
    featuremap_cosine = pd.DataFrame(cosine_similarity(featuremap))
    featuremap_cosine.columns, featuremap_cosine.index = genes, genes

    cosim_list = list()
    for _, x in edges.iterrows():
        cosim = featuremap_cosine[x['A']][x['B']]
        cosim_list.append(cosim)
    edges['cosim'] = cosim_list
    edges.loc[edges['cosim'] < 0, 'cosim'] = 0
    edges.loc[edges['cosim'] > 1, 'cosim'] = 1
    
    return featuremap_cosine, edges

In [14]:
# save node embeddings
for idx, model in tqdm(models.items()):
    featuremap_cosine, featuremap_edge = save_models('integrate', model)
    featuremap_cosine.to_pickle(os.path.join(SAVE_ROOT, f'pairwise_cosine/model_{idx}.pkl'))
    featuremap_edge.to_pickle(os.path.join(SAVE_ROOT, f'edge_cosine/model_{idx}.pkl'))

100%|██████████| 10/10 [02:52<00:00, 17.23s/it]
