In [2]:
import os
import torch
import warnings
import torchmetrics
import numpy as np
from loguru import logger
from torchmetrics.functional.regression import cosine_similarity

from const import *
from data import get_dataloader

warnings.filterwarnings('ignore')
start_threshold, end_threshold = 0.5, 0.851

  warn(f"Failed to load image Python extension: {e}")


# 测试数据

In [3]:
test_file_name = "triplet_both_test_5k.txt"
test_file_path = os.path.join(data_dir, test_file_name)
test_dl = get_dataloader(test_file_path, "eval", batch_size=128)

Loading data from /home/chennanye/ICDM-2023-Address/federated_contrastive_learning/datasets/triplet_both_test_5k.txt: : 5000it [00:00, 139078.58it/s]


In [11]:
def evaluate(model, dev_dl, start_threshold=0.50, end_threshold=1.01):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.eval()
    # 计算pairs之间的相似度
    y_prob = []
    y_true = []
    with torch.no_grad():
        for batch_idx, source in enumerate(dev_dl):
            real_batch_num = source.get('input_ids').shape[0]
            input_ids = source.get('input_ids').view(real_batch_num * 3, -1).to(device)
            attention_mask = source.get('attention_mask').view(real_batch_num * 3, -1).to(device)
            token_type_ids = source.get('token_type_ids').view(real_batch_num * 3, -1).to(device)
            out = model(input_ids, attention_mask, token_type_ids)
            idx = torch.arange(out.shape[0], device=device).unsqueeze(1)
            # 三元组 (anchor, positive, negetive)
            anchor = out.index_select(0, torch.where(idx % 3 == 0)[0])
            positive = out.index_select(0, torch.where(idx % 3 == 1)[0])
            negative = out.index_select(0, torch.where(idx % 3 == 2)[0])

            # y_true、y_pred、y_prob
            sim_1 = cosine_similarity(anchor, positive, 'none')
            sim_0 = cosine_similarity(anchor, negative, 'none')
            y_prob.extend(sim_1)
            y_prob.extend(sim_0)
            y_true.extend([1] * len(sim_1))
            y_true.extend([0] * len(sim_0))
    y_prob = torch.tensor(y_prob).view(-1, 1).to(device)
    y_true = torch.tensor(y_true).view(-1, 1).to(device)
    
    max_f1 = 0
    for threshold in np.arange(start_threshold, end_threshold, 0.01):
        threshold = round(threshold, 2)
        y_pred = torch.where(y_prob > threshold, 1, 0).to(device)
        
        Accuracy = torchmetrics.Accuracy(threshold=threshold).to(device)
        Precision = torchmetrics.Precision(threshold=threshold, ignore_index=0).to(device)
        Recall = torchmetrics.Recall(threshold=threshold, ignore_index=0).to(device)
        F1Score = torchmetrics.F1Score(threshold=threshold, ignore_index=0).to(device)
        AUROC = torchmetrics.AUROC().to(device)
        SpearmanCorrCoef = torchmetrics.SpearmanCorrCoef().to(device)
        
        accuracy = round(Accuracy(y_pred, y_true).item(), 4)
        precision = round(Precision(y_pred, y_true).item(), 4)
        recall = round(Recall(y_pred, y_true).item(), 4)
        f1 = round(F1Score(y_pred, y_true).item(), 4)
        if f1 > max_f1:
            max_f1 = f1
        auroc = round(AUROC(y_prob, y_true).item(), 4)
        spearman = round(SpearmanCorrCoef(y_prob, y_true.type(torch.float32)).item(), 4)
        print(f"Threshold {threshold:.2f}: accuracy={accuracy:.4f}, precision={precision:.4f}, recall={recall:.4f}, f1={f1:.4f}, auroc={auroc:.4f}, spearman={spearman:.4f}")
        Accuracy.reset(), Precision.reset(), Recall.reset(), F1Score.reset(), AUROC.reset(), SpearmanCorrCoef.reset()
    print(f"max_f1={max_f1}")

# Residence

## Transformer

### Residence_Transformer_Sup-SimCSE

In [43]:
model_path = os.path.join(model_dir, "2022-06-03/Residence/Transformer/Sup-SimCSE/13:13:04-Residence_Transformer_Sup-SimCSE.pt")
model = torch.load(model_path)
evaluate(model, test_dl)

Threshold 0.50: accuracy=0.6955, precision=0.6254, recall=0.9748, f1=0.7620, auroc=0.9194, spearman=0.7265
Threshold 0.51: accuracy=0.7032, precision=0.6322, recall=0.9716, f1=0.7660, auroc=0.9194, spearman=0.7265
Threshold 0.52: accuracy=0.7127, precision=0.6405, recall=0.9694, f1=0.7714, auroc=0.9194, spearman=0.7265
Threshold 0.53: accuracy=0.7209, precision=0.6480, recall=0.9670, f1=0.7760, auroc=0.9194, spearman=0.7265
Threshold 0.54: accuracy=0.7294, precision=0.6561, recall=0.9640, f1=0.7808, auroc=0.9194, spearman=0.7265
Threshold 0.55: accuracy=0.7393, precision=0.6659, recall=0.9606, f1=0.7865, auroc=0.9194, spearman=0.7265
Threshold 0.56: accuracy=0.7489, precision=0.6756, recall=0.9578, f1=0.7923, auroc=0.9194, spearman=0.7265
Threshold 0.57: accuracy=0.7577, precision=0.6849, recall=0.9544, f1=0.7975, auroc=0.9194, spearman=0.7265
Threshold 0.58: accuracy=0.7671, precision=0.6950, recall=0.9518, f1=0.8034, auroc=0.9194, spearman=0.7265
Threshold 0.59: accuracy=0.7753, prec

### Residence_Transformer_Sup-Triplet

In [44]:
model_path = os.path.join(model_dir, "2022-06-03/Residence/Transformer/Sup-Triplet/14:59:46-Residence_Transformer_Sup-Triplet.pt")
model = torch.load(model_path)
evaluate(model, test_dl)

Threshold 0.50: accuracy=0.5203, precision=0.5113, recall=0.9178, f1=0.6567, auroc=0.5189, spearman=0.0342
Threshold 0.51: accuracy=0.5205, precision=0.5115, recall=0.9092, f1=0.6547, auroc=0.5189, spearman=0.0342
Threshold 0.52: accuracy=0.5189, precision=0.5107, recall=0.9038, f1=0.6526, auroc=0.5189, spearman=0.0342
Threshold 0.53: accuracy=0.5189, precision=0.5107, recall=0.9026, f1=0.6523, auroc=0.5189, spearman=0.0342
Threshold 0.54: accuracy=0.5187, precision=0.5106, recall=0.9022, f1=0.6521, auroc=0.5189, spearman=0.0342
Threshold 0.55: accuracy=0.5188, precision=0.5106, recall=0.9022, f1=0.6522, auroc=0.5189, spearman=0.0342
Threshold 0.56: accuracy=0.5189, precision=0.5107, recall=0.9020, f1=0.6522, auroc=0.5189, spearman=0.0342
Threshold 0.57: accuracy=0.5188, precision=0.5106, recall=0.9018, f1=0.6521, auroc=0.5189, spearman=0.0342
Threshold 0.58: accuracy=0.5184, precision=0.5104, recall=0.9010, f1=0.6517, auroc=0.5189, spearman=0.0342
Threshold 0.59: accuracy=0.5184, prec

### Residence_Transformer_Unsup-SimCSE

In [45]:
model_path = os.path.join(model_dir, "2022-06-04/Residence/Transformer/Unsup-SimCSE/01:24:49-Residence_Transformer_Unsup-SimCSE.pt")
model = torch.load(model_path)
evaluate(model, test_dl)

Threshold 0.50: accuracy=0.4003, precision=0.2093, recall=0.0718, f1=0.1069, auroc=0.2847, spearman=-0.3730
Threshold 0.51: accuracy=0.4039, precision=0.2047, recall=0.0666, f1=0.1005, auroc=0.2847, spearman=-0.3730
Threshold 0.52: accuracy=0.4105, precision=0.2050, recall=0.0622, f1=0.0954, auroc=0.2847, spearman=-0.3730
Threshold 0.53: accuracy=0.4149, precision=0.2031, recall=0.0582, f1=0.0905, auroc=0.2847, spearman=-0.3730
Threshold 0.54: accuracy=0.4197, precision=0.2037, recall=0.0552, f1=0.0869, auroc=0.2847, spearman=-0.3730
Threshold 0.55: accuracy=0.4240, precision=0.2017, recall=0.0514, f1=0.0819, auroc=0.2847, spearman=-0.3730
Threshold 0.56: accuracy=0.4290, precision=0.2027, recall=0.0484, f1=0.0781, auroc=0.2847, spearman=-0.3730
Threshold 0.57: accuracy=0.4333, precision=0.2009, recall=0.0448, f1=0.0733, auroc=0.2847, spearman=-0.3730
Threshold 0.58: accuracy=0.4373, precision=0.2045, recall=0.0434, f1=0.0716, auroc=0.2847, spearman=-0.3730
Threshold 0.59: accuracy=0.4

## Transformer + GE-type

### Residence_Transformer+GE-type_Sup-SimCSE

In [4]:
model_path = os.path.join(model_dir, "2022-06-03/Residence/Transformer+GE-type/Sup-SimCSE/13:15:44-Residence_Transformer+GE-type_Sup-SimCSE.pt")
model = torch.load(model_path)

for threshold in np.arange(start_threshold, end_threshold, 0.02):
    evaluate(model, test_dl, criterion=simcse_sup_loss, threshold=threshold)

2022-06-04 00:15:17.533 | INFO     | __main__:<cell line: 4>:4 - model: 
 GeographicalAttentionNetwork(
  (pretrained_bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 256, padding_idx=0)
      (position_embeddings): Embedding(512, 256)
      (token_type_embeddings): Embedding(16, 256)
      (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=256, out_features=256, bias=True)
              (key): Linear(in_features=256, out_features=256, bias=True)
              (value): Linear(in_features=256, out_features=256, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=256, out_fea

## Transformer + GE-type + Entity-Pooler

### Residence_Transformer+GE-type+Entity-Pooler_Sup-SimCSE

In [8]:
model_path = os.path.join(model_dir, "2022-06-04/Residence/Transformer+GE-type+Entity-Pooler/Sup-SimCSE/12:20:15-Residence_Transformer+GE-type+Entity-Pooler_Sup-SimCSE.pt")
model = torch.load(model_path)
evaluate(model, test_dl)

Threshold 0.50: accuracy=0.7764, precision=0.6932, recall=0.9918, f1=0.8160, auroc=0.9733, spearman=0.8198
Threshold 0.51: accuracy=0.7843, precision=0.7012, recall=0.9908, f1=0.8212, auroc=0.9733, spearman=0.8198
Threshold 0.52: accuracy=0.7915, precision=0.7088, recall=0.9896, f1=0.8260, auroc=0.9733, spearman=0.8198
Threshold 0.53: accuracy=0.7984, precision=0.7163, recall=0.9882, f1=0.8306, auroc=0.9733, spearman=0.8198
Threshold 0.54: accuracy=0.8055, precision=0.7240, recall=0.9874, f1=0.8354, auroc=0.9733, spearman=0.8198
Threshold 0.55: accuracy=0.8131, precision=0.7324, recall=0.9868, f1=0.8408, auroc=0.9733, spearman=0.8198
Threshold 0.56: accuracy=0.8204, precision=0.7410, recall=0.9852, f1=0.8458, auroc=0.9733, spearman=0.8198
Threshold 0.57: accuracy=0.8278, precision=0.7495, recall=0.9846, f1=0.8511, auroc=0.9733, spearman=0.8198
Threshold 0.58: accuracy=0.8368, precision=0.7606, recall=0.9830, f1=0.8576, auroc=0.9733, spearman=0.8198
Threshold 0.59: accuracy=0.8433, prec

### Residence_Transformer+GE-type+Entity-Pooler_Sup-Triplet

In [7]:
model_path = os.path.join(model_dir, "2022-06-04/Residence/Transformer+GE-type+Entity-Pooler/Sup-Triplet/12:21:28-Residence_Transformer+GE-type+Entity-Pooler_Sup-Triplet.pt")
model = torch.load(model_path)
evaluate(model, test_dl)

Threshold 0.50: accuracy=0.8207, precision=0.8021, recall=0.8514, f1=0.8260, auroc=0.8977, spearman=0.6889
Threshold 0.51: accuracy=0.8218, precision=0.8049, recall=0.8496, f1=0.8266, auroc=0.8977, spearman=0.6889
Threshold 0.52: accuracy=0.8234, precision=0.8079, recall=0.8486, f1=0.8277, auroc=0.8977, spearman=0.6889
Threshold 0.53: accuracy=0.8232, precision=0.8091, recall=0.8460, f1=0.8271, auroc=0.8977, spearman=0.6889
Threshold 0.54: accuracy=0.8236, precision=0.8106, recall=0.8446, f1=0.8272, auroc=0.8977, spearman=0.6889
Threshold 0.55: accuracy=0.8258, precision=0.8148, recall=0.8432, f1=0.8288, auroc=0.8977, spearman=0.6889
Threshold 0.56: accuracy=0.8264, precision=0.8171, recall=0.8410, f1=0.8289, auroc=0.8977, spearman=0.6889
Threshold 0.57: accuracy=0.8266, precision=0.8194, recall=0.8378, f1=0.8285, auroc=0.8977, spearman=0.6889
Threshold 0.58: accuracy=0.8283, precision=0.8226, recall=0.8372, f1=0.8298, auroc=0.8977, spearman=0.6889
Threshold 0.59: accuracy=0.8292, prec

### Residence_Transformer+GE-type+Entity-Pooler_Unsup-SimCSE

In [49]:
model_path = os.path.join(model_dir, "2022-06-04/Residence/Transformer+GE-type+Entity-Pooler/Unsup-SimCSE/01:25:40-Residence_Transformer+GE-type+Entity-Pooler_Unsup-SimCSE.pt")
model = torch.load(model_path)
evaluate(model, test_dl)

Threshold 0.50: accuracy=0.3930, precision=0.2398, recall=0.0986, f1=0.1397, auroc=0.2925, spearman=-0.3594
Threshold 0.51: accuracy=0.3957, precision=0.2343, recall=0.0920, f1=0.1321, auroc=0.2925, spearman=-0.3594
Threshold 0.52: accuracy=0.3970, precision=0.2261, recall=0.0850, f1=0.1235, auroc=0.2925, spearman=-0.3594
Threshold 0.53: accuracy=0.4002, precision=0.2219, recall=0.0796, f1=0.1172, auroc=0.2925, spearman=-0.3594
Threshold 0.54: accuracy=0.4021, precision=0.2146, recall=0.0736, f1=0.1096, auroc=0.2925, spearman=-0.3594
Threshold 0.55: accuracy=0.4059, precision=0.2112, recall=0.0688, f1=0.1038, auroc=0.2925, spearman=-0.3594
Threshold 0.56: accuracy=0.4080, precision=0.2040, recall=0.0634, f1=0.0967, auroc=0.2925, spearman=-0.3594
Threshold 0.57: accuracy=0.4114, precision=0.1999, recall=0.0590, f1=0.0911, auroc=0.2925, spearman=-0.3594
Threshold 0.58: accuracy=0.4134, precision=0.1933, recall=0.0546, f1=0.0852, auroc=0.2925, spearman=-0.3594
Threshold 0.59: accuracy=0.4

# Institution

## Transformer

### Institution_Transformer_Sup-SimCSE

In [50]:
model_path = os.path.join(model_dir, "2022-06-03/Institution/Transformer/Sup-SimCSE/13:18:53-Institution_Transformer_Sup-SimCSE.pt")
model = torch.load(model_path)
evaluate(model, test_dl)

Threshold 0.50: accuracy=0.6713, precision=0.6056, recall=0.9826, f1=0.7493, auroc=0.9199, spearman=0.7273
Threshold 0.51: accuracy=0.6824, precision=0.6141, recall=0.9816, f1=0.7555, auroc=0.9199, spearman=0.7273
Threshold 0.52: accuracy=0.6908, precision=0.6209, recall=0.9798, f1=0.7601, auroc=0.9199, spearman=0.7273
Threshold 0.53: accuracy=0.6992, precision=0.6282, recall=0.9762, f1=0.7644, auroc=0.9199, spearman=0.7273
Threshold 0.54: accuracy=0.7080, precision=0.6358, recall=0.9736, f1=0.7693, auroc=0.9199, spearman=0.7273
Threshold 0.55: accuracy=0.7176, precision=0.6442, recall=0.9722, f1=0.7749, auroc=0.9199, spearman=0.7273
Threshold 0.56: accuracy=0.7273, precision=0.6529, recall=0.9704, f1=0.7806, auroc=0.9199, spearman=0.7273
Threshold 0.57: accuracy=0.7358, precision=0.6608, recall=0.9688, f1=0.7857, auroc=0.9199, spearman=0.7273
Threshold 0.58: accuracy=0.7441, precision=0.6693, recall=0.9650, f1=0.7904, auroc=0.9199, spearman=0.7273
Threshold 0.59: accuracy=0.7551, prec

### Institution_Transformer_Sup-Triplet

In [51]:
model_path = os.path.join(model_dir, "2022-06-04/Institution/Transformer/Sup-Triplet/00:40:36-Institution_Transformer_Sup-Triplet.pt")
model = torch.load(model_path)
evaluate(model, test_dl)

Threshold 0.50: accuracy=0.5000, precision=0.5000, recall=1.0000, f1=0.6667, auroc=0.5020, spearman=0.0040
Threshold 0.51: accuracy=0.5000, precision=0.5000, recall=1.0000, f1=0.6667, auroc=0.5020, spearman=0.0040
Threshold 0.52: accuracy=0.5000, precision=0.5000, recall=1.0000, f1=0.6667, auroc=0.5020, spearman=0.0040
Threshold 0.53: accuracy=0.5000, precision=0.5000, recall=1.0000, f1=0.6667, auroc=0.5020, spearman=0.0040
Threshold 0.54: accuracy=0.5000, precision=0.5000, recall=1.0000, f1=0.6667, auroc=0.5020, spearman=0.0040
Threshold 0.55: accuracy=0.5000, precision=0.5000, recall=1.0000, f1=0.6667, auroc=0.5020, spearman=0.0040
Threshold 0.56: accuracy=0.5000, precision=0.5000, recall=1.0000, f1=0.6667, auroc=0.5020, spearman=0.0040
Threshold 0.57: accuracy=0.5000, precision=0.5000, recall=1.0000, f1=0.6667, auroc=0.5020, spearman=0.0040
Threshold 0.58: accuracy=0.5000, precision=0.5000, recall=1.0000, f1=0.6667, auroc=0.5020, spearman=0.0040
Threshold 0.59: accuracy=0.5000, prec

### Institution_Transformer_Unsup-SimCSE

In [52]:
model_path = os.path.join(model_dir, "2022-06-04/Institution/Transformer/Unsup-SimCSE/01:25:51-Institution_Transformer_Unsup-SimCSE.pt")
model = torch.load(model_path)
evaluate(model, test_dl)

Threshold 0.50: accuracy=0.3910, precision=0.2141, recall=0.0816, f1=0.1182, auroc=0.2785, spearman=-0.3836
Threshold 0.51: accuracy=0.3978, precision=0.2120, recall=0.0752, f1=0.1110, auroc=0.2785, spearman=-0.3836
Threshold 0.52: accuracy=0.4030, precision=0.2120, recall=0.0714, f1=0.1068, auroc=0.2785, spearman=-0.3836
Threshold 0.53: accuracy=0.4077, precision=0.2085, recall=0.0660, f1=0.1003, auroc=0.2785, spearman=-0.3836
Threshold 0.54: accuracy=0.4101, precision=0.2001, recall=0.0600, f1=0.0923, auroc=0.2785, spearman=-0.3836
Threshold 0.55: accuracy=0.4155, precision=0.1997, recall=0.0562, f1=0.0877, auroc=0.2785, spearman=-0.3836
Threshold 0.56: accuracy=0.4198, precision=0.1962, recall=0.0518, f1=0.0820, auroc=0.2785, spearman=-0.3836
Threshold 0.57: accuracy=0.4248, precision=0.1953, recall=0.0482, f1=0.0773, auroc=0.2785, spearman=-0.3836
Threshold 0.58: accuracy=0.4283, precision=0.1928, recall=0.0450, f1=0.0730, auroc=0.2785, spearman=-0.3836
Threshold 0.59: accuracy=0.4

## Transformer+GE-type

## Transformer+GE-type+Entity-Pooler

### Institution_Transformer+GE-type+Entity-Pooler_Sup-SimCSE

In [9]:
model_path = os.path.join(model_dir, "2022-06-04/Institution/Transformer+GE-type+Entity-Pooler/Sup-SimCSE/12:20:51-Institution_Transformer+GE-type+Entity-Pooler_Sup-SimCSE.pt")
model = torch.load(model_path)
evaluate(model, test_dl)

Threshold 0.50: accuracy=0.8149, precision=0.7337, recall=0.9886, f1=0.8423, auroc=0.9714, spearman=0.8164
Threshold 0.51: accuracy=0.8215, precision=0.7415, recall=0.9870, f1=0.8468, auroc=0.9714, spearman=0.8164
Threshold 0.52: accuracy=0.8286, precision=0.7499, recall=0.9860, f1=0.8519, auroc=0.9714, spearman=0.8164
Threshold 0.53: accuracy=0.8341, precision=0.7572, recall=0.9836, f1=0.8557, auroc=0.9714, spearman=0.8164
Threshold 0.54: accuracy=0.8408, precision=0.7660, recall=0.9814, f1=0.8604, auroc=0.9714, spearman=0.8164
Threshold 0.55: accuracy=0.8472, precision=0.7740, recall=0.9808, f1=0.8652, auroc=0.9714, spearman=0.8164
Threshold 0.56: accuracy=0.8532, precision=0.7821, recall=0.9792, f1=0.8696, auroc=0.9714, spearman=0.8164
Threshold 0.57: accuracy=0.8596, precision=0.7909, recall=0.9776, f1=0.8744, auroc=0.9714, spearman=0.8164
Threshold 0.58: accuracy=0.8642, precision=0.7983, recall=0.9746, f1=0.8777, auroc=0.9714, spearman=0.8164
Threshold 0.59: accuracy=0.8698, prec

### Institution_Transformer+GE-type+Entity-Pooler_Sup-Triplet

In [54]:
model_path = os.path.join(model_dir, "2022-06-04/Institution/Transformer+GE-type+Entity-Pooler/Sup-Triplet/00:41:27-Institution_Transformer+GE-type+Entity-Pooler_Sup-Triplet.pt")
model = torch.load(model_path)
evaluate(model, test_dl)

Threshold 0.50: accuracy=0.8517, precision=0.8261, recall=0.8910, f1=0.8573, auroc=0.9149, spearman=0.7186
Threshold 0.51: accuracy=0.8527, precision=0.8286, recall=0.8894, f1=0.8579, auroc=0.9149, spearman=0.7186
Threshold 0.52: accuracy=0.8530, precision=0.8306, recall=0.8868, f1=0.8578, auroc=0.9149, spearman=0.7186
Threshold 0.53: accuracy=0.8540, precision=0.8336, recall=0.8846, f1=0.8583, auroc=0.9149, spearman=0.7186
Threshold 0.54: accuracy=0.8540, precision=0.8351, recall=0.8822, f1=0.8580, auroc=0.9149, spearman=0.7186
Threshold 0.55: accuracy=0.8544, precision=0.8373, recall=0.8798, f1=0.8580, auroc=0.9149, spearman=0.7186
Threshold 0.56: accuracy=0.8551, precision=0.8401, recall=0.8772, f1=0.8582, auroc=0.9149, spearman=0.7186
Threshold 0.57: accuracy=0.8552, precision=0.8417, recall=0.8750, f1=0.8580, auroc=0.9149, spearman=0.7186
Threshold 0.58: accuracy=0.8571, precision=0.8453, recall=0.8742, f1=0.8595, auroc=0.9149, spearman=0.7186
Threshold 0.59: accuracy=0.8573, prec

### Institution_Transformer+GE-type+Entity-Pooler_Unsup-SimCSE

In [55]:
model_path = os.path.join(model_dir, "2022-06-04/Institution/Transformer+GE-type+Entity-Pooler/Unsup-SimCSE/01:25:45-Institution_Transformer+GE-type+Entity-Pooler_Unsup-SimCSE.pt")
model = torch.load(model_path)
evaluate(model, test_dl)

Threshold 0.50: accuracy=0.3337, precision=0.2980, recall=0.2454, f1=0.2692, auroc=0.2426, spearman=-0.4457
Threshold 0.51: accuracy=0.3333, precision=0.2933, recall=0.2366, f1=0.2619, auroc=0.2426, spearman=-0.4457
Threshold 0.52: accuracy=0.3321, precision=0.2881, recall=0.2282, f1=0.2547, auroc=0.2426, spearman=-0.4457
Threshold 0.53: accuracy=0.3327, precision=0.2835, recall=0.2190, f1=0.2471, auroc=0.2426, spearman=-0.4457
Threshold 0.54: accuracy=0.3306, precision=0.2749, recall=0.2068, f1=0.2360, auroc=0.2426, spearman=-0.4457
Threshold 0.55: accuracy=0.3319, precision=0.2713, recall=0.1994, f1=0.2299, auroc=0.2426, spearman=-0.4457
Threshold 0.56: accuracy=0.3311, precision=0.2659, recall=0.1918, f1=0.2228, auroc=0.2426, spearman=-0.4457
Threshold 0.57: accuracy=0.3324, precision=0.2625, recall=0.1852, f1=0.2172, auroc=0.2426, spearman=-0.4457
Threshold 0.58: accuracy=0.3321, precision=0.2574, recall=0.1782, f1=0.2106, auroc=0.2426, spearman=-0.4457
Threshold 0.59: accuracy=0.3

# Federated

## Transformer

### Federated_Transformer_Sup-SimCSE

In [56]:
model_path = os.path.join(model_dir, "2022-06-03/Federated/Transformer/Sup-SimCSE/23:52:29-Federated_Transformer_Sup-SimCSE-global.pt")
model = torch.load(model_path)
evaluate(model, test_dl)

Threshold 0.50: accuracy=0.7151, precision=0.6393, recall=0.9870, f1=0.7760, auroc=0.9548, spearman=0.7878
Threshold 0.51: accuracy=0.7239, precision=0.6471, recall=0.9852, f1=0.7811, auroc=0.9548, spearman=0.7878
Threshold 0.52: accuracy=0.7348, precision=0.6567, recall=0.9842, f1=0.7877, auroc=0.9548, spearman=0.7878
Threshold 0.53: accuracy=0.7458, precision=0.6667, recall=0.9830, f1=0.7945, auroc=0.9548, spearman=0.7878
Threshold 0.54: accuracy=0.7546, precision=0.6753, recall=0.9806, f1=0.7998, auroc=0.9548, spearman=0.7878
Threshold 0.55: accuracy=0.7661, precision=0.6867, recall=0.9788, f1=0.8071, auroc=0.9548, spearman=0.7878
Threshold 0.56: accuracy=0.7767, precision=0.6974, recall=0.9776, f1=0.8141, auroc=0.9548, spearman=0.7878
Threshold 0.57: accuracy=0.7869, precision=0.7083, recall=0.9756, f1=0.8207, auroc=0.9548, spearman=0.7878
Threshold 0.58: accuracy=0.7975, precision=0.7199, recall=0.9740, f1=0.8279, auroc=0.9548, spearman=0.7878
Threshold 0.59: accuracy=0.8064, prec

### Federated_Transformer_Sup-Triplet

In [57]:
model_path = os.path.join(model_dir, "2022-05-30/Federated/Transformer/Sup-Triplet/22:24:51-Federated_Transformer_Sup-Triplet-global.pt")
model = torch.load(model_path)
evaluate(model, test_dl)

Threshold 0.50: accuracy=0.5000, precision=0.5000, recall=1.0000, f1=0.6667, auroc=0.4885, spearman=-0.0230
Threshold 0.51: accuracy=0.5000, precision=0.5000, recall=1.0000, f1=0.6667, auroc=0.4885, spearman=-0.0230
Threshold 0.52: accuracy=0.5000, precision=0.5000, recall=1.0000, f1=0.6667, auroc=0.4885, spearman=-0.0230
Threshold 0.53: accuracy=0.5000, precision=0.5000, recall=1.0000, f1=0.6667, auroc=0.4885, spearman=-0.0230
Threshold 0.54: accuracy=0.5000, precision=0.5000, recall=1.0000, f1=0.6667, auroc=0.4885, spearman=-0.0230
Threshold 0.55: accuracy=0.5000, precision=0.5000, recall=1.0000, f1=0.6667, auroc=0.4885, spearman=-0.0230
Threshold 0.56: accuracy=0.5000, precision=0.5000, recall=1.0000, f1=0.6667, auroc=0.4885, spearman=-0.0230
Threshold 0.57: accuracy=0.5000, precision=0.5000, recall=1.0000, f1=0.6667, auroc=0.4885, spearman=-0.0230
Threshold 0.58: accuracy=0.5000, precision=0.5000, recall=1.0000, f1=0.6667, auroc=0.4885, spearman=-0.0230
Threshold 0.59: accuracy=0.5

### Federated_Transformer_Unsup-SimCSE

In [5]:
model_path = os.path.join(model_dir, "2022-06-04/Federated/Transformer/Unsup-SimCSE/12:51:14-Federated_Transformer_Unsup-SimCSE-global.pt")
model = torch.load(model_path)
evaluate(model, test_dl)

Threshold 0.50: accuracy=0.4103, precision=0.2059, recall=0.0628, f1=0.0962, auroc=0.2820, spearman=-0.3776
Threshold 0.51: accuracy=0.4160, precision=0.2079, recall=0.0598, f1=0.0929, auroc=0.2820, spearman=-0.3776
Threshold 0.52: accuracy=0.4220, precision=0.2124, recall=0.0576, f1=0.0906, auroc=0.2820, spearman=-0.3776
Threshold 0.53: accuracy=0.4277, precision=0.2160, recall=0.0550, f1=0.0877, auroc=0.2820, spearman=-0.3776
Threshold 0.54: accuracy=0.4318, precision=0.2130, recall=0.0506, f1=0.0818, auroc=0.2820, spearman=-0.3776
Threshold 0.55: accuracy=0.4339, precision=0.2067, recall=0.0466, f1=0.0761, auroc=0.2820, spearman=-0.3776
Threshold 0.56: accuracy=0.4378, precision=0.2066, recall=0.0438, f1=0.0723, auroc=0.2820, spearman=-0.3776
Threshold 0.57: accuracy=0.4423, precision=0.2047, recall=0.0400, f1=0.0669, auroc=0.2820, spearman=-0.3776
Threshold 0.58: accuracy=0.4447, precision=0.2049, recall=0.0384, f1=0.0647, auroc=0.2820, spearman=-0.3776
Threshold 0.59: accuracy=0.4

## Transformer+GE-type_Sup

### Federated_Transformer+GE-type_Sup-SimCSE

In [14]:
model_path = os.path.join(model_dir, "2022-05-29/configs/global-4Transformer_Type-Sup_SimCSE-Federated-2022.05.29-13:26:13.pt")
model = torch.load(model_path)
evaluate(model, test_dl)

Threshold 0.50: accuracy=0.8000, precision=0.7185, recall=0.9864, f1=0.8314, auroc=0.9725, spearman=0.8183
Threshold 0.51: accuracy=0.8075, precision=0.7268, recall=0.9854, f1=0.8366, auroc=0.9725, spearman=0.8183
Threshold 0.52: accuracy=0.8146, precision=0.7347, recall=0.9848, f1=0.8416, auroc=0.9725, spearman=0.8183
Threshold 0.53: accuracy=0.8210, precision=0.7419, recall=0.9844, f1=0.8461, auroc=0.9725, spearman=0.8183
Threshold 0.54: accuracy=0.8272, precision=0.7493, recall=0.9834, f1=0.8505, auroc=0.9725, spearman=0.8183
Threshold 0.55: accuracy=0.8348, precision=0.7586, recall=0.9822, f1=0.8560, auroc=0.9725, spearman=0.8183
Threshold 0.56: accuracy=0.8425, precision=0.7683, recall=0.9808, f1=0.8616, auroc=0.9725, spearman=0.8183
Threshold 0.57: accuracy=0.8480, precision=0.7757, recall=0.9792, f1=0.8656, auroc=0.9725, spearman=0.8183
Threshold 0.58: accuracy=0.8545, precision=0.7842, recall=0.9782, f1=0.8705, auroc=0.9725, spearman=0.8183
Threshold 0.59: accuracy=0.8600, prec

### Federated_Transformer+GE-type_Sup-Triplet

In [58]:
model_path = os.path.join(model_dir, "2022-05-30/Federated/Transformer+GE-type/Sup-Triplet/22:24:34-Federated_Transformer+GE-type_Sup-Triplet-global.pt")
model = torch.load(model_path)
evaluate(model, test_dl)

Threshold 0.50: accuracy=0.9082, precision=0.9016, recall=0.9164, f1=0.9089, auroc=0.9691, spearman=0.8125
Threshold 0.51: accuracy=0.9083, precision=0.9047, recall=0.9128, f1=0.9087, auroc=0.9691, spearman=0.8125
Threshold 0.52: accuracy=0.9087, precision=0.9073, recall=0.9104, f1=0.9089, auroc=0.9691, spearman=0.8125
Threshold 0.53: accuracy=0.9087, precision=0.9094, recall=0.9078, f1=0.9086, auroc=0.9691, spearman=0.8125
Threshold 0.54: accuracy=0.9085, precision=0.9115, recall=0.9048, f1=0.9082, auroc=0.9691, spearman=0.8125
Threshold 0.55: accuracy=0.9083, precision=0.9138, recall=0.9016, f1=0.9077, auroc=0.9691, spearman=0.8125
Threshold 0.56: accuracy=0.9088, precision=0.9171, recall=0.8988, f1=0.9079, auroc=0.9691, spearman=0.8125
Threshold 0.57: accuracy=0.9092, precision=0.9198, recall=0.8966, f1=0.9080, auroc=0.9691, spearman=0.8125
Threshold 0.58: accuracy=0.9101, precision=0.9232, recall=0.8946, f1=0.9087, auroc=0.9691, spearman=0.8125
Threshold 0.59: accuracy=0.9100, prec

## Transformer+GE-type+Entity-Pooler

### Federated_Transformer+GE-type+Entity-Pooler_Sup-SimCSE

In [59]:
model_path = os.path.join(model_dir, "2022-06-03/Federated/Transformer+GE-type+Entity-Pooler/Sup-SimCSE/21:17:10-Federated_Transformer+GE-type+Entity-Pooler_Sup-SimCSE-global.pt")
model = torch.load(model_path)
evaluate(model, test_dl)

Threshold 0.50: accuracy=0.7964, precision=0.7113, recall=0.9978, f1=0.8305, auroc=0.9842, spearman=0.8387
Threshold 0.51: accuracy=0.8058, precision=0.7211, recall=0.9974, f1=0.8370, auroc=0.9842, spearman=0.8387
Threshold 0.52: accuracy=0.8155, precision=0.7313, recall=0.9974, f1=0.8439, auroc=0.9842, spearman=0.8387
Threshold 0.53: accuracy=0.8242, precision=0.7409, recall=0.9970, f1=0.8501, auroc=0.9842, spearman=0.8387
Threshold 0.54: accuracy=0.8322, precision=0.7501, recall=0.9964, f1=0.8559, auroc=0.9842, spearman=0.8387
Threshold 0.55: accuracy=0.8415, precision=0.7610, recall=0.9956, f1=0.8627, auroc=0.9842, spearman=0.8387
Threshold 0.56: accuracy=0.8518, precision=0.7736, recall=0.9946, f1=0.8703, auroc=0.9842, spearman=0.8387
Threshold 0.57: accuracy=0.8594, precision=0.7832, recall=0.9940, f1=0.8761, auroc=0.9842, spearman=0.8387
Threshold 0.58: accuracy=0.8669, precision=0.7931, recall=0.9928, f1=0.8818, auroc=0.9842, spearman=0.8387
Threshold 0.59: accuracy=0.8732, prec

### Federated_Transformer+GE-type+Entity-Pooler_Sup-Triplet

In [12]:
model_path = os.path.join(model_dir, "2022-06-04/Federated/Transformer+GE-type+Entity-Pooler/Sup-Triplet/12:45:15-Federated_Transformer+GE-type+Entity-Pooler_Sup-Triplet-global.pt")
model = torch.load(model_path)
evaluate(model, test_dl)

Threshold 0.50: accuracy=0.8979, precision=0.8865, recall=0.9126, f1=0.8994, auroc=0.9628, spearman=0.8016
Threshold 0.51: accuracy=0.8979, precision=0.8882, recall=0.9104, f1=0.8992, auroc=0.9628, spearman=0.8016
Threshold 0.52: accuracy=0.8990, precision=0.8918, recall=0.9082, f1=0.8999, auroc=0.9628, spearman=0.8016
Threshold 0.53: accuracy=0.8992, precision=0.8946, recall=0.9050, f1=0.8998, auroc=0.9628, spearman=0.8016
Threshold 0.54: accuracy=0.8997, precision=0.8976, recall=0.9024, f1=0.9000, auroc=0.9628, spearman=0.8016
Threshold 0.55: accuracy=0.8990, precision=0.8992, recall=0.8988, f1=0.8990, auroc=0.9628, spearman=0.8016
Threshold 0.56: accuracy=0.8990, precision=0.9027, recall=0.8944, f1=0.8985, auroc=0.9628, spearman=0.8016
Threshold 0.57: accuracy=0.8989, precision=0.9038, recall=0.8928, f1=0.8983, auroc=0.9628, spearman=0.8016
Threshold 0.58: accuracy=0.8983, precision=0.9062, recall=0.8886, f1=0.8973, auroc=0.9628, spearman=0.8016
Threshold 0.59: accuracy=0.8980, prec

### Federated_Transformer+GE-type+Entity-Pooler_Unsup-SimCSE

In [6]:
model_path = os.path.join(model_dir, "2022-06-04/Federated/Transformer+GE-type+Entity-Pooler/Unsup-SimCSE/12:56:14-Federated_Transformer+GE-type+Entity-Pooler_Unsup-SimCSE-global.pt")
model = torch.load(model_path)
evaluate(model, test_dl)

Threshold 0.50: accuracy=0.3320, precision=0.2965, recall=0.2448, f1=0.2682, auroc=0.2427, spearman=-0.4456
Threshold 0.51: accuracy=0.3310, precision=0.2914, recall=0.2360, f1=0.2608, auroc=0.2427, spearman=-0.4456
Threshold 0.52: accuracy=0.3321, precision=0.2884, recall=0.2288, f1=0.2552, auroc=0.2427, spearman=-0.4456
Threshold 0.53: accuracy=0.3311, precision=0.2825, recall=0.2194, f1=0.2470, auroc=0.2427, spearman=-0.4456
Threshold 0.54: accuracy=0.3287, precision=0.2742, recall=0.2080, f1=0.2366, auroc=0.2427, spearman=-0.4456
Threshold 0.55: accuracy=0.3295, precision=0.2698, recall=0.1998, f1=0.2296, auroc=0.2427, spearman=-0.4456
Threshold 0.56: accuracy=0.3285, precision=0.2623, recall=0.1892, f1=0.2198, auroc=0.2427, spearman=-0.4456
Threshold 0.57: accuracy=0.3281, precision=0.2569, recall=0.1816, f1=0.2128, auroc=0.2427, spearman=-0.4456
Threshold 0.58: accuracy=0.3264, precision=0.2497, recall=0.1732, f1=0.2045, auroc=0.2427, spearman=-0.4456
Threshold 0.59: accuracy=0.3