In [1]:
import sys
import os
import yaml
import argparse
import logging
logging.basicConfig(level=logging.INFO, format='%(levelname)s:%(message)s')

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CSVLogger
import torch
from torch_geometric.loader import DataLoader

from Networks.Embedding.embedding import Embedding
from Networks.utils import *

INFO:Loading faiss with AVX512 support.
INFO:Successfully loaded faiss with AVX512 support.


In [2]:
# Load config file
with open('emb_config.yaml') as c:
    cl = yaml.load(c, Loader=yaml.FullLoader)
    config = cl['metric_learning_configs']

In [3]:
model = Embedding(config)

In [4]:
d = load_dataset('Datasets')
dl = DataLoader(d,batch_size=1)
trainer = Trainer()
res = trainer.predict(model,dl,ckpt_path='EMB_out/testrun.ckpt')

  rank_zero_warn(
INFO:GPU available: False, used: False
INFO:TPU available: False, using: 0 TPU cores
INFO:IPU available: False, using: 0 IPUs
INFO:HPU available: False, using: 0 HPUs
INFO:Restoring states from the checkpoint path at EMB_out/testrun.ckpt
INFO:Loaded model weights from checkpoint at EMB_out/testrun.ckpt
  rank_zero_warn(


Predicting DataLoader 0: 100%|███████████████████████████████████████████████████████████| 1183/1183 [00:08<00:00, 142.42it/s]


In [5]:
print(res[0]['truth'].shape)

torch.Size([1562])


In [7]:
res[0]['distances'].shape

torch.Size([1562])

In [14]:
res[0]

{'loss': tensor(0.0098),
 'distances': tensor([9.6245e-05, 1.3376e-04, 5.3936e-05,  ..., 1.6741e-04, 3.1212e-04,
         3.5718e-04]),
 'preds': tensor([[ 0,  0,  0,  ..., 40, 40, 40],
         [ 1,  2,  3,  ..., 37, 38, 39]]),
 'truth': tensor([False, False, False,  ..., False, False, False]),
 'truth_graph': tensor([[ 3,  3,  4,  4, 16, 16,  4, 16,  3, 16,  3,  4],
         [ 4, 16,  3, 16,  3,  4,  3,  3,  4,  4, 16, 16]]),
 'eff': tensor(0.5000),
 'pur': tensor(3.)}

In [15]:
d[0].true_edges

tensor([[ 3,  3,  4,  4, 16, 16],
        [ 4, 16,  3, 16,  3,  4]])

In [13]:
for i in range(len(d)):
    d[i].pred_edges_emb = res[i]['preds']
    d[i].distances_emb = res[i]['distances']
    d[i].edges_y = res[i]['truth']

In [9]:
d

[Data(x=[41, 2], true_edges=[2, 6], pred_edges_emb=[2, 1562], distances_emb=[1562]),
 Data(x=[32, 2], true_edges=[2, 6], pred_edges_emb=[2, 992], distances_emb=[992]),
 Data(x=[38, 2], true_edges=[2, 6], pred_edges_emb=[2, 1406], distances_emb=[1406]),
 Data(x=[16, 2], true_edges=[2, 2], pred_edges_emb=[2, 240], distances_emb=[240]),
 Data(x=[55, 2], true_edges=[2, 30], pred_edges_emb=[2, 2970], distances_emb=[2970]),
 Data(x=[49, 2], true_edges=[2, 20], pred_edges_emb=[2, 2352], distances_emb=[2352]),
 Data(x=[47, 2], true_edges=[2, 32], pred_edges_emb=[2, 2162], distances_emb=[2162]),
 Data(x=[52, 2], true_edges=[2, 6], pred_edges_emb=[2, 2652], distances_emb=[2652]),
 Data(x=[35, 2], true_edges=[2, 2], pred_edges_emb=[2, 1190], distances_emb=[1190]),
 Data(x=[35, 2], true_edges=[2, 8], pred_edges_emb=[2, 1122], distances_emb=[1122]),
 Data(x=[42, 2], true_edges=[2, 14], pred_edges_emb=[2, 1722], distances_emb=[1722]),
 Data(x=[43, 2], true_edges=[2, 8], pred_edges_emb=[2, 1806], dis

In [14]:
torch.save(d,'EMB_out/data_emb.pt')

In [11]:
d_new = torch.load('EMB_out/data_emb.pt')

In [12]:
d_new

[Data(x=[41, 2], true_edges=[2, 6], pred_edges_emb=[2, 1562], distances_emb=[1562]),
 Data(x=[32, 2], true_edges=[2, 6], pred_edges_emb=[2, 992], distances_emb=[992]),
 Data(x=[38, 2], true_edges=[2, 6], pred_edges_emb=[2, 1406], distances_emb=[1406]),
 Data(x=[16, 2], true_edges=[2, 2], pred_edges_emb=[2, 240], distances_emb=[240]),
 Data(x=[55, 2], true_edges=[2, 30], pred_edges_emb=[2, 2970], distances_emb=[2970]),
 Data(x=[49, 2], true_edges=[2, 20], pred_edges_emb=[2, 2352], distances_emb=[2352]),
 Data(x=[47, 2], true_edges=[2, 32], pred_edges_emb=[2, 2162], distances_emb=[2162]),
 Data(x=[52, 2], true_edges=[2, 6], pred_edges_emb=[2, 2652], distances_emb=[2652]),
 Data(x=[35, 2], true_edges=[2, 2], pred_edges_emb=[2, 1190], distances_emb=[1190]),
 Data(x=[35, 2], true_edges=[2, 8], pred_edges_emb=[2, 1122], distances_emb=[1122]),
 Data(x=[42, 2], true_edges=[2, 14], pred_edges_emb=[2, 1722], distances_emb=[1722]),
 Data(x=[43, 2], true_edges=[2, 8], pred_edges_emb=[2, 1806], dis

In [2]:
d = torch.load('EMB_out/data_emb.pt')

In [3]:
d[0]

Data(x=[41, 2], true_edges=[2, 6], pred_edges_emb=[2, 1640], distances_emb=[1640], edges_y=[1640])

In [4]:
dist = d[0].distances_emb
y = d[0].edges_y

In [5]:
N_true = sum(y==True)
N_false = sum(y==False)

In [7]:
N_true

tensor(6)

In [8]:
s_true = N_true/(dist[y==True]).sum()

In [10]:
w_true = s_true*dist[y==True]

In [33]:
s_false = N_false/(1/dist[y==False]).sum()
s_true = N_true/(1/dist[y==True]).sum()

In [37]:
w_false = s_false/dist[y==False]
w_true = s_true/dist[y==True]

In [12]:
w_true

tensor([0.7874, 1.9480, 0.7874, 0.2646, 1.9480, 0.2646])

In [11]:
dist[y==True]

tensor([3.0154e-05, 7.4602e-05, 3.0154e-05, 1.0135e-05, 7.4602e-05, 1.0135e-05])

In [16]:
1/(1/dist[y==True]).sum()

tensor(3.4426e-06)

In [23]:
input = torch.tensor([-3,-5,15,7,-3,-5,15,7]).float()
target = torch.tensor([0,0,1,1,0,0,1,1]).float()

In [24]:
torch.nn.functional.binary_cross_entropy_with_logits(input,target)

tensor(0.0141)

In [48]:
edge_weight = torch.ones_like(y).float()

In [53]:
edge_weight[y==True] = s_true/dist[y==True]
edge_weight[y==False] = s_false/dist[y==False]

In [54]:
edge_weight[y==False]

tensor([0.2141, 0.1397, 0.2656,  ..., 0.3037, 0.0661, 0.0524])

In [52]:
dist[y==True]

tensor([3.0154e-05, 7.4602e-05, 3.0154e-05, 1.0135e-05, 7.4602e-05, 1.0135e-05])