In [None]:
import os.path as osp

import torch
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
from torch_geometric.utils import negative_sampling 
import torch.optim as optim
from torch.utils.data import DataLoader
from dataset import load_data

from sklearn.metrics import roc_auc_score 

import ctraining_data

if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): 
    device = torch.device('mps') 
else: 
    device = torch.device('cpu') 


class GCN(torch.nn.Module): 
    """
    GCN layers similar to https://arxiv.org/abs/1609.02907
    - altered for link regression

    """
    def __init__(self, in_channels, hidden_channels, out_channels): 
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def encode(self, x, edge_index): 
        h = self.conv1(x, edge_index)
        h = h.relu()
        h = self.conv2(x, edge_index)
        return h

    def decode(self, z, edge_label_index): 
        # z: node features matrix [num_nodes, num_features]
        # edge index: edge indices [2, num_edges]
        src, dst = edge_index
        score = (z[src] * z[dst]).sum(dim=1) 
        return score

"""
Import the training datasets
"""
datafolders = ['data/train/process18427', 'data/train/process31072', 'data/train/process53165']
datasets = [ctraining_data.ColumnLatticeDataset(folder) for folder in datafolders]
batch_size = 1000

training_data = [datasets[0], datasets[1]]
validation_data = [dataset[2]]

dataloaders = [DataLoader(dataset, batch_size=batch_size) for dataset in training_datasets]
validation_loader = DataLoader(validation_set, batch_size=batch_size)

N_training_examples = sum([len(dataset) for dataset in training_datasets])
N_validation_examples = len(validation_set)


model = GCN(dataset.num_features, 
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01) 
criterion = torch.nn.MSELoss()

def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(train_data.x, train_data.edge_index)

    edge_label = train_data.edge_label

    out = model.decode(z, train_data.edge_label_index).view(-1)
    loss = criterion(out, edge_label)
    loss.backward()
    optimizer.step()
    return loss

@torch.no_grad()
def test(data):
    model.eval()
    z = model.encode(data.x, data.edge_index)
    out = model.decode(z, data.edge_label_index).view(-1).sigmoid()
    return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())


best_val_auc = final_test_auc = 0
for epoch in range(1, 101):
    loss = train()
    val_auc = test(val_data)
    test_auc = test(test_data)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        final_test_auc = test_auc
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, '
          f'Test: {test_auc:.4f}')

print(f'Final Test: {final_test_auc:.4f}')

z = model.encode(test_data.x, test_data.edge_index)
final_edge_index = model.decode_all(z)
    





