In [1]:
!source .venv/bin/activate
import warnings
warnings.filterwarnings('ignore')


In [2]:
import os
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
from copy import deepcopy

#printing
from torch_geometric import utils
import networkx as nx

# Torch
from torch_geometric.nn import GCNConv, TransformerConv
from torch.utils.data import random_split, SubsetRandomSampler, Subset
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from typing import Dict, Tuple


import pandas as pd
import numpy as np
import matplotlib as plt


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [4]:
Edges = Dict[Tuple[int, int], Tuple[int, ...]]
Nodes = Dict[int, Tuple[int, ...]]
all_edges: Dict[Tuple[int, int], Tuple[int, ...]] = {}

class CustomDataset(Dataset):
    def __init__(self, data_list):
        self.data_list = data_list

    def __getitem__(self, index):
        return self.data_list[index]

    def __len__(self):
        return len(self.data_list)

def datasets(data_directory):
    datasets = {}
    for domain_name in os.listdir(data_directory):
        domain_path = os.path.join(data_directory, domain_name)
        data_loader = dataset_from_domain(domain_path)
        datasets[domain_name] = data_loader
    return datasets

def dataset_from_domain(domain_path):
    domain_name = os.path.basename(domain_path)
    dataset = []
    number_of_problems = 0
    for problem_name in os.listdir(domain_path):
        if problem_name == "empty_causal_graphs":
            continue
        number_of_problems+=1
        problem_path = os.path.join(domain_path, problem_name)
        data = problem_path_to_data(problem_path)
        dataset.append(data)
        # Generate list of data objects from our problem path
        
        # Iterate over all the problems in the domain
        # Generate a data object for each problem
        # train_test_split everything
        # train the model
    return dataset




def split_dataset(dataset, test_size=0.2, batch_size=8, shuffle=True, random_seed=42) -> Tuple[DataLoader, DataLoader]:
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    split = int(np.floor(0.2 * dataset_size))
    if True :
        np.random.seed(42)
        np.random.shuffle(indices)
    train_indices, test_indicies = indices[split:], indices[:split]

    train_set = []
    test_set = []

    for i in range(len(satellite_dataset)):
        if i in train_indices:
            train_set.append(satellite_dataset[i])
        else:
            test_set.append(satellite_dataset[i])

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=shuffle)
    return train_loader, test_set

    

def problem_path_to_data(problem_path):
    # print(problem_path)
    cg_df = pd.read_csv(os.path.join(problem_path, 'cg.csv'), index_col=[0, 1])
    cg_df.sort_index(inplace=True)
    nodes_df = pd.read_csv(os.path.join(problem_path, 'nodes.csv'), index_col=0)

    edges = cg_df.index
    edge_features_list = cg_df[['type_pre_eff', 'type_eff_eff']].values
    edge_labels = cg_df['label'].values
    edge_dict = {}

    # Unlucky naming, but the edge_features is a vector representing features of a single edge
    # edge_feature_list is the dictionary of all the edges and their respective features
    for edge, edge_features, label in zip(edges, edge_features_list, edge_labels):
        edge_dict[tuple(edge)] = (edge_features, label)

    edge_features, edge_labels = zip(*[edge_dict[edge] for edge in sorted(edge_dict.keys())])


    data = Data(
        x=torch.tensor(nodes_df.values, dtype=torch.float),
        edge_index=torch.tensor(list(sorted(edge_dict.keys())), dtype=torch.long).t().contiguous(),
        edge_attr=torch.tensor(edge_features, dtype=torch.float),
        y=torch.tensor(edge_labels, dtype=torch.bool)
    )
    return data

    
def draw_graph(data: Data):
    g = utils.to_networkx(data)

    color = ['green' if data.y[i] else 'red' for i in range(data.y.size(0))]
    a = nx.draw_networkx(g,node_size=200, pos=nx.spectral_layout(g), edge_color=color, node_color='green', with_labels=True)

satellite_dataset = dataset_from_domain('graph_training_data/satellite')
train_loader, test_set = split_dataset(satellite_dataset)



In [9]:
class Net(torch.nn.Module):
    def __init__(self, features_num):
        super(Net, self).__init__()
        self.conv1 = TransformerConv(
            in_channels=features_num,
            out_channels=128,
            edge_dim=2
        )
        self.conv2 = TransformerConv(
            in_channels=128,
            out_channels=64,
            edge_dim=2)

    def encode(self, data: Data):
        x = self.conv1(
            x=data.x,
            edge_index=data.edge_index,
            edge_attr=data.edge_attr) # convolution 1
        x = x.relu()
        return self.conv2(
            x=x, 
            edge_index=data.edge_index,
            edge_attr=data.edge_attr) # convolution 2



    def decode(self, z, edge_index): # only pos and neg edges
        #TODO  edge_index[0] 7 4 9
        #TODO edge_index[1] 5 3 9
        #     print("z shape: ", z.shape)

        # Multiply adjecency matrix with latent space using the COO format of 
        # Edge index[0] and Edge index[1]
        logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)  # dot product 


        return logits

    def decode_all(self, z): 
        prob_adj = z @ z.t() # get adj NxN
        return (prob_adj > 0).nonzero(as_tuple=False).t() # get predicted edge_list 

In [10]:
data_loader = train_loader
test_loader = DataLoader(test_set, batch_size=8)

num_node_features = next(iter(data_loader)).x.shape[1]
model = Net(num_node_features)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)
batch = next(iter(data_loader))
print(batch[0].x.shape)
print(batch[1].x.shape)

torch.Size([15, 1])
torch.Size([11, 1])


In [11]:
# class DomainTrainer:
#     def __init__(self, train_loader, test_set):
#         self.train_loader = train_loader
#         self.test_set = test_set

#         data_loader = train_loader

#         num_node_features = next(iter(data_loader)).x.shape[1]
#         model = Net(num_node_features)
#         optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
#         batch = next(iter(data_loader))
#         print(batch[0].x.shape)
#         print(batch[1].x.shape)

def train():
    model.train()  # Flag to modify the gradient

    batch = next(iter(data_loader))  # This is next level shit
    # print(batch[0])
    # print(batch[1])
    edge_index = batch.edge_index
    optimizer.zero_grad()
    z = model.encode(batch) 
    link_logits = model.decode(z, edge_index) # decode
    print(link_logits)
    # print(link_logits)
    link_labels = batch.y
    # print(link_labels)
    link_labels = link_labels.type(torch.float)
    # print(link_labels)
    loss = F.binary_cross_entropy_with_logits(link_logits, link_labels)
    loss.backward()
    optimizer.step()

    return loss


@torch.no_grad()
def test():
    model.eval()
    test_data = next(iter(test_loader))
    z = model.encode(test_data) # encode train
    link_logits = model.decode(z, test_data.edge_index) # decode test or val
    link_probs = link_logits.sigmoid() # apply sigmoid
    link_labels = test_data.y.type(torch.float)
    
    return roc_auc_score(link_labels, link_probs) #compute roc_auc score

In [12]:
best_val_perf = test_perf = 0
for epoch in range(1, 300):
    train_loss = train()

    test_perf = test()
    # if val_perf > best_val_perf:
    #     best_val_perf = val_perf
    #     test_perf = tmp_test_perf
    # log = 'Epoch: {:03d}, Loss: {:.4f}, Val: {:.4f}, Test: {:.4f}'
    log = 'Epoch: {:03d}, Loss: {:.4f}, Test: {:.4f}'
    if epoch % 10 == 0:

        print(log.format(epoch, train_loss, test_perf))

tensor([132.0721, 137.3179, 132.0721, 137.3179, 140.4741, 132.0721, 137.3179,
        132.0721, 137.3179, 140.4741, 140.4741, 140.4741, 140.4741, 137.3700,
        136.1938, 137.3700, 136.1938, 136.1938, 136.1938, 136.1938, 137.3179,
        137.3179, 181.1575, 137.3179, 137.3179, 181.1575, 181.1575, 181.1575,
        181.1575, 106.6725, 106.7557, 101.2937,  99.7515, 106.6725,  99.7515,
        105.9212,  99.7515, 106.6725, 106.7557, 101.2937,  99.7515, 106.6725,
        106.7557, 101.2937,  99.7515, 106.6725,  99.7515, 106.6725, 105.9212,
        105.9212,  78.8733,  76.9742,  78.8733,  78.8733,  76.9742, 106.7557,
        105.9212, 116.6104, 106.7557, 105.9212, 116.6104, 106.7557, 105.9212,
        116.6104, 109.4420, 106.7700, 109.4420, 106.7700, 103.3471, 106.6725,
        106.7557, 103.3461,  99.7515, 106.6725,  99.7515, 105.9212,  78.7563,
         78.3577,  78.8733,  78.3577, 106.7700, 106.7700, 121.4116, 106.7557,
        105.9212, 121.4037,  99.4785,  99.4785, 105.2286, 106.75

In [None]:
test_data

In [None]:
£