In [1]:
from model import get_classification_model
from torch_geometric.data import Data
import glob
from tqdm import tqdm
import torch

In [2]:
import pickle
import gzip

In [3]:
BASELINE_GIN_CLASSIFIER = {
    "type": "GraphClassifier",
    "name": "BASELINE_GIN",
    "encoder": {
        "type": "GraphComposite",
        "pooling": {
            "type": "sum"
        },
        "encoder": {
            "num_layers": 3,
            "hidden_channels": 128,
            "layer_type": "CGIN",
            "norm_type": "None",

        }
    },
    "classifier": {
        "layer_type": "MLP",
        "dropout": 0.5,
        "num_layers": 3
    }
}

In [4]:
BASELINE_GIN_CLASSIFIER["features"] = 150
BASELINE_GIN_CLASSIFIER["classes"] = 1
model = get_classification_model(BASELINE_GIN_CLASSIFIER).encoder.node_level_encoder

In [5]:
model.load_state_dict(torch.load("14_model.chkpt"))
model.eval()
None

In [15]:
preds_c = []
preds_o = []
preds_co = []
trues = []
for name in tqdm(glob.glob(f"../cache/REVEAL3/*.cpg.pt.gz")):
    idx = name.split("/")[-1].split("_")[0]
    label = int(name.split("/")[-1].split("_")[-1].split(".")[0])

    object_file = pickle.load(gzip.open(name))
    data = Data(x=torch.cat((object_file["astenc"], object_file["codeenc"]), dim=1), edge_index=object_file["edge_index"], y=object_file["y"])
    data.edge_index = data.edge_index.long()
    data.x = data.x.float()
    c_logs, o_logs, co_logs = model(data)

    preds_c.append(c_logs.squeeze().argmax().item())
    preds_o.append(o_logs.squeeze().argmax().item())
    preds_co.append(co_logs.squeeze().argmax().item())
    trues.append(label)

100%|██████████| 19218/19218 [10:44<00:00, 29.82it/s]


In [16]:
from sklearn.metrics import matthews_corrcoef, accuracy_score, f1_score, balanced_accuracy_score

def compute_metrics(pred, true):
    predicted = (torch.as_tensor(pred) > 0.5).long().tolist()
    return {
        "MCC": matthews_corrcoef(true, predicted),
        "F1": f1_score(true, predicted, average='macro'),
        "Acc": accuracy_score(true, predicted),
        "BAcc": balanced_accuracy_score(true, predicted),
    }

In [17]:
compute_metrics(preds_c, trues)

{'MCC': -0.1354440204735427,
 'F1': 0.3315053048525221,
 'Acc': 0.4048808408783432,
 'BAcc': 0.38825360819983235}

In [18]:
compute_metrics(preds_o, trues)

{'MCC': 0.1316591703907932,
 'F1': 0.4754380646104019,
 'Acc': 0.5850764907898844,
 'BAcc': 0.6089379838512485}

In [19]:
compute_metrics(preds_co, trues)

{'MCC': 0.13880975014120056,
 'F1': 0.4649761640317789,
 'Acc': 0.5615048392132376,
 'BAcc': 0.6154728533588254}