In [None]:
import torch
import numpy as np
import networkx as nx
from olga import OLGA
from torch_geometric.nn import GAE
from oneclass import one_class_loss, one_class_masking, One_Class_GNN_prediction, EarlyStopping
from torch_geometric.utils.convert import from_networkx
import warnings
from sklearn.exceptions import UndefinedMetricWarning
warnings.filterwarnings(action='ignore', category=UndefinedMetricWarning)

learn_radius = False
learn_center = False

strategy = 'l2l3' #l1, l1l3, l2l3, l1l2l3

dt = 'musk'
k = 'k=1'

file = '../../datasets/' + dt + '/' + k + '/' + dt + '_' + k +'_fold=0.gpickle'

g = nx.read_gpickle(file)

mask, t_mask, mask_unsup, t_mask_unsup = one_class_masking(g, True)

G = from_networkx(g)

hidden1 = 48

hidden2 = 2

c = [0] * hidden2
r = [0.35]
c = torch.Tensor(c)
r = torch.Tensor(r)

model_ocl = OLGA(len(G.features[0]), [hidden1, hidden2])
model = GAE(model_ocl)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)

patience = 300

stopper = EarlyStopping(patience)

centers = []
embeddings = []
losses_ocl = []
losses_rec = []
accuracies = []
radiuss = []
losses = []

best_embeddings, best_radius, best_center = [], 0, []
# Training loop

g_unsup = g.subgraph(t_mask_unsup)
G_unsup = from_networkx(g_unsup)

loss_ocl = 0
recon_loss_unsup = 0

radius = r
center = c

for epoch in range(10001):
    # Clear gradients
    optimizer.zero_grad()

    # Forward pass

    learned_representations = model.encode(G.features.float(), G.edge_index)

    if strategy == 'l1':
        loss = one_class_loss(center, radius, learned_representations, mask)
    elif strategy == 'l1l3':
        loss_ocl = one_class_loss(center, radius, learned_representations, mask)
        recon_loss_unsup = model.recon_loss(learned_representations[mask_unsup], G_unsup.edge_index)
        loss = loss_ocl + recon_loss_unsup
    elif strategy == 'l2l3':
        loss = model.recon_loss(learned_representations, G.edge_index)
    elif strategy == 'l1l2l3':
        loss_ocl = one_class_loss(center, radius, learned_representations, mask)
        loss_rec = model.recon_loss(learned_representations, G.edge_index)
        loss = loss_ocl + loss_rec

    f1 = One_Class_GNN_prediction(center, radius, learned_representations, g, 'test', True)['macro avg']['f1-score']

    # Compute gradients
    loss.backward()

    # Tune parameters
    optimizer.step()

    #print(f'Epoch {epoch:>3} | Loss: {loss:.5f} | F1: {f1*100:.2f}% | Loss_R: {recon_loss_unsup:.5f} |  Loss_O: '
          #f''f'{loss_ocl:.5f} | R: {torch.abs(torch.mean(radius)):.2f}')

    stop, best_embeddings, best_radius, best_center, best_epoch = stopper.step(f1, loss, epoch, radius, center,
                                                                               learned_representations)

    embeddings.append(learned_representations)
    losses_ocl.append(loss_ocl)
    losses_rec.append(recon_loss_unsup)
    losses.append(loss)
    accuracies.append(f1)
    centers.append(center)
    radiuss.append(radius)

    if stop:
        break

One_Class_GNN_prediction(best_center, best_radius, best_embeddings, g, 'test', False)

              precision    recall  f1-score   support

          -1       0.96      0.98      0.97      2925
           1       0.15      0.08      0.10       123

    accuracy                           0.94      3048
   macro avg       0.55      0.53      0.54      3048
weighted avg       0.93      0.94      0.94      3048

