In [1]:
import os
import argparse
import gc
import torch
from tqdm import tqdm
from rdkit import Chem
import numpy as np
import json
import copy
from utils import *
from sklearn.metrics import f1_score, roc_auc_score, average_precision_score
from model.predictor import Predictor
from model.sampler import Sampler
from torch_geometric.utils import degree
from torch.utils.data.distributed import DistributedSampler
from data_process import smile_to_graph, read_smiles, read_interactions, generate_node_subgraphs, read_network
from sklearn.model_selection import StratifiedKFold, KFold
from train_eval import train, test, eval
import random
from main import *
import pdb
import sys

parser = argparse.ArgumentParser(description='TIGER')
args, unknown = parser.parse_known_args()

args.model_name = 'TIGER'
args.dataset = 'drugbank'
args.folds = 5
args.layer = 1
args.predictor_lr = 0.001
args.sampler_lr = 0.1
args.weight_decay = 0.0001
args.batch_size = 128
args.epoch = 50
# args.extractor = 'khop-subtree'
args.extractor = 'RL'
args.graph_fixed_num = 1
args.khop = 1
args.fixed_num = 32
args.d_dim = 64
args.fixed_num = 32
args.num_heads = 4
args.max_smiles_degree = 300
args.max_graph_degree = 600
args.dropout = 0.2
args.k_step = 10
args.sub_coeff = 0.1
args.mi_coeff = 0.1
args.s_type = 'random'
args.pos = 1
args.neg = 1

In [4]:
dataset = args.dataset
data_path = "/bigdat2/user/xiejc/zhangc/dataset/TIGER/dataset/" + dataset + "/"
ligands = read_smiles(os.path.join(data_path, "drug_smiles.txt"))
smile_graph, num_rel_mol_update, max_smiles_degree = smile_to_graph(data_path, ligands)
interactions_label, all_contained_drgus = read_interactions(os.path.join(data_path, "ddi.txt"), smile_graph)


Read /bigdat2/user/xiejc/zhangc/dataset/TIGER/dataset/drugbank/drug_smiles.txt!
10630
10630


In [5]:
interactions_label

array([[  819,   758,     1,     1],
       [  822,   758,     1,     1],
       [  826,   758,     1,     1],
       ...,
       [ 8311, 38519,     1,     0],
       [29607, 26875,     1,     0],
       [32738,  6294,     1,     0]])

In [6]:
from collections import defaultdict

nodes = defaultdict(list)
for i, (drug1, drug2, _, label) in enumerate(interactions_label):
    nodes[drug1].append((drug2, label))
    nodes[drug2].append((drug1, label))

In [None]:
import random
labels = {}
for node in nodes:
    labels[node] = {0:0, 1:1}

for node in nodes:
    for _, label in nodes[node]:
        labels[node][label] += 1

train_nodes = set()

train_size = 0.6*len(interactions_label)
val_size = 0.2*len(interactions_label)
test_size = 0.2*len(interactions_label)

train_s = 0
val_s = 0
test_s = 0

t = 0
f = 0
test_node = set()
while(test_s < test_size):
    node = random.choice(list(nodes.keys()))
    if node in test_node:
        continue
    if t>f and labels[node][0] < labels[node][1]:
        continue
    if f>t and labels[node][1] < labels[node][0]:
        continue
    t += labels[node][1]
    f += labels[node][0]
    test_s += (labels[node][1]+labels[node][0])
    test_node.add(node)

t = 0
f = 0
val_node = set()
while(val_s < val_size):
    node = random.choice(list(nodes.keys()))
    if node in test_node or node in val_node:
        continue
    if t>f and labels[node][0] < labels[node][1]:
        continue
    if f>t and labels[node][1] < labels[node][0]:
        continue
    t += labels[node][1]
    f += labels[node][0]
    val_s += (labels[node][1]+labels[node][0])
    val_node.add(node)

for node in nodes:
    if node in test_node or node in val_node:
        continue
    train_nodes.add(node)

test_set = set()
for node in test_node:
    for p in nodes[node]:
        if (p[0], node, p[1]) in test_set or (node, p[0], p[1]) in test_set:
            continue
        test_set.add((node, p[0], p[1]))

val_set = set()
for node in val_node:
    for p in nodes[node]:
        if (p[0], node, p[1]) in val_set or (node, p[0], p[1]) in val_set:
            continue
        if (p[0], node, p[1]) in test_set or (node, p[0], p[1]) in test_set:
            continue 
        val_set.add((node, p[0], p[1]))

train_set = set()
for node in train_nodes:
    for p in nodes[node]:
        if (p[0], node, p[1]) in train_set or (node, p[0], p[1]) in train_set:
            continue
        if (p[0], node, p[1]) in val_set or (node, p[0], p[1]) in val_set:
            continue
        if (p[0], node, p[1]) in test_set or (node, p[0], p[1]) in test_set:
            continue 
        train_set.add((node, p[0], p[1]))

updated_train_node = set()
for u,v, label in train_set:
    updated_train_node.add(u)
    updated_train_node.add(v)

to_remove = set()
for tri in test_set:
    if tri[0] in updated_train_node and tri[1] in updated_train_node:
        to_remove.add(tri)
test_set = test_set - to_remove

to_remove = set()
for tri in val_set:
    if tri[0] in updated_train_node and tri[1] in updated_train_node:
        to_remove.add(tri)
val_set = val_set - to_remove

print(f'train size: {len(train_set)}', f'val size: {len(val_set)}', f'test size: {len(test_set)}')

def f(data):
    t, f = 0, 0
    for u,v, label in data:
        if label == 1:
            t += 1
        else:
            f += 1
    return t, f

print(f(train_set), f(val_set), f(test_set))




train size: 13128 val size: 3428 test size: 3973
(6939, 6189) (1724, 1704) (1954, 2019)


In [62]:
def get_index(interactions_label, data_set):
    index = []
    for u,v,label in data_set:

        matches = np.all(interactions_label == [u,v,1,label], axis=1)
        indices = np.where(matches)[0]
        if len(indices) > 0:
            index.append(indices[0])
        else:
            matches = np.all(interactions_label == [v,u,1,label], axis=1)
            indices = np.where(matches)[0]
            index.append(indices[0])
    return index

train_index = get_index(interactions_label, train_set)
val_index = get_index(interactions_label, val_set)
test_index = get_index(interactions_label, test_set)
        

In [64]:
train_data = interactions_label[train_index]
val_data = interactions_label[val_index]
test_data = interactions_label[test_index]

In [67]:
updated_train_node = set()
for u,v,_, label in train_data:
    updated_train_node.add(u)
    updated_train_node.add(v)

t1, t2, t3 = 0, 0, 0
for tri in val_data:
    if tri[0] in updated_train_node and tri[1] in updated_train_node:
        t1 += 1
    elif tri[0] not in updated_train_node and tri[1] not in updated_train_node:
        t2 += 1
    else:
        t3 += 1

print('both seen:', t1, 'both unseen:', t2, 'one seen:', t3)

both seen: 0 both unseen: 210 one seen: 3218


In [None]:
import torch
import pickle

data = {'train': [torch.tensor(train_index)],
        'val': [torch.tensor(val_index)],
        'test': [torch.tensor(test_index)]}

with open('/home/xiejc/Code/TIGER/RL/data/drugbank/inductive_split.pkl', 'wb') as f:
    pickle.dump(data, f)


In [9]:
def split_fold(folds, dataset, labels, scenario_type='random'):

    test_indices, train_indices, val_indices = [], [], []

    if scenario_type == 'random':
        skf = StratifiedKFold(folds, shuffle=True, random_state=2023)
        train_indices, test_indices, val_indices = k_fold(dataset, skf, folds, labels)

    return train_indices, test_indices, val_indices

tmp = split_fold(5, interactions_label[:,:2], interactions_label[:,3])
print(tmp)

([tensor([    0,     1,     2,  ..., 20800, 20801, 20802]), tensor([    0,     1,     2,  ..., 20803, 20805, 20807]), tensor([    0,     1,     3,  ..., 20805, 20806, 20807]), tensor([    4,     8,    11,  ..., 20805, 20806, 20807]), tensor([    2,     5,     6,  ..., 20802, 20804, 20806])], [array([   14,    15,    25, ..., 20797, 20804, 20806]), array([    8,    11,    13, ..., 20791, 20800, 20802]), array([    2,     5,     6, ..., 20798, 20799, 20801]), array([    0,     1,     3, ..., 20779, 20786, 20792]), array([    4,    18,    27, ..., 20803, 20805, 20807])], [array([    4,    18,    27, ..., 20803, 20805, 20807]), array([   14,    15,    25, ..., 20797, 20804, 20806]), array([    8,    11,    13, ..., 20791, 20800, 20802]), array([    2,     5,     6, ..., 20798, 20799, 20801]), array([    0,     1,     3, ..., 20779, 20786, 20792])])


In [42]:
data, labels, smile_graph, node_graph, dataset_statistics, adj_matrix, edge_rel = load_data(args)

edge_index = torch.tensor(adj_matrix).T.cuda()
edge_rel = torch.tensor(edge_rel).cuda()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

setup_seed(42)

tmp = list(zip(*split_fold(5, data, labels, args.s_type)))
train_idx, test_idx, val_idx = tmp[0]

# train_idx = train_idx.numpy()

train_data = DTADataset(x=data[train_idx], y=labels[train_idx], sub_graph=node_graph, smile_graph=smile_graph)
test_data = DTADataset(x=data[test_idx], y=labels[test_idx], sub_graph=node_graph, smile_graph=smile_graph)
eval_data = DTADataset(x=data[val_idx], y=labels[val_idx], sub_graph=node_graph, smile_graph=smile_graph)

# train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=False, collate_fn=collate) 
train_loader = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=False, collate_fn=collate) 
test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, collate_fn=collate) 
eval_loader = torch.utils.data.DataLoader(eval_data, batch_size=args.batch_size, shuffle=False, collate_fn=collate) 

DDI_predictor, DDI_sampler, predictor_optim, sampler_optim = init_model(args, dataset_statistics)
DDI_predictor.to(device)
DDI_sampler.to(device)
DDI_predictor.reset_parameters()

# DDI_predictor.load_state_dict(torch.load('/home/xiejc/Code/TIGER/RL/best_save/tiger/drugbank/RL/fold_0/0.00000/DDI_predictor.pt'))
# DDI_predictor.load_state_dict(torch.load('/home/xiejc/Code/TIGER/RL/best_save/tiger/drugbank/RL/1.0-1.0-s2/fold_0/67.35455/DDI_predictor.pt'))
DDI_predictor.load_state_dict(torch.load('/home/xiejc/Code/TIGER/RL/best_save/tiger/drugbank/khop-subtree/fold_0/0.87003/DDI_predictor.pt'))
# DDI_sampler.load_state_dict(torch.load('/home/xiejc/Code/TIGER/RL/best_save/tiger/drugbank/RL/1.0-1.0-s2/fold_0/67.35455/DDI_sampler.pt'))


Read /bigdat2/user/xiejc/zhangc/dataset/TIGER/dataset/drugbank/drug_smiles.txt!
load drug smiles graphs!!
load networks !!
load DDI samples!!
10630
10630
generate subgraphs!!
{'num_nodes': 391116, 'num_rel_mol': 133, 'num_rel_graph': 71, 'num_interactions': 20808, 'num_drugs_DDI': 1052, 'max_degree_graph': 7, 'max_degree_node': 69}


<All keys matched successfully>

In [43]:
import pickle

with open('/home/xiejc/Code/TIGER/RL/data/drugbank/khop-subtree/hop_1_old.pkl', 'rb') as f:
    h1 = pickle.load(f)

with open('/home/xiejc/Code/TIGER/RL/data/drugbank/khop-subtree/hop_1.pkl', 'rb') as f:
    h2 = pickle.load(f)


d1, d2 = data[0]




In [55]:
h1[(d1, d2)].x

tensor([    49,    819,   3906,  10879,  10886,  10887,  16113,  32172,  39877,
        101648, 161862, 161868, 161869, 161874, 161875, 161877, 161880, 161884,
        161885, 161894, 161899, 161903, 161906, 161907, 161911, 161912, 161913,
        161914, 161916, 161940,      9,     60,     84,    758,    759,    761,
           762,    763,    764,    767,    771,    772,    774,    776,    780,
           784,    786,    788,    789,    791,    793,    804,    830,    833,
           834,    835,    837,    839])

In [57]:
h2[(d1, d2)].x

tensor([     9,     49,     60,     84,    758,    759,    761,    762,    763,
           764,    767,    771,    772,    774,    776,    780,    784,    786,
           788,    789,    791,    793,    804,    819,    830,    833,    834,
           835,    837,    839,   3906,  10879,  10886,  10887,  16113,  32172,
         39877, 101648, 161862, 161868, 161869, 161874, 161875, 161877, 161880,
        161884, 161885, 161894, 161899, 161903, 161906, 161907, 161911, 161912,
        161913, 161914, 161916, 161940])

In [6]:
train_loaders = list(train_loader)

In [None]:
s = []
for i in range(len(train_loaders)):
    data_subgraph = train_loaders[i][2]
    nodes_with_id_1 = torch.where(data_subgraph.id == 1)[0]

    neighbors = []
    for node in nodes_with_id_1:
        src_neighbors = data_subgraph.edge_index[1][data_subgraph.edge_index[0] == node]
        dst_neighbors = data_subgraph.edge_index[0][data_subgraph.edge_index[1] == node]
        node_neighbors = torch.cat([src_neighbors, dst_neighbors]).unique()
        neighbors.append(node_neighbors)
    print(data_subgraph.x)
    print(data_subgraph.edge_index)
    print(train_loaders[i][3])
    neighbors = torch.cat(neighbors).unique()
    print(neighbors)
    print(len(neighbors))
    asfdasf
    s.append(len(neighbors))

print(max(s), min(s))


In [4]:
from sklearn.metrics import f1_score, roc_auc_score, precision_recall_curve, accuracy_score, auc
def get_score(label_all, prob_all):

    predicts_label = [1 if prob >= 0.5 else 0 for prob in prob_all]

    acc = accuracy_score(label_all, predicts_label)
    f1 = f1_score(label_all, predicts_label)
    auroc = roc_auc_score(label_all, prob_all)
    p, r, t = precision_recall_curve(label_all, prob_all)
    auprc = auc(r, p)

    return acc, f1, auroc, auprc

In [5]:
pred1 = []
label = []

total_reward = 0

num_nodes = dataset_statistics['num_nodes']

for data in tqdm(train_loader):
    data_mol1 = data[0].cuda()
    data_mol2 = data[1].cuda()
    data_subgraph = data[2].cuda()
    data_idx = data[3].cuda() 

    predicts, loss = DDI_predictor(data_mol1, data_mol2, data_subgraph)

    pred1.append(predicts)

    label.append(data_mol1.y)


label_all = torch.concat(label).cpu().detach().numpy()
pred1_all = torch.concat(pred1).cpu().detach().numpy()

acc1, f11, auc1, aupr1 = get_score(label_all, pred1_all)
print('Default Model: acc: %.4f, f1: %.4f, auroc: %.4f, auprc: %.4f' % (acc1, f11, auc1, aupr1))



100%|██████████| 98/98 [00:55<00:00,  1.78it/s]

Default Model: acc: 0.8804, f1: 0.8888, auroc: 0.9500, auprc: 0.9436





In [None]:
pred1 = []
pred2 = []
label = []

total_reward = 0

num_nodes = dataset_statistics['num_nodes']

for data in tqdm(train_loader):
    data_mol1 = data[0].cuda()
    data_mol2 = data[1].cuda()
    data_subgraph = data[2].cuda()
    data_idx = data[3].cuda() 

    selected_subgraph_list, batch = DDI_sampler.generate_default_subgraph(data_idx, edge_index, edge_rel)
    selected_subgraph_list = selected_subgraph_list.cuda()
    batch = batch.cuda()
    pred_default = DDI_predictor.pred(data_mol1, data_mol2, data_subgraph, batch)

    pred1.append(pred_default)

    embeddings = DDI_predictor.drug_node_feature.node_encoder(torch.tensor(range(num_nodes)).cuda())
    selected_subgraph_list, batch = DDI_sampler.predict(data_idx, edge_index, edge_rel, embeddings) 
    selected_subgraph_list = selected_subgraph_list.cuda()
    batch = batch.cuda()
    
    reward_batch, predicts = DDI_predictor.get_reward(data_mol1, data_mol2, selected_subgraph_list, batch, pred_default)
    total_reward += torch.sum(reward_batch).item()

    pred2.append(predicts)
    label.append(data_mol1.y)


label_all = torch.concat(label).cpu().detach().numpy()
pred1_all = torch.concat(pred1).cpu().detach().numpy()
pred2_all = torch.concat(pred2).cpu().detach().numpy()

acc1, f11, auc1, aupr1 = get_score(label_all, pred1_all)
acc2, f12, auc2, aupr2 = get_score(label_all, pred2_all)
print('Default Model: acc: %.4f, f1: %.4f, auroc: %.4f, auprc: %.4f' % (acc1, f11, auc1, aupr1))
print('RL Model: acc: %.4f, f1: %.4f, auroc: %.4f, auprc: %.4f' % (acc2, f12, auc2, aupr2))



100%|██████████| 98/98 [12:38<00:00,  7.74s/it]

Default Model: acc: 0.5706, f1: 0.5940, auroc: 0.6051, auprc: 0.6070
RL Model: acc: 0.5990, f1: 0.6266, auroc: 0.6378, auprc: 0.6333





In [7]:
idx = 200

data_mol1 = train_loaders[idx][0].cuda()
data_mol2 = train_loaders[idx][1].cuda()
data_subgraph = train_loaders[idx][2].cuda()
data_idx = train_loaders[idx][3].cuda() # batch_size * 2

DDI_predictor.eval()
DDI_sampler.eval()

with torch.no_grad():
    batch = torch.tensor(range(data_idx.shape[0])).cuda()
    pred_default = DDI_predictor.pred(data_mol1, data_mol2, data_subgraph, batch)

print(pred_default, data_mol1.y)

tensor([0.9665], device='cuda:0') tensor([1], device='cuda:0')


In [39]:
src, tgt = data_subgraph.edge_index
src = src.cpu().detach().numpy().tolist()
tgt = tgt.cpu().detach().numpy().tolist()
b = list(zip(src, tgt))
print(len(b))
print(len(set(b)))

364
178


In [40]:
src, tgt = G_sub.edge_index
src = src.cpu().detach().numpy().tolist()
tgt = tgt.cpu().detach().numpy().tolist()
a = list(zip(src, tgt))
print(len(a))
print(len(set(a)))

242
238


In [25]:
from torch_geometric.utils import degree, subgraph, k_hop_subgraph

idx = 200

data_mol1 = train_loaders[idx][0].cuda()
data_mol2 = train_loaders[idx][1].cuda()
data_subgraph = train_loaders[idx][2].cuda()
data_idx = train_loaders[idx][3].cuda() # batch_size * 2

DDI_predictor.eval()
DDI_sampler.eval()

x = data_subgraph.x

edge_index2, edge_rel2 = subgraph(x, edge_index, edge_rel, relabel_nodes=True)

mapping_id = torch.zeros(len(x), dtype=torch.long)
mapping_id[torch.where(x == data_idx[0][0])[0]] = 1
mapping_id[torch.where(x == data_idx[0][1])[0]] = 1

G_sub = DATA.Data(x=x,
            edge_index=edge_index2,
            id=mapping_id,
            rel_index=edge_rel2,
            sp_edge_index=edge_index2,
            sp_value=torch.ones(edge_index2.size(1), dtype=torch.float),
            sp_edge_rel=edge_rel2
        )

G_sub = Batch.from_data_list([G_sub]).cuda()

with torch.no_grad():
    batch = torch.tensor(range(data_idx.shape[0])).cuda()
    pred_default = DDI_predictor.pred(data_mol1, data_mol2, G_sub, batch)

print(pred_default, data_mol1.y)

tensor([0.9059], device='cuda:0') tensor([1], device='cuda:0')


In [11]:
selected_subgraph_list.to_data_list()

[Data(x=[64], edge_index=[2, 246], id=[64], rel_index=[246], sp_edge_index=[2, 246], sp_value=[246], sp_edge_rel=[246]),
 Data(x=[65], edge_index=[2, 250], id=[65], rel_index=[250], sp_edge_index=[2, 250], sp_value=[250], sp_edge_rel=[250]),
 Data(x=[66], edge_index=[2, 254], id=[66], rel_index=[254], sp_edge_index=[2, 254], sp_value=[254], sp_edge_rel=[254]),
 Data(x=[67], edge_index=[2, 258], id=[67], rel_index=[258], sp_edge_index=[2, 258], sp_value=[258], sp_edge_rel=[258]),
 Data(x=[68], edge_index=[2, 262], id=[68], rel_index=[262], sp_edge_index=[2, 262], sp_value=[262], sp_edge_rel=[262]),
 Data(x=[69], edge_index=[2, 268], id=[69], rel_index=[268], sp_edge_index=[2, 268], sp_value=[268], sp_edge_rel=[268]),
 Data(x=[70], edge_index=[2, 272], id=[70], rel_index=[272], sp_edge_index=[2, 272], sp_value=[272], sp_edge_rel=[272]),
 Data(x=[71], edge_index=[2, 276], id=[71], rel_index=[276], sp_edge_index=[2, 276], sp_value=[276], sp_edge_rel=[276]),
 Data(x=[72], edge_index=[2, 280

In [9]:
num_nodes = dataset_statistics['num_nodes']
with torch.no_grad():
    embeddings = DDI_predictor.drug_node_feature.node_encoder(torch.tensor(range(num_nodes)).cuda())
selected_subgraph_list, selected_subgraph_prob_list, batch, tmp = DDI_sampler(data_idx, edge_index, edge_rel, embeddings, data_subgraph)

selected_subgraph_list = selected_subgraph_list.cuda()
batch = batch.cuda()

with torch.no_grad():
    reward_batch, predicts = DDI_predictor.get_reward(data_mol1, data_mol2, selected_subgraph_list, batch, pred_default)

a = [(tmp[i], selected_subgraph_prob_list[i].item(), reward_batch[i].item()) for i in range(len(tmp))]
print([f"({x[0]}, {x[1]:.4f}, {x[2]:.4f})" for x in a])
print(torch.sum(reward_batch).item())
print(predicts)

['((0, 27743), 0.0088, -0.9665)', '((0, 27668), 0.0079, -0.9664)', '((0, 27680), 0.0080, -0.9664)', '((0, 151057), 0.0083, -0.9664)', '((0, 151084), 0.0082, -0.9664)', '((0, 25412), 0.0092, -0.9664)', '((0, 151058), 0.0086, -0.9664)', '((0, 27692), 0.0090, -0.9664)', '((0, 27687), 0.0090, -0.9664)', '((0, 27858), 0.0088, -0.9664)']
-9.66413688659668
tensor([9.8239e-05, 1.0850e-04, 1.1072e-04, 1.1109e-04, 1.1802e-04, 1.3134e-04,
        1.2990e-04, 1.8765e-04, 1.9280e-04, 1.7292e-04], device='cuda:0')


In [19]:
a = torch.tensor([1,2,3,4,5])
b = torch.tensor([3,2,1,7,6])

c = torch.cat([a,b], dim=0)
c = torch.unique(c, dim=0)
c

tensor([1, 2, 3, 4, 5, 6, 7])

In [None]:
from torch_geometric.utils import degree, subgraph, k_hop_subgraph

graph_embeddings = torch.concat((embeddings, DDI_sampler.done_embedding.weight), dim=0) 

current_nodes = data_idx[0]

for _ in range(1):
    current_neighbors = DDI_sampler.get_neighbors(current_nodes, edge_index) 
    mask = ~torch.isin(current_neighbors, current_nodes)
    current_neighbors = current_neighbors[mask]

    neighbors_embeddings =  graph_embeddings[current_neighbors]
    k_embeddings = graph_embeddings[current_nodes].mean(dim=0)

    logits = DDI_sampler.prior(k_embeddings.expand_as(neighbors_embeddings), neighbors_embeddings)

    print(neighbors_embeddings.shape)



torch.Size([213, 64])


In [None]:
logits

tensor([1.0830e-10, 6.3097e-10, 2.0367e-10, 6.0915e-11, 2.9689e-09, 6.3208e-11,
        7.5545e-11, 9.1292e-04, 2.6761e-06, 5.5650e-11, 5.1834e-10, 2.8352e-10,
        1.3596e-09, 3.6235e-11, 2.6743e-10, 3.2621e-10, 3.9418e-11, 6.6574e-11,
        8.5086e-11, 7.7675e-11, 4.3505e-10, 7.8275e-09, 1.9513e-10, 4.6439e-10,
        6.1970e-11, 1.7582e-10, 1.8908e-08, 1.5779e-10, 8.2300e-11, 1.3202e-10,
        5.6626e-11, 2.6970e-10, 2.0408e-09, 2.5742e-10, 6.8789e-08, 4.3487e-10,
        2.2408e-10, 1.1551e-10, 5.9214e-11, 4.3066e-10, 7.2286e-11, 4.4758e-10,
        1.4138e-03, 1.5043e-07, 2.6613e-09, 3.6469e-10, 5.1665e-05, 2.0351e-10,
        7.3577e-11, 3.1566e-09, 2.4525e-10, 4.9033e-11, 2.2919e-08, 4.8477e-11,
        7.5948e-11, 1.4439e-10, 2.4066e-10, 1.7304e-10, 2.5364e-09, 3.8400e-09,
        5.4374e-11, 8.3430e-11, 2.4670e-10, 6.0068e-11, 8.0299e-11, 3.3665e-10,
        1.2282e-10, 1.8414e-10, 1.0142e-10, 1.7689e-10, 6.3985e-11, 1.7394e-10,
        1.3243e-10, 1.3957e-10, 9.5974e-

In [None]:
e1 = k_embeddings.expand_as(neighbors_embeddings)
e2 = neighbors_embeddings

e3 = DDI_sampler.prior.fc1(torch.cat((e1, e2), dim=-1))
e = e1 + e3

e_l1 = DDI_sampler.prior.fc_layers[0](e)
e_l2 = DDI_sampler.prior.fc_layers[1](e_l1)
e_l3 = DDI_sampler.prior.fc_layers[2](e_l2)
e_l4 = DDI_sampler.prior.fc_layers[3](e_l3).squeeze()
e_l4[80]

tensor(10.9395, device='cuda:0', grad_fn=<SelectBackward0>)

In [None]:
DDI_sampler.prior.fc_layers[3].weight

Parameter containing:
tensor([[ 0.4100, -0.5864, -0.4886, -0.5318,  0.4054,  0.4538, -0.4167,  0.5638,
         -0.4315,  0.4092,  0.4140, -0.5129, -0.5220,  0.5087, -0.4700,  0.4654,
          0.5506, -0.3612,  0.4555,  0.3850,  0.4648, -0.4004, -0.4736, -0.5325,
          0.3799,  0.5322, -0.5124, -0.4700, -0.4387, -0.4650,  0.3855, -0.3797]],
       device='cuda:0', requires_grad=True)

In [None]:
e_l3[80:82]

tensor([[ 0.9748, -0.4366, -0.6336, -0.6649,  0.9594,  0.9516, -0.4009,  0.4172,
         -0.3794,  0.9258,  0.6849, -0.7392, -0.5632,  0.9083, -0.8041,  0.6441,
          0.9533, -0.9631,  0.8250,  0.4897,  0.5320, -0.9417, -0.9602, -0.8695,
          0.9606,  0.4458, -0.9518, -0.8146, -0.9534, -0.3829,  0.9191, -0.9409],
        [-0.8228,  0.9628,  0.9647,  0.9706, -0.8601, -0.8854,  0.8317, -0.9705,
          0.9311, -0.8923, -0.9306,  0.5511,  0.9716, -0.8803,  0.7787, -0.9118,
         -0.7977,  0.5250, -0.5871, -0.9763, -0.9459,  0.7911,  0.8878,  0.7249,
         -0.7988, -0.9823,  0.7826,  0.9244,  0.2728,  0.9364, -0.4391,  0.4792]],
       device='cuda:0', grad_fn=<SliceBackward0>)

In [None]:
import torch
reward_all = []
batch = torch.tensor([0,0,0,1,1,1,1,2,2])
reward_batch = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
for b in torch.unique(batch):
    idx = (batch == b).nonzero(as_tuple=True)[0]
    batch_reward = reward_batch[idx]
    reward = torch.zeros(batch_reward.size(), device=batch_reward.device)
    R = 0
    n = batch_reward.size(0) - 1
    for i, r in enumerate(batch_reward.flip(0)):
        R = r + 1 * R
        reward[n-i] = R
    reward_all.append(reward)
reward_batch = torch.concat(reward_all)

print(reward_batch)

tensor([0.6000, 0.5000, 0.3000, 2.2000, 1.8000, 1.3000, 0.7000, 1.7000, 0.9000])
