In [2]:
import sys
sys.argv = ["main", "data_use_0.8.pkl", "results_sample.txt", "sample.pt"]

In [67]:
import pickle, torch, os, sys, random
import numpy as np
from math import ceil
import torch.optim as optim
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
from sklearn.metrics import accuracy_score

dataset_file = os.path.abspath("../Input/" + sys.argv[1])
results_file = os.path.abspath("../Results/" + sys.argv[2])
model_file =  os.path.abspath("../Models/" + sys.argv[3])

f = open(dataset_file, "rb")
(nodes_train, paths_train, counts_train, targets_train, 
 nodes_test, paths_test, counts_test, targets_test,
 nodes_instances, paths_instances, counts_instances, targets_instances,
 nodes_knocked, paths_knocked, counts_knocked, targets_knocked,
 emb_indexer, emb_indexer_inv, emb_vals, 
 pos_indexer, dep_indexer, dir_indexer, rel_indexer) = pickle.load(f)

def write(statement):
    op_file = open(results_file, "w+")
    op_file.write("\n" + str(statement) + "\n")
    op_file.close()

POS_DIM = 4
DEP_DIM = 6
DIR_DIM = 3
NUM_RELATIONS = len(rel_indexer)
NULL_EDGE = [0, 0, 0, 0]

torch.set_default_dtype(torch.float64)
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

flatten = lambda l: [item for sublist in l for item in sublist]

class RelationPredictor(nn.Module):

    def __init__(self, emb_vals):
        
        super(RelationPredictor, self).__init__()

        self.EMBEDDING_DIM = np.array(emb_vals).shape[1]
        self.n_directions = 2 if bidirectional else 1
        
        self.input_dim = POS_DIM + DEP_DIM + self.EMBEDDING_DIM + DIR_DIM
        self.output_dim = self.n_directions * HIDDEN_DIM + 2 * self.EMBEDDING_DIM
        # self.layer1_dim = LAYER1_DIM
        # self.W1 = nn.Linear(self.hidden_dim, self.layer1_dim)
        # self.W2 = nn.Linear(self.layer1_dim, NUM_RELATIONS)

        self.dropout_layer = nn.Dropout(p=dropout)
        self.log_softmax = nn.LogSoftmax()
        
        self.name_embeddings = nn.Embedding(len(emb_vals), self.EMBEDDING_DIM)
        self.name_embeddings.load_state_dict({'weight': torch.from_numpy(np.array(emb_vals))})
        self.name_embeddings.weight.requires_grad = False

        self.pos_embeddings = nn.Embedding(len(pos_indexer), POS_DIM)
        self.dep_embeddings = nn.Embedding(len(dep_indexer), DEP_DIM)
        self.dir_embeddings = nn.Embedding(len(dir_indexer), DIR_DIM)

        nn.init.xavier_uniform_(self.pos_embeddings.weight)
        nn.init.xavier_uniform_(self.dep_embeddings.weight)
        nn.init.xavier_uniform_(self.dir_embeddings.weight)
        
        self.lstm = nn.LSTM(self.input_dim, HIDDEN_DIM, NUM_LAYERS, bidirectional=bidirectional, batch_first=True)

        self.W = nn.Linear(self.output_dim, NUM_RELATIONS)

    def forward(self, nodes, paths, counts, edgecounts, max_paths, max_edges):
        '''
            nodes: batch_size * 2
            paths: batch_size * max_paths * max_edges * 4
            counts: batch_size * max_paths
            edgecounts: batch_size * max_paths
        '''
        word_embed = self.dropout_layer(self.name_embeddings(paths[:,:,:,0]))
        pos_embed = self.dropout_layer(self.pos_embeddings(paths[:,:,:,1]))
        dep_embed = self.dropout_layer(self.dep_embeddings(paths[:,:,:,2]))
        dir_embed = self.dropout_layer(self.dir_embeddings(paths[:,:,:,3]))
        paths_embed = torch.cat((word_embed, pos_embed, dep_embed, dir_embed), dim=-1)
        nodes_embed = self.dropout_layer(self.name_embeddings(nodes)).reshape(-1, 2*self.EMBEDDING_DIM)
        
        print (word_embed.shape, pos_embed.shape, dep_embed.shape, dir_embed.shape, paths_embed.shape, nodes_embed.shape)

        paths_embed = paths_embed.reshape((-1, max_edges, self.input_dim))

        paths_packed = pack_padded_sequence(paths_embed, torch.flatten(edgecounts), batch_first=True, enforce_sorted=False)
        _, (hidden_state, _) = self.lstm(paths_packed)
        paths_output = hidden_state.permute(1,2,0)
        paths_output_reshaped = paths_output.reshape(-1, max_paths, HIDDEN_DIM*NUM_LAYERS*self.n_directions)
        # paths_output has dim (batch_size, max_paths, HIDDEN_DIM, NUM_LAYERS*self.n_directions)

        paths_weighted = torch.bmm(paths_output_reshaped.permute(0,2,1), counts.unsqueeze(-1)).squeeze(-1)
        representation = torch.cat((nodes_embed, paths_weighted), dim=-1)
        probabilities = self.log_softmax(self.W(representation))
        return probabilities

def to_list(seq):
    for item in seq:
        if isinstance(item, tuple):
            yield list(to_list(item))
        elif isinstance(item, list):
            yield [list(to_list(elem)) for elem in item]
        else:
            yield item

def pad_paths(paths, max_paths, max_edges):
    paths_edgepadded = [[path + [NULL_EDGE for i in range(max_edges-len(path))]
        for path in element]
    for element in paths]
    NULL_PATH = [NULL_EDGE for i in range(max_edges)]
    paths_padded = [element + [NULL_PATH for i in range(max_paths-len(element))] 
        for element in paths_edgepadded]
    return np.array(paths_padded)
        
def pad_counts(counts, max_paths):
    return np.array([elem + [0 for i in range(max_paths - len(elem))] for elem in counts])

def pad_edgecounts(edgecounts, max_paths):
    return np.array([elem + [1 for i in range(max_paths - len(elem))] for elem in edgecounts])

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

HIDDEN_DIM = 250
# LAYER1_DIM = 120
NUM_LAYERS = 1
num_epochs = 1
batch_size = 32
bidirectional = True

lr = 0.001
dropout = 0.3
weight_decay = 0.001

model = nn.DataParallel(RelationPredictor(emb_vals)).to(device)
criterion = nn.NLLLoss()
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

for epoch in range(num_epochs):
    
    all_inp = list(zip(nodes_train, paths_train, counts_train, targets_train))
    all_inp_shuffled = random.sample(all_inp, len(all_inp))
    nodes_train, paths_train, counts_train, targets_train = list(zip(*all_inp_shuffled[:100]))

    num_edges_all = [[len(path) for path in element] for element in paths_train]
    max_edges = max(flatten(num_edges_all))
    max_paths = max([len(elem) for elem in counts_train])

    dataset_size = len(nodes_train)
    batch_size = min(batch_size, dataset_size)
    num_batches = int(ceil(dataset_size/batch_size))

    for batch_idx in range(num_batches):
        
        batch_start = batch_idx * batch_size
        batch_end = (batch_idx+1) * batch_size
        
        nodes = torch.LongTensor(nodes_train[batch_start:batch_end]).to(device)
        paths = torch.LongTensor(pad_paths(paths_train[batch_start:batch_end], max_paths, max_edges)).to(device)
        counts = torch.DoubleTensor(pad_counts(counts_train[batch_start:batch_end], max_paths)).to(device)
        edgecounts = torch.LongTensor(pad_edgecounts(num_edges_all[batch_start:batch_end], max_paths)).to(device)
        targets = torch.LongTensor(targets_train[batch_start:batch_end]).to(device)
        
        # Backprop and perform Adam optimisation
        optimizer.zero_grad()

        # Run the forward pass
        outputs = model(nodes, paths, counts, edgecounts, max_paths, max_edges)

        #loss = log_loss(outputs, torch.LongTensor(labels).to(device))
        loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()

        print("Epoch: {} Idx: {} Loss: {}".format(epoch, batch_idx, loss.item()))  

print("Training Complete!")

def calculate_recall(true, pred):
    true_f, pred_f = [], []
    for i,elem in enumerate(true):
        if elem!=4:
            true_f.append(elem)
            pred_f.append(pred[i])
    return accuracy_score(true_f, pred_f)

def calculate_precision(true, pred):
    true_f, pred_f = [], []
    for i,elem in enumerate(pred):
        if elem!=4:
            pred_f.append(elem)
            true_f.append(true[i])
    return accuracy_score(true_f, pred_f)

def test(nodes_test, paths_test, counts_test, targets_test, message):
    predictedLabels, trueLabels = [], []

    num_edges_all = [[len(path) for path in element] for element in paths_test]
    max_edges = max(flatten(num_edges_all))
    max_paths = max([len(elem) for elem in counts_test])
    print (max_paths, max_edges)
    dataset_size = len(nodes_test)
    batch_size = min(32, dataset_size)
    num_batches = int(ceil(dataset_size/batch_size))

    for batch_idx in range(num_batches):
        
        batch_start = batch_idx * batch_size
        batch_end = (batch_idx+1) * batch_size

        nodes = torch.LongTensor(nodes_test[batch_start:batch_end])
        paths = torch.LongTensor(pad_paths(paths_test[batch_start:batch_end], max_paths, max_edges))
        counts = torch.DoubleTensor(pad_counts(counts_test[batch_start:batch_end], max_paths))
        edgecounts = torch.LongTensor(pad_edgecounts(num_edges_all[batch_start:batch_end], max_paths))
        targets = torch.LongTensor(targets_test[batch_start:batch_end])
        
        print (nodes.shape, paths.shape, counts.shape, edgecounts.shape, targets.shape)
        
        outputs = model(nodes, paths, counts, edgecounts, max_paths, max_edges)
        _, predicted = torch.max(outputs, 1)
        predicted = [el.item() for el in predicted]
        targets = [el.item() for el in targets]
        predictedLabels.extend(predicted)
        trueLabels.extend(targets)
    
    print (trueLabels, predictedLabels)
    accuracy = accuracy_score(trueLabels, predictedLabels)
    recall = calculate_recall(trueLabels, predictedLabels)
    precision = calculate_precision(trueLabels, predictedLabels)
    final_metrics = [accuracy, precision, recall, 2 * (precision * recall/(precision + recall))]
    print("Final Results ({}): [{}]".format(message, ", ".join([str(el) for el in final_metrics])))

model.eval()
with torch.no_grad():
    test(nodes_test, paths_test, counts_test, targets_test, "Test")
    test(nodes_instances, paths_instances, counts_instances, targets_instances, "Instances")
    test(nodes_knocked, paths_knocked, counts_knocked, targets_knocked, "Knocked out")


torch.Size([32, 6, 5, 512]) torch.Size([32, 6, 5, 4]) torch.Size([32, 6, 5, 6]) torch.Size([32, 6, 5, 3]) torch.Size([32, 6, 5, 525]) torch.Size([32, 1024])
Epoch: 0 Idx: 0 Loss: 1.6065702650937592
torch.Size([32, 6, 5, 512]) torch.Size([32, 6, 5, 4]) torch.Size([32, 6, 5, 6]) torch.Size([32, 6, 5, 3]) torch.Size([32, 6, 5, 525]) torch.Size([32, 1024])
Epoch: 0 Idx: 1 Loss: 1.589469174306449
torch.Size([32, 6, 5, 512]) torch.Size([32, 6, 5, 4]) torch.Size([32, 6, 5, 6]) torch.Size([32, 6, 5, 3]) torch.Size([32, 6, 5, 525]) torch.Size([32, 1024])




Epoch: 0 Idx: 2 Loss: 1.5615394039610633
torch.Size([4, 6, 5, 512]) torch.Size([4, 6, 5, 4]) torch.Size([4, 6, 5, 6]) torch.Size([4, 6, 5, 3]) torch.Size([4, 6, 5, 525]) torch.Size([4, 1024])
Epoch: 0 Idx: 3 Loss: 1.563932163269547
Training Complete!
238 8
torch.Size([32, 2]) torch.Size([32, 238, 8, 4]) torch.Size([32, 238]) torch.Size([32, 238]) torch.Size([32])
torch.Size([32, 238, 8, 512]) torch.Size([32, 238, 8, 4]) torch.Size([32, 238, 8, 6]) torch.Size([32, 238, 8, 3]) torch.Size([32, 238, 8, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 238, 8, 4]) torch.Size([32, 238]) torch.Size([32, 238]) torch.Size([32])
torch.Size([32, 238, 8, 512]) torch.Size([32, 238, 8, 4]) torch.Size([32, 238, 8, 6]) torch.Size([32, 238, 8, 3]) torch.Size([32, 238, 8, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 238, 8, 4]) torch.Size([32, 238]) torch.Size([32, 238]) torch.Size([32])
torch.Size([32, 238, 8, 512]) torch.Size([32, 238, 8, 4]) torch.Size([32, 238, 8, 

torch.Size([32, 2]) torch.Size([32, 238, 8, 4]) torch.Size([32, 238]) torch.Size([32, 238]) torch.Size([32])
torch.Size([32, 238, 8, 512]) torch.Size([32, 238, 8, 4]) torch.Size([32, 238, 8, 6]) torch.Size([32, 238, 8, 3]) torch.Size([32, 238, 8, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 238, 8, 4]) torch.Size([32, 238]) torch.Size([32, 238]) torch.Size([32])
torch.Size([32, 238, 8, 512]) torch.Size([32, 238, 8, 4]) torch.Size([32, 238, 8, 6]) torch.Size([32, 238, 8, 3]) torch.Size([32, 238, 8, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 238, 8, 4]) torch.Size([32, 238]) torch.Size([32, 238]) torch.Size([32])
torch.Size([32, 238, 8, 512]) torch.Size([32, 238, 8, 4]) torch.Size([32, 238, 8, 6]) torch.Size([32, 238, 8, 3]) torch.Size([32, 238, 8, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 238, 8, 4]) torch.Size([32, 238]) torch.Size([32, 238]) torch.Size([32])
torch.Size([32, 238, 8, 512]) torch.Size([32, 238, 8, 4]) torch

torch.Size([32, 2]) torch.Size([32, 11, 6, 4]) torch.Size([32, 11]) torch.Size([32, 11]) torch.Size([32])
torch.Size([32, 11, 6, 512]) torch.Size([32, 11, 6, 4]) torch.Size([32, 11, 6, 6]) torch.Size([32, 11, 6, 3]) torch.Size([32, 11, 6, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 11, 6, 4]) torch.Size([32, 11]) torch.Size([32, 11]) torch.Size([32])
torch.Size([32, 11, 6, 512]) torch.Size([32, 11, 6, 4]) torch.Size([32, 11, 6, 6]) torch.Size([32, 11, 6, 3]) torch.Size([32, 11, 6, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 11, 6, 4]) torch.Size([32, 11]) torch.Size([32, 11]) torch.Size([32])
torch.Size([32, 11, 6, 512]) torch.Size([32, 11, 6, 4]) torch.Size([32, 11, 6, 6]) torch.Size([32, 11, 6, 3]) torch.Size([32, 11, 6, 525]) torch.Size([32, 1024])
torch.Size([19, 2]) torch.Size([19, 11, 6, 4]) torch.Size([19, 11]) torch.Size([19, 11]) torch.Size([19])
torch.Size([19, 11, 6, 512]) torch.Size([19, 11, 6, 4]) torch.Size([19, 11, 6, 6]) torch.S



torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch

torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([3

torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([3

torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([3

torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([3

torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([3

In [84]:
results = """results_threshold_0.59.txt:Final Results (Test): [0.8011695906432749, 0.757123473541384, 0.7469879518072289, 0.7520215633423182]
results_threshold_0.59.txt:Final Results (Instances): [0.22545454545454546, 0.1306122448979592, 0.9142857142857143, 0.22857142857142856]
results_threshold_0.59.txt:Final Results (Knocked out): [0.6942939689418562, 0.9236127792457363, 0.6942939689418562, 0.7927017833213071]
results_threshold_0.5.txt:Final Results (Test): [0.8011695906432749, 0.757123473541384, 0.7469879518072289, 0.7520215633423182]
results_threshold_0.5.txt:Final Results (Instances): [0.22545454545454546, 0.1306122448979592, 0.9142857142857143, 0.22857142857142856]
results_threshold_0.5.txt:Final Results (Knocked out): [0.6942939689418562, 0.9236127792457363, 0.6942939689418562, 0.7927017833213071]
results_threshold_0.65.txt:Final Results (Test): [0.8011695906432749, 0.757123473541384, 0.7469879518072289, 0.7520215633423182]
results_threshold_0.65.txt:Final Results (Instances): [0.22545454545454546, 0.1306122448979592, 0.9142857142857143, 0.22857142857142856]
results_threshold_0.65.txt:Final Results (Knocked out): [0.6942939689418562, 0.9236127792457363, 0.6942939689418562, 0.7927017833213071]
results_threshold_0.66.txt:Final Results (Test): [0.7986633249791144, 0.7599451303155007, 0.7416331994645248, 0.7506775067750678]
results_threshold_0.66.txt:Final Results (Instances): [0.24, 0.13278008298755187, 0.9142857142857143, 0.23188405797101447]
results_threshold_0.66.txt:Final Results (Knocked out): [0.6816540267244493, 0.923660386591632, 0.6816540267244493, 0.7844155844155845]
results_threshold_0.67.txt:Final Results (Test): [0.7986633249791144, 0.7599451303155007, 0.7416331994645248, 0.7506775067750678]
results_threshold_0.67.txt:Final Results (Instances): [0.24, 0.13278008298755187, 0.9142857142857143, 0.23188405797101447]
results_threshold_0.67.txt:Final Results (Knocked out): [0.6816540267244493, 0.923660386591632, 0.6816540267244493, 0.7844155844155845]
results_threshold_0.68.txt:Final Results (Test): [0.8036758563074352, 0.7629427792915532, 0.749665327978581, 0.7562457798784605]
results_threshold_0.68.txt:Final Results (Instances): [0.23636363636363636, 0.1322314049586777, 0.9142857142857143, 0.23104693140794227]
results_threshold_0.68.txt:Final Results (Knocked out): [0.6778620440592271, 0.9248583394924859, 0.6778620440592271, 0.7823278107742002]
results_threshold_0.69.txt:Final Results (Test): [0.8036758563074352, 0.7629427792915532, 0.749665327978581, 0.7562457798784605]
results_threshold_0.69.txt:Final Results (Instances): [0.23636363636363636, 0.1322314049586777, 0.9142857142857143, 0.23104693140794227]
results_threshold_0.69.txt:Final Results (Knocked out): [0.6778620440592271, 0.9248583394924859, 0.6778620440592271, 0.7823278107742002]
results_threshold_0.6.txt:Final Results (Test): [0.8011695906432749, 0.757123473541384, 0.7469879518072289, 0.7520215633423182]
results_threshold_0.6.txt:Final Results (Instances): [0.22545454545454546, 0.1306122448979592, 0.9142857142857143, 0.22857142857142856]
results_threshold_0.6.txt:Final Results (Knocked out): [0.6942939689418562, 0.9236127792457363, 0.6942939689418562, 0.7927017833213071]
results_threshold_0.71.txt:Final Results (Test): [0.7969924812030075, 0.7493224932249323, 0.7402945113788487, 0.7447811447811449]
results_threshold_0.71.txt:Final Results (Instances): [0.22181818181818183, 0.13008130081300814, 0.9142857142857143, 0.22775800711743777]
results_threshold_0.71.txt:Final Results (Knocked out): [0.6868905742145178, 0.9224054316197866, 0.6868905742145178, 0.7874146139515629]
results_threshold_0.72.txt:Final Results (Test): [0.8045112781954887, 0.7664835164835165, 0.7469879518072289, 0.7566101694915256]
results_threshold_0.72.txt:Final Results (Instances): [0.23636363636363636, 0.13278008298755187, 0.9142857142857143, 0.23188405797101447]
results_threshold_0.72.txt:Final Results (Knocked out): [0.6807511737089202, 0.9256076602013258, 0.6807511737089202, 0.7845177400894808]
results_threshold_0.73.txt:Final Results (Test): [0.7953216374269005, 0.7452830188679245, 0.7402945113788487, 0.7427803895231698]
results_threshold_0.73.txt:Final Results (Instances): [0.22181818181818183, 0.13008130081300814, 0.9142857142857143, 0.22775800711743777]
results_threshold_0.73.txt:Final Results (Knocked out): [0.6912242686890574, 0.9221874247169357, 0.6912242686890574, 0.7901744246052224]
results_threshold_0.74.txt:Final Results (Test): [0.8028404344193818, 0.7643835616438356, 0.7469879518072289, 0.7555856465809073]
results_threshold_0.74.txt:Final Results (Instances): [0.24, 0.1297071129707113, 0.8857142857142857, 0.22627737226277375]
results_threshold_0.74.txt:Final Results (Knocked out): [0.678945467677862, 0.9229258713794797, 0.678945467677862, 0.7823553890969621]
results_threshold_0.75.txt:Final Results (Test): [0.8011695906432749, 0.7665745856353591, 0.7429718875502008, 0.7545887151597553]
results_threshold_0.75.txt:Final Results (Instances): [0.23636363636363636, 0.13278008298755187, 0.9142857142857143, 0.23188405797101447]
results_threshold_0.75.txt:Final Results (Knocked out): [0.6764174792343807, 0.9260815822002472, 0.6764174792343807, 0.7818011061254304]
results_threshold_0.76.txt:Final Results (Test): [0.8020050125313283, 0.7637362637362637, 0.7443105756358769, 0.7538983050847459]
results_threshold_0.76.txt:Final Results (Instances): [0.24, 0.13278008298755187, 0.9142857142857143, 0.23188405797101447]
results_threshold_0.76.txt:Final Results (Knocked out): [0.6793066088840737, 0.9243243243243243, 0.6793066088840737, 0.7830974188176519]
results_threshold_0.77.txt:Final Results (Test): [0.7986633249791144, 0.7575342465753425, 0.7402945113788487, 0.7488151658767774]
results_threshold_0.77.txt:Final Results (Instances): [0.22545454545454546, 0.12757201646090535, 0.8857142857142857, 0.22302158273381295]
results_threshold_0.77.txt:Final Results (Knocked out): [0.6767786204405922, 0.9240631163708086, 0.6767786204405922, 0.7813216593704398]
results_threshold_0.78.txt:Final Results (Test): [0.8011695906432749, 0.7634112792297112, 0.7429718875502008, 0.7530529172320216]
results_threshold_0.78.txt:Final Results (Instances): [0.23272727272727273, 0.13168724279835392, 0.9142857142857143, 0.2302158273381295]
results_threshold_0.78.txt:Final Results (Knocked out): [0.6775009028530156, 0.9252774352651049, 0.6775009028530156, 0.7822370478473888]
results_threshold_0.79.txt:Final Results (Test): [0.8036758563074352, 0.7661623108665749, 0.7456492637215528, 0.7557666214382632]
results_threshold_0.79.txt:Final Results (Instances): [0.23272727272727273, 0.13168724279835392, 0.9142857142857143, 0.2302158273381295]
results_threshold_0.79.txt:Final Results (Knocked out): [0.6776814734561214, 0.9246119733924612, 0.6776814734561214, 0.7821194123163488]
results_threshold_0.7.txt:Final Results (Test): [0.8036758563074352, 0.7629427792915532, 0.749665327978581, 0.7562457798784605]
results_threshold_0.7.txt:Final Results (Instances): [0.23636363636363636, 0.1322314049586777, 0.9142857142857143, 0.23104693140794227]
results_threshold_0.7.txt:Final Results (Knocked out): [0.6778620440592271, 0.9248583394924859, 0.6778620440592271, 0.7823278107742002]
results_threshold_0.81.txt:Final Results (Test): [0.8011695906432749, 0.7641379310344828, 0.7416331994645248, 0.7527173913043479]
results_threshold_0.81.txt:Final Results (Instances): [0.23636363636363636, 0.13278008298755187, 0.9142857142857143, 0.23188405797101447]
results_threshold_0.81.txt:Final Results (Knocked out): [0.6784037558685446, 0.923095823095823, 0.6784037558685446, 0.7820566194837636]
results_threshold_0.8.txt:Final Results (Test): [0.7953216374269005, 0.7632311977715878, 0.7336010709504686, 0.7481228668941979]
results_threshold_0.8.txt:Final Results (Instances): [0.24, 0.13278008298755187, 0.9142857142857143, 0.23188405797101447]
results_threshold_0.8.txt:Final Results (Knocked out): [0.674070061394005, 0.9251548946716233, 0.674070061394005, 0.7799018071659876]
results_threshold_0.82.txt:Final Results (Test): [0.8053467000835421, 0.7706043956043956, 0.751004016064257, 0.760677966101695]
results_threshold_0.82.txt:Final Results (Instances): [0.24363636363636362, 0.13389121338912133, 0.9142857142857143, 0.2335766423357664]
results_threshold_0.82.txt:Final Results (Knocked out): [0.6836403033586133, 0.9222898903775884, 0.6836403033586133, 0.7852328113657576]
results_threshold_0.83.txt:Final Results (Test): [0.7994987468671679, 0.7547683923705722, 0.7416331994645248, 0.7481431465226198]
results_threshold_0.83.txt:Final Results (Instances): [0.22545454545454546, 0.12757201646090535, 0.8857142857142857, 0.22302158273381295]
results_threshold_0.83.txt:Final Results (Knocked out): [0.6924882629107981, 0.9234288466169035, 0.6924882629107981, 0.7914559900939016]
results_threshold_0.84.txt:Final Results (Test): [0.8061821219715957, 0.76775956284153, 0.7523427041499331, 0.759972954699121]
results_threshold_0.84.txt:Final Results (Instances): [0.23272727272727273, 0.12863070539419086, 0.8857142857142857, 0.22463768115942026]
results_threshold_0.84.txt:Final Results (Knocked out): [0.6821957385337667, 0.925073457394711, 0.6821957385337667, 0.7852837247973394]
results_threshold_0.85.txt:Final Results (Test): [0.8053467000835421, 0.7712328767123288, 0.7536813922356091, 0.7623561272850371]
results_threshold_0.85.txt:Final Results (Instances): [0.23636363636363636, 0.12863070539419086, 0.8857142857142857, 0.22463768115942026]
results_threshold_0.85.txt:Final Results (Knocked out): [0.678945467677862, 0.9240599655935119, 0.678945467677862, 0.7827625689601332]
results_threshold_0.86.txt:Final Results (Test): [0.8061821219715957, 0.7720994475138122, 0.7483266398929049, 0.760027192386132]
results_threshold_0.86.txt:Final Results (Instances): [0.2545454545454545, 0.13559322033898305, 0.9142857142857143, 0.23616236162361626]
results_threshold_0.86.txt:Final Results (Knocked out): [0.69068255687974, 0.925477861117832, 0.69068255687974, 0.791024713059663]
results_threshold_0.87.txt:Final Results (Test): [0.8061821219715957, 0.7713498622589532, 0.749665327978581, 0.7603530210454854]
results_threshold_0.87.txt:Final Results (Instances): [0.24, 0.12916666666666668, 0.8857142857142857, 0.2254545454545455]
results_threshold_0.87.txt:Final Results (Knocked out): [0.6634163958107621, 0.9224202862164198, 0.6634163958107621, 0.7717676714630817]
results_threshold_0.88.txt:Final Results (Test): [0.8045112781954887, 0.7754532775453278, 0.7443105756358769, 0.7595628415300547]
results_threshold_0.88.txt:Final Results (Instances): [0.2509090909090909, 0.1350210970464135, 0.9142857142857143, 0.23529411764705882]
results_threshold_0.88.txt:Final Results (Knocked out): [0.6540267244492597, 0.9230377166156982, 0.6540267244492597, 0.765588670471359]
results_threshold_0.89.txt:Final Results (Test): [0.8070175438596491, 0.7739251040221914, 0.7469879518072289, 0.7602179836512262]
results_threshold_0.89.txt:Final Results (Instances): [0.24363636363636362, 0.13333333333333333, 0.9142857142857143, 0.23272727272727273]
results_threshold_0.89.txt:Final Results (Knocked out): [0.6538461538461539, 0.9230180983940861, 0.6538461538461539, 0.7654581968079485]
results_threshold_0.91.txt:Final Results (Test): [0.808688387635756, 0.774281805745554, 0.7576974564926372, 0.7658998646820027]
results_threshold_0.91.txt:Final Results (Instances): [0.22545454545454546, 0.1306122448979592, 0.9142857142857143, 0.22857142857142856]
results_threshold_0.91.txt:Final Results (Knocked out): [0.6652221018418202, 0.9219219219219219, 0.6652221018418202, 0.7728130899937069]
results_threshold_0.92.txt:Final Results (Test): [0.8036758563074352, 0.7710344827586207, 0.7483266398929049, 0.7595108695652173]
results_threshold_0.92.txt:Final Results (Instances): [0.2509090909090909, 0.13445378151260504, 0.9142857142857143, 0.2344322344322344]
results_threshold_0.92.txt:Final Results (Knocked out): [0.6561935716865295, 0.9218670725520041, 0.6561935716865295, 0.7666666666666667]
results_threshold_0.93.txt:Final Results (Test): [0.8036758563074352, 0.7710344827586207, 0.7483266398929049, 0.7595108695652173]
results_threshold_0.93.txt:Final Results (Instances): [0.2509090909090909, 0.13445378151260504, 0.9142857142857143, 0.2344322344322344]
results_threshold_0.93.txt:Final Results (Knocked out): [0.6561935716865295, 0.9218670725520041, 0.6561935716865295, 0.7666666666666667]
results_threshold_0.94.txt:Final Results (Test): [0.8045112781954887, 0.7712328767123288, 0.7536813922356091, 0.7623561272850371]
results_threshold_0.94.txt:Final Results (Instances): [0.24, 0.12916666666666668, 0.8857142857142857, 0.2254545454545455]
results_threshold_0.94.txt:Final Results (Knocked out): [0.6612495485734923, 0.9219536757301108, 0.6612495485734923, 0.770136698212408]
results_threshold_0.95.txt:Final Results (Test): [0.8028404344193818, 0.7681755829903978, 0.749665327978581, 0.7588075880758807]
results_threshold_0.95.txt:Final Results (Instances): [0.24, 0.12916666666666668, 0.8857142857142857, 0.2254545454545455]
results_threshold_0.95.txt:Final Results (Knocked out): [0.6625135427952329, 0.921627731725697, 0.6625135427952329, 0.770879294043492]
results_threshold_0.96.txt:Final Results (Test): [0.8028404344193818, 0.7675378266850069, 0.7469879518072289, 0.757123473541384]
results_threshold_0.96.txt:Final Results (Instances): [0.24363636363636362, 0.13333333333333333, 0.9142857142857143, 0.23272727272727273]
results_threshold_0.96.txt:Final Results (Knocked out): [0.6655832430480317, 0.9221916437327996, 0.6655832430480317, 0.7731515469323543]
results_threshold_0.97.txt:Final Results (Test): [0.8028404344193818, 0.7675378266850069, 0.7469879518072289, 0.757123473541384]
results_threshold_0.97.txt:Final Results (Instances): [0.24363636363636362, 0.13333333333333333, 0.9142857142857143, 0.23272727272727273]
results_threshold_0.97.txt:Final Results (Knocked out): [0.6655832430480317, 0.9221916437327996, 0.6655832430480317, 0.7731515469323543]
results_threshold_0.98.txt:Final Results (Test): [0.8028404344193818, 0.7675378266850069, 0.7469879518072289, 0.757123473541384]
results_threshold_0.98.txt:Final Results (Instances): [0.24363636363636362, 0.13333333333333333, 0.9142857142857143, 0.23272727272727273]
results_threshold_0.98.txt:Final Results (Knocked out): [0.6655832430480317, 0.9221916437327996, 0.6655832430480317, 0.7731515469323543]
results_threshold_0.99.txt:Final Results (Test): [0.8028404344193818, 0.7675378266850069, 0.7469879518072289, 0.757123473541384]
results_threshold_0.99.txt:Final Results (Instances): [0.24363636363636362, 0.13333333333333333, 0.9142857142857143, 0.23272727272727273]
results_threshold_0.99.txt:Final Results (Knocked out): [0.6655832430480317, 0.9221916437327996, 0.6655832430480317, 0.7731515469323543]
results_threshold_0.9.txt:Final Results (Test): [0.8061821219715957, 0.7701778385772914, 0.7536813922356091, 0.7618403247631934]
results_threshold_0.9.txt:Final Results (Instances): [0.24363636363636362, 0.1297071129707113, 0.8857142857142857, 0.22627737226277375]
results_threshold_0.9.txt:Final Results (Knocked out): [0.6590827013362225, 0.9221829206670036, 0.6590827013362225, 0.7687447346251053]
results_threshold_1.0.txt:Final Results (Test): [0.8061821219715957, 0.7741046831955923, 0.7523427041499331, 0.7630685675492193]
results_threshold_1.0.txt:Final Results (Instances): [0.23272727272727273, 0.1322314049586777, 0.9142857142857143, 0.23104693140794227]
results_threshold_1.0.txt:Final Results (Knocked out): [0.6686529433008306, 0.9234413965087281, 0.6686529433008306, 0.7756598240469208]"""

rownames, colnames, data = list(zip(*[(float(line.split(".txt")[0].split("_")[-1]), line.split("(")[-1].split(")")[0], float(line.split("[")[-1].split("]")[0].split(",")[-1].strip())) for line in results.split("\n")]))
from orderedset import OrderedSet
rownames, colnames = OrderedSet(rownames), OrderedSet(colnames)
data = np.array(data).reshape((len(rownames), len(colnames))).tolist()
import pandas as pd
pd.DataFrame(data, columns=colnames, index=rownames).sort_index()

Unnamed: 0,Test,Instances,Knocked out
0.5,0.752022,0.228571,0.792702
0.59,0.752022,0.228571,0.792702
0.6,0.752022,0.228571,0.792702
0.65,0.752022,0.228571,0.792702
0.66,0.750678,0.231884,0.784416
0.67,0.750678,0.231884,0.784416
0.68,0.756246,0.231047,0.782328
0.69,0.756246,0.231047,0.782328
0.7,0.756246,0.231047,0.782328
0.71,0.744781,0.227758,0.787415


In [70]:
targets

tensor([3, 0, 3, 4])

In [63]:
max_paths

289

In [48]:
[el.item() for el in torch.Tensor([1,2,3,4])]

[1.0, 2.0, 3.0, 4.0]