In [2]:
import sys
sys.argv = ["main", "data_use_0.8.pkl", "results_sample.txt", "sample.pt"]

In [67]:
import pickle, torch, os, sys, random
import numpy as np
from math import ceil
import torch.optim as optim
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
from sklearn.metrics import accuracy_score

dataset_file = os.path.abspath("../Input/" + sys.argv[1])
results_file = os.path.abspath("../Results/" + sys.argv[2])
model_file =  os.path.abspath("../Models/" + sys.argv[3])

f = open(dataset_file, "rb")
(nodes_train, paths_train, counts_train, targets_train, 
 nodes_test, paths_test, counts_test, targets_test,
 nodes_instances, paths_instances, counts_instances, targets_instances,
 nodes_knocked, paths_knocked, counts_knocked, targets_knocked,
 emb_indexer, emb_indexer_inv, emb_vals, 
 pos_indexer, dep_indexer, dir_indexer, rel_indexer) = pickle.load(f)

def write(statement):
    op_file = open(results_file, "w+")
    op_file.write("\n" + str(statement) + "\n")
    op_file.close()

POS_DIM = 4
DEP_DIM = 6
DIR_DIM = 3
NUM_RELATIONS = len(rel_indexer)
NULL_EDGE = [0, 0, 0, 0]

torch.set_default_dtype(torch.float64)
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

flatten = lambda l: [item for sublist in l for item in sublist]

class RelationPredictor(nn.Module):

    def __init__(self, emb_vals):
        
        super(RelationPredictor, self).__init__()

        self.EMBEDDING_DIM = np.array(emb_vals).shape[1]
        self.n_directions = 2 if bidirectional else 1
        
        self.input_dim = POS_DIM + DEP_DIM + self.EMBEDDING_DIM + DIR_DIM
        self.output_dim = self.n_directions * HIDDEN_DIM + 2 * self.EMBEDDING_DIM
        # self.layer1_dim = LAYER1_DIM
        # self.W1 = nn.Linear(self.hidden_dim, self.layer1_dim)
        # self.W2 = nn.Linear(self.layer1_dim, NUM_RELATIONS)

        self.dropout_layer = nn.Dropout(p=dropout)
        self.log_softmax = nn.LogSoftmax()
        
        self.name_embeddings = nn.Embedding(len(emb_vals), self.EMBEDDING_DIM)
        self.name_embeddings.load_state_dict({'weight': torch.from_numpy(np.array(emb_vals))})
        self.name_embeddings.weight.requires_grad = False

        self.pos_embeddings = nn.Embedding(len(pos_indexer), POS_DIM)
        self.dep_embeddings = nn.Embedding(len(dep_indexer), DEP_DIM)
        self.dir_embeddings = nn.Embedding(len(dir_indexer), DIR_DIM)

        nn.init.xavier_uniform_(self.pos_embeddings.weight)
        nn.init.xavier_uniform_(self.dep_embeddings.weight)
        nn.init.xavier_uniform_(self.dir_embeddings.weight)
        
        self.lstm = nn.LSTM(self.input_dim, HIDDEN_DIM, NUM_LAYERS, bidirectional=bidirectional, batch_first=True)

        self.W = nn.Linear(self.output_dim, NUM_RELATIONS)

    def forward(self, nodes, paths, counts, edgecounts, max_paths, max_edges):
        '''
            nodes: batch_size * 2
            paths: batch_size * max_paths * max_edges * 4
            counts: batch_size * max_paths
            edgecounts: batch_size * max_paths
        '''
        word_embed = self.dropout_layer(self.name_embeddings(paths[:,:,:,0]))
        pos_embed = self.dropout_layer(self.pos_embeddings(paths[:,:,:,1]))
        dep_embed = self.dropout_layer(self.dep_embeddings(paths[:,:,:,2]))
        dir_embed = self.dropout_layer(self.dir_embeddings(paths[:,:,:,3]))
        paths_embed = torch.cat((word_embed, pos_embed, dep_embed, dir_embed), dim=-1)
        nodes_embed = self.dropout_layer(self.name_embeddings(nodes)).reshape(-1, 2*self.EMBEDDING_DIM)
        
        print (word_embed.shape, pos_embed.shape, dep_embed.shape, dir_embed.shape, paths_embed.shape, nodes_embed.shape)

        paths_embed = paths_embed.reshape((-1, max_edges, self.input_dim))

        paths_packed = pack_padded_sequence(paths_embed, torch.flatten(edgecounts), batch_first=True, enforce_sorted=False)
        _, (hidden_state, _) = self.lstm(paths_packed)
        paths_output = hidden_state.permute(1,2,0)
        paths_output_reshaped = paths_output.reshape(-1, max_paths, HIDDEN_DIM*NUM_LAYERS*self.n_directions)
        # paths_output has dim (batch_size, max_paths, HIDDEN_DIM, NUM_LAYERS*self.n_directions)

        paths_weighted = torch.bmm(paths_output_reshaped.permute(0,2,1), counts.unsqueeze(-1)).squeeze(-1)
        representation = torch.cat((nodes_embed, paths_weighted), dim=-1)
        probabilities = self.log_softmax(self.W(representation))
        return probabilities

def to_list(seq):
    for item in seq:
        if isinstance(item, tuple):
            yield list(to_list(item))
        elif isinstance(item, list):
            yield [list(to_list(elem)) for elem in item]
        else:
            yield item

def pad_paths(paths, max_paths, max_edges):
    paths_edgepadded = [[path + [NULL_EDGE for i in range(max_edges-len(path))]
        for path in element]
    for element in paths]
    NULL_PATH = [NULL_EDGE for i in range(max_edges)]
    paths_padded = [element + [NULL_PATH for i in range(max_paths-len(element))] 
        for element in paths_edgepadded]
    return np.array(paths_padded)
        
def pad_counts(counts, max_paths):
    return np.array([elem + [0 for i in range(max_paths - len(elem))] for elem in counts])

def pad_edgecounts(edgecounts, max_paths):
    return np.array([elem + [1 for i in range(max_paths - len(elem))] for elem in edgecounts])

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

HIDDEN_DIM = 250
# LAYER1_DIM = 120
NUM_LAYERS = 1
num_epochs = 1
batch_size = 32
bidirectional = True

lr = 0.001
dropout = 0.3
weight_decay = 0.001

model = nn.DataParallel(RelationPredictor(emb_vals)).to(device)
criterion = nn.NLLLoss()
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

for epoch in range(num_epochs):
    
    all_inp = list(zip(nodes_train, paths_train, counts_train, targets_train))
    all_inp_shuffled = random.sample(all_inp, len(all_inp))
    nodes_train, paths_train, counts_train, targets_train = list(zip(*all_inp_shuffled[:100]))

    num_edges_all = [[len(path) for path in element] for element in paths_train]
    max_edges = max(flatten(num_edges_all))
    max_paths = max([len(elem) for elem in counts_train])

    dataset_size = len(nodes_train)
    batch_size = min(batch_size, dataset_size)
    num_batches = int(ceil(dataset_size/batch_size))

    for batch_idx in range(num_batches):
        
        batch_start = batch_idx * batch_size
        batch_end = (batch_idx+1) * batch_size
        
        nodes = torch.LongTensor(nodes_train[batch_start:batch_end]).to(device)
        paths = torch.LongTensor(pad_paths(paths_train[batch_start:batch_end], max_paths, max_edges)).to(device)
        counts = torch.DoubleTensor(pad_counts(counts_train[batch_start:batch_end], max_paths)).to(device)
        edgecounts = torch.LongTensor(pad_edgecounts(num_edges_all[batch_start:batch_end], max_paths)).to(device)
        targets = torch.LongTensor(targets_train[batch_start:batch_end]).to(device)
        
        # Backprop and perform Adam optimisation
        optimizer.zero_grad()

        # Run the forward pass
        outputs = model(nodes, paths, counts, edgecounts, max_paths, max_edges)

        #loss = log_loss(outputs, torch.LongTensor(labels).to(device))
        loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()

        print("Epoch: {} Idx: {} Loss: {}".format(epoch, batch_idx, loss.item()))  

print("Training Complete!")

def calculate_recall(true, pred):
    true_f, pred_f = [], []
    for i,elem in enumerate(true):
        if elem!=4:
            true_f.append(elem)
            pred_f.append(pred[i])
    return accuracy_score(true_f, pred_f)

def calculate_precision(true, pred):
    true_f, pred_f = [], []
    for i,elem in enumerate(pred):
        if elem!=4:
            pred_f.append(elem)
            true_f.append(true[i])
    return accuracy_score(true_f, pred_f)

def test(nodes_test, paths_test, counts_test, targets_test, message):
    predictedLabels, trueLabels = [], []

    num_edges_all = [[len(path) for path in element] for element in paths_test]
    max_edges = max(flatten(num_edges_all))
    max_paths = max([len(elem) for elem in counts_test])
    print (max_paths, max_edges)
    dataset_size = len(nodes_test)
    batch_size = min(32, dataset_size)
    num_batches = int(ceil(dataset_size/batch_size))

    for batch_idx in range(num_batches):
        
        batch_start = batch_idx * batch_size
        batch_end = (batch_idx+1) * batch_size

        nodes = torch.LongTensor(nodes_test[batch_start:batch_end])
        paths = torch.LongTensor(pad_paths(paths_test[batch_start:batch_end], max_paths, max_edges))
        counts = torch.DoubleTensor(pad_counts(counts_test[batch_start:batch_end], max_paths))
        edgecounts = torch.LongTensor(pad_edgecounts(num_edges_all[batch_start:batch_end], max_paths))
        targets = torch.LongTensor(targets_test[batch_start:batch_end])
        
        print (nodes.shape, paths.shape, counts.shape, edgecounts.shape, targets.shape)
        
        outputs = model(nodes, paths, counts, edgecounts, max_paths, max_edges)
        _, predicted = torch.max(outputs, 1)
        predicted = [el.item() for el in predicted]
        targets = [el.item() for el in targets]
        predictedLabels.extend(predicted)
        trueLabels.extend(targets)
    
    print (trueLabels, predictedLabels)
    accuracy = accuracy_score(trueLabels, predictedLabels)
    recall = calculate_recall(trueLabels, predictedLabels)
    precision = calculate_precision(trueLabels, predictedLabels)
    final_metrics = [accuracy, precision, recall, 2 * (precision * recall/(precision + recall))]
    print("Final Results ({}): [{}]".format(message, ", ".join([str(el) for el in final_metrics])))

model.eval()
with torch.no_grad():
    test(nodes_test, paths_test, counts_test, targets_test, "Test")
    test(nodes_instances, paths_instances, counts_instances, targets_instances, "Instances")
    test(nodes_knocked, paths_knocked, counts_knocked, targets_knocked, "Knocked out")


torch.Size([32, 6, 5, 512]) torch.Size([32, 6, 5, 4]) torch.Size([32, 6, 5, 6]) torch.Size([32, 6, 5, 3]) torch.Size([32, 6, 5, 525]) torch.Size([32, 1024])
Epoch: 0 Idx: 0 Loss: 1.6065702650937592
torch.Size([32, 6, 5, 512]) torch.Size([32, 6, 5, 4]) torch.Size([32, 6, 5, 6]) torch.Size([32, 6, 5, 3]) torch.Size([32, 6, 5, 525]) torch.Size([32, 1024])
Epoch: 0 Idx: 1 Loss: 1.589469174306449
torch.Size([32, 6, 5, 512]) torch.Size([32, 6, 5, 4]) torch.Size([32, 6, 5, 6]) torch.Size([32, 6, 5, 3]) torch.Size([32, 6, 5, 525]) torch.Size([32, 1024])




Epoch: 0 Idx: 2 Loss: 1.5615394039610633
torch.Size([4, 6, 5, 512]) torch.Size([4, 6, 5, 4]) torch.Size([4, 6, 5, 6]) torch.Size([4, 6, 5, 3]) torch.Size([4, 6, 5, 525]) torch.Size([4, 1024])
Epoch: 0 Idx: 3 Loss: 1.563932163269547
Training Complete!
238 8
torch.Size([32, 2]) torch.Size([32, 238, 8, 4]) torch.Size([32, 238]) torch.Size([32, 238]) torch.Size([32])
torch.Size([32, 238, 8, 512]) torch.Size([32, 238, 8, 4]) torch.Size([32, 238, 8, 6]) torch.Size([32, 238, 8, 3]) torch.Size([32, 238, 8, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 238, 8, 4]) torch.Size([32, 238]) torch.Size([32, 238]) torch.Size([32])
torch.Size([32, 238, 8, 512]) torch.Size([32, 238, 8, 4]) torch.Size([32, 238, 8, 6]) torch.Size([32, 238, 8, 3]) torch.Size([32, 238, 8, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 238, 8, 4]) torch.Size([32, 238]) torch.Size([32, 238]) torch.Size([32])
torch.Size([32, 238, 8, 512]) torch.Size([32, 238, 8, 4]) torch.Size([32, 238, 8, 

torch.Size([32, 2]) torch.Size([32, 238, 8, 4]) torch.Size([32, 238]) torch.Size([32, 238]) torch.Size([32])
torch.Size([32, 238, 8, 512]) torch.Size([32, 238, 8, 4]) torch.Size([32, 238, 8, 6]) torch.Size([32, 238, 8, 3]) torch.Size([32, 238, 8, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 238, 8, 4]) torch.Size([32, 238]) torch.Size([32, 238]) torch.Size([32])
torch.Size([32, 238, 8, 512]) torch.Size([32, 238, 8, 4]) torch.Size([32, 238, 8, 6]) torch.Size([32, 238, 8, 3]) torch.Size([32, 238, 8, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 238, 8, 4]) torch.Size([32, 238]) torch.Size([32, 238]) torch.Size([32])
torch.Size([32, 238, 8, 512]) torch.Size([32, 238, 8, 4]) torch.Size([32, 238, 8, 6]) torch.Size([32, 238, 8, 3]) torch.Size([32, 238, 8, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 238, 8, 4]) torch.Size([32, 238]) torch.Size([32, 238]) torch.Size([32])
torch.Size([32, 238, 8, 512]) torch.Size([32, 238, 8, 4]) torch

torch.Size([32, 2]) torch.Size([32, 11, 6, 4]) torch.Size([32, 11]) torch.Size([32, 11]) torch.Size([32])
torch.Size([32, 11, 6, 512]) torch.Size([32, 11, 6, 4]) torch.Size([32, 11, 6, 6]) torch.Size([32, 11, 6, 3]) torch.Size([32, 11, 6, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 11, 6, 4]) torch.Size([32, 11]) torch.Size([32, 11]) torch.Size([32])
torch.Size([32, 11, 6, 512]) torch.Size([32, 11, 6, 4]) torch.Size([32, 11, 6, 6]) torch.Size([32, 11, 6, 3]) torch.Size([32, 11, 6, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 11, 6, 4]) torch.Size([32, 11]) torch.Size([32, 11]) torch.Size([32])
torch.Size([32, 11, 6, 512]) torch.Size([32, 11, 6, 4]) torch.Size([32, 11, 6, 6]) torch.Size([32, 11, 6, 3]) torch.Size([32, 11, 6, 525]) torch.Size([32, 1024])
torch.Size([19, 2]) torch.Size([19, 11, 6, 4]) torch.Size([19, 11]) torch.Size([19, 11]) torch.Size([19])
torch.Size([19, 11, 6, 512]) torch.Size([19, 11, 6, 4]) torch.Size([19, 11, 6, 6]) torch.S



torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch

torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([3

torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([3

torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([3

torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([3

torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311, 10, 6]) torch.Size([32, 311, 10, 3]) torch.Size([32, 311, 10, 525]) torch.Size([32, 1024])
torch.Size([32, 2]) torch.Size([32, 311, 10, 4]) torch.Size([32, 311]) torch.Size([32, 311]) torch.Size([32])
torch.Size([32, 311, 10, 512]) torch.Size([3

In [94]:
results = """results_dropout_0.3_0_0.15_.txt:Final Results (Test): [0.8011695906432749, 0.7585266030013642, 0.7443105756358769, 0.7513513513513513]
results_dropout_0.3_0_0.15_.txt:Final Results (Instances): [0.23636363636363636, 0.1322314049586777, 0.9142857142857143, 0.23104693140794227]
results_dropout_0.3_0_0.15_.txt:Final Results (Knocked out): [0.7170458649331889, 0.9230590423059042, 0.7170458649331889, 0.8071138211382113]
results_dropout_0.3_0.2_0.15_.txt:Final Results (Test): [0.8095238095238095, 0.773224043715847, 0.7576974564926372, 0.7653820148749154]
results_dropout_0.3_0.2_0.15_.txt:Final Results (Instances): [0.23636363636363636, 0.13278008298755187, 0.9142857142857143, 0.23188405797101447]
results_dropout_0.3_0.2_0.15_.txt:Final Results (Knocked out): [0.6979053810039726, 0.9163110478899953, 0.6979053810039726, 0.7923329233292333]
results_dropout_0.3_0.7_0.15_.txt:Final Results (Test): [0.808688387635756, 0.76797829036635, 0.7576974564926372, 0.7628032345013476]
results_dropout_0.3_0.7_0.15_.txt:Final Results (Instances): [0.22545454545454546, 0.12757201646090535, 0.8857142857142857, 0.22302158273381295]
results_dropout_0.3_0.7_0.15_.txt:Final Results (Knocked out): [0.7181292885518238, 0.919750231267345, 0.7181292885518238, 0.8065301155952138]
results_dropout_0.3_0.8_0.15_.txt:Final Results (Test): [0.8045112781954887, 0.7681755829903978, 0.749665327978581, 0.7588075880758807]
results_dropout_0.3_0.8_0.15_.txt:Final Results (Instances): [0.23272727272727273, 0.12863070539419086, 0.8857142857142857, 0.22463768115942026]
results_dropout_0.3_0.8_0.15_.txt:Final Results (Knocked out): [0.7186710003611412, 0.9206569511913023, 0.7186710003611412, 0.8072203630463441]
results_dropout_0.3_0.8_0_.txt:Final Results (Test): [0.808688387635756, 0.7701778385772914, 0.7536813922356091, 0.7618403247631934]
results_dropout_0.3_0.8_0_.txt:Final Results (Instances): [0.21818181818181817, 0.13008130081300814, 0.9142857142857143, 0.22775800711743777]
results_dropout_0.3_0.8_0_.txt:Final Results (Knocked out): [0.7040447815095703, 0.9208786017949929, 0.7040447815095703, 0.7979942693409743]
results_dropout_0.35_0_0.15_.txt:Final Results (Test): [0.8011695906432749, 0.7547425474254743, 0.7456492637215528, 0.7501683501683502]
results_dropout_0.35_0_0.15_.txt:Final Results (Instances): [0.2290909090909091, 0.13168724279835392, 0.9142857142857143, 0.2302158273381295]
results_dropout_0.35_0_0.15_.txt:Final Results (Knocked out): [0.7248104008667389, 0.9214876033057852, 0.7248104008667389, 0.8114008489993936]
results_dropout_0.35_0.2_0.15_.txt:Final Results (Test): [0.8028404344193818, 0.7654320987654321, 0.7469879518072289, 0.7560975609756098]
results_dropout_0.35_0.2_0.15_.txt:Final Results (Instances): [0.24, 0.13278008298755187, 0.9142857142857143, 0.23188405797101447]
results_dropout_0.35_0.2_0.15_.txt:Final Results (Knocked out): [0.7080173347778982, 0.9195590994371482, 0.7080173347778982, 0.8000408079983676]
results_dropout_0.35_0.7_0.15_.txt:Final Results (Test): [0.808688387635756, 0.76797829036635, 0.7576974564926372, 0.7628032345013476]
results_dropout_0.35_0.7_0.15_.txt:Final Results (Instances): [0.21454545454545454, 0.12601626016260162, 0.8857142857142857, 0.2206405693950178]
results_dropout_0.35_0.7_0.15_.txt:Final Results (Knocked out): [0.7143373058866017, 0.9170143718127028, 0.7143373058866017, 0.8030856678846934]
results_dropout_0.35_0.8_0.15_.txt:Final Results (Test): [0.8078529657477026, 0.7676630434782609, 0.7563587684069611, 0.7619689817936616]
results_dropout_0.35_0.8_0.15_.txt:Final Results (Instances): [0.2109090909090909, 0.12195121951219512, 0.8571428571428571, 0.21352313167259784]
results_dropout_0.35_0.8_0.15_.txt:Final Results (Knocked out): [0.720837847598411, 0.9206642066420664, 0.720837847598411, 0.8085882114644519]
results_dropout_0.35_0.8_0_.txt:Final Results (Test): [0.8162071846282373, 0.7827868852459017, 0.7670682730923695, 0.7748478701825559]
results_dropout_0.35_0.8_0_.txt:Final Results (Instances): [0.23272727272727273, 0.128099173553719, 0.8857142857142857, 0.22382671480144403]
results_dropout_0.35_0.8_0_.txt:Final Results (Knocked out): [0.704225352112676, 0.9215500945179584, 0.704225352112676, 0.7983623336745138]
results_dropout_0.3_0.2_0_.txt:Final Results (Test): [0.8028404344193818, 0.7612551159618008, 0.7469879518072289, 0.754054054054054]
results_dropout_0.3_0.2_0_.txt:Final Results (Instances): [0.22181818181818183, 0.1306122448979592, 0.9142857142857143, 0.22857142857142856]
results_dropout_0.3_0.2_0_.txt:Final Results (Knocked out): [0.6787648970747562, 0.9215493993625888, 0.6787648970747562, 0.7817406675678487]
results_dropout_0.3_0.7_0_.txt:Final Results (Test): [0.8095238095238095, 0.7693351424694709, 0.7590361445783133, 0.7641509433962265]
results_dropout_0.3_0.7_0_.txt:Final Results (Instances): [0.22181818181818183, 0.13008130081300814, 0.9142857142857143, 0.22775800711743777]
results_dropout_0.3_0.7_0_.txt:Final Results (Knocked out): [0.6946551101480679, 0.9201148050705573, 0.6946551101480679, 0.7916452309908427]
results_dropout_0.35_0.2_0_.txt:Final Results (Test): [0.8011695906432749, 0.7640603566529492, 0.7456492637215528, 0.7547425474254743]
results_dropout_0.35_0.2_0_.txt:Final Results (Instances): [0.2545454545454545, 0.13559322033898305, 0.9142857142857143, 0.23616236162361626]
results_dropout_0.35_0.2_0_.txt:Final Results (Knocked out): [0.6794871794871795, 0.9245700245700246, 0.6794871794871795, 0.7833055786844297]
results_dropout_0.35_0.7_0_.txt:Final Results (Test): [0.8078529657477026, 0.771117166212534, 0.7576974564926372, 0.7643484132343012]
results_dropout_0.35_0.7_0_.txt:Final Results (Instances): [0.24363636363636362, 0.13333333333333333, 0.9142857142857143, 0.23272727272727273]
results_dropout_0.35_0.7_0_.txt:Final Results (Knocked out): [0.6986276634163958, 0.9231686948222382, 0.6986276634163958, 0.7953540960016445]"""

rownames, colnames, data = list(zip(*[(",".join(line.split("_.txt")[0].split("_")[-3:]), line.split("(")[-1].split(")")[0], float(line.split("[")[-1].split("]")[0].split(",")[-1].strip())) for line in results.split("\n")]))
from orderedset import OrderedSet
rownames, colnames = OrderedSet(rownames), OrderedSet(colnames)
data = np.array(data).reshape((len(rownames), len(colnames))).tolist()
import pandas as pd
pd.DataFrame(data, columns=colnames, index=rownames).sort_index()

Unnamed: 0,Test,Instances,Knocked out
"0.3,0,0.15",0.751351,0.231047,0.807114
"0.3,0.2,0",0.754054,0.228571,0.781741
"0.3,0.2,0.15",0.765382,0.231884,0.792333
"0.3,0.7,0",0.764151,0.227758,0.791645
"0.3,0.7,0.15",0.762803,0.223022,0.80653
"0.3,0.8,0",0.76184,0.227758,0.797994
"0.3,0.8,0.15",0.758808,0.224638,0.80722
"0.35,0,0.15",0.750168,0.230216,0.811401
"0.35,0.2,0",0.754743,0.236162,0.783306
"0.35,0.2,0.15",0.756098,0.231884,0.800041


In [20]:
import torch
import torch.nn.functional as F
counts = torch.randn (32, 8)
div = torch.randn (8)
torch.div(counts, div)

tensor([[ 1.4354e-01, -2.4744e+00, -1.2262e+00, -4.0278e-01,  1.9652e+00,
          2.9542e+00,  6.6998e-02, -7.3829e+01],
        [-4.5266e-01, -1.3811e+00, -4.9185e-01, -1.9605e+00, -3.2036e-01,
         -1.2381e+00,  1.3872e-01, -6.3207e+01],
        [ 6.9792e-01, -1.2247e-01,  7.7042e-01,  3.2037e-02, -6.5696e-01,
         -4.9648e-01,  5.9166e-02,  6.9360e+01],
        [ 2.8257e-01, -2.9700e+00,  1.3231e-01, -2.7354e+00,  3.3597e-01,
         -6.2222e-01, -2.9061e-01, -1.4742e+01],
        [ 1.6055e-01,  9.0650e-01,  9.2916e-01,  5.3029e-01, -1.2657e+00,
         -1.0209e+00,  2.7596e-02, -7.0289e+01],
        [ 2.3818e-01, -1.1293e+00,  9.9822e-01, -8.4411e-01, -1.1915e+00,
          3.1715e+00,  3.4325e-01,  1.4783e+02],
        [ 1.2325e-01, -1.3863e+00,  1.7848e+00, -1.6062e-01, -2.9663e-01,
         -1.6308e+00, -2.3644e-01, -3.5159e+01],
        [-5.4035e-01,  2.4307e+00, -6.3870e-01,  1.3965e+00,  4.7659e-01,
         -1.7653e+00,  8.4513e-01,  4.2152e+01],
        [ 4.7109

In [18]:
a

tensor([3, 2, 3, 5, 7, 2, 5, 2, 6, 0, 5, 3, 2, 2, 7, 4, 5, 5, 5, 0, 6, 2, 0, 5,
        7, 2, 6, 7, 1, 4, 0, 7])

In [48]:
[el.item() for el in torch.Tensor([1,2,3,4])]

[1.0, 2.0, 3.0, 4.0]