In [9]:
from notebooks.util_classifier import MultiHomogeneousGraphTripletDataset
from training_classifier import *
from util_classifier import *
from util import plot_losses, save_training_results, save_dict_to_json
from gat_models import *

import os
import random
import numpy as np
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
from torch_geometric.loader import DataLoader
from torch.optim import Adam
from torch.nn.modules.loss import TripletMarginLoss
import torch.nn as nn
from src.shared.database_wrapper import DatabaseWrapper
from src.shared.graph_schema import *
from src.shared.graph_sampling import GraphSampling

random.seed(40)
np.random.seed(40)
torch.manual_seed(40)
torch.cuda.manual_seed_all(40)

In [10]:
config = {
    'experiment': 'GAT Classifier Training',
    'max_hops': 2,
    'model_node_feature': 'feature_vec',  # Node feature to use for GAT encoder
    'hidden_channels': 32,
    'out_channels': 8,
    'num_heads': 8,
    'classifier_in_channels': 4 * 8,
    'classifier_hidden_channels': 16,
    'classifier_out_channels': 8,
    'classifier_dropout': 0.3,
    'margin': 1.0,
    'optimizer': 'Adam',
    'learning_rate': 0.005,
    'weight_decay': 5e-4,
    'num_epochs': 10,
    'batch_size': 32,
}

In [11]:
# Graph sampling configurations
node_properties = [
    'id',
    'feature_vec',
]

node_spec = [
    NodeType.PUBLICATION
]

edge_spec = [
    # EdgeType.SIM_TITLE,
    EdgeType.SIM_ABSTRACT,
    EdgeType.SIM_AUTHOR,
    EdgeType.SIM_ORG,
    EdgeType.SAME_AUTHOR,
]

gat_list = {
    EdgeType.SIM_ABSTRACT: './data/results/homogeneous (abstract) full_emb linear_layer dropout 32h 8out/gat_encoder.pt',
    EdgeType.SIM_AUTHOR: './data/results/homogeneous (similar co-authors) full_emb linear_layer dropout small_graph low_dim/gat_encoder.pt',
    EdgeType.SIM_ORG: './data/results/homogeneous (org) full_emb linear_layer dropout 32h 8out/gat_encoder.pt',
    EdgeType.SAME_AUTHOR: './data/results/homogeneous (same author) full_emb linear_layer dropout low_dim/gat_encoder.pt'
}


database = 'small-graph'
gs = GraphSampling(
    node_spec=node_spec,
    edge_spec=edge_spec,
    node_properties=node_properties,
    database=database
)

# Model configurations
# Load pre-trained GAT encoders
gat_encoders = {}
for edge_key, gat_path in gat_list.items():
    gat_encoder = HomoGATv2Encoder(config['hidden_channels'], config['out_channels'], num_heads=config['num_heads']).to(device)
    gat_encoder.load_state_dict(torch.load(gat_path))
    gat_list[edge_key] = gat_encoder

loss_fn = TripletMarginLoss(margin=config['margin'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# TODO: Adjust result folder name!
result_folder_name = 'classifier full_emb (abstract, org, sim_author, same_author edges) low dim 2 layers'
result_folder_path = f'./data/results/{result_folder_name}'
if not os.path.exists(result_folder_path):
    os.mkdir(result_folder_path)

Using default edge type: SimilarAbstract for homogeneous graph sampling.


## Embedding Network
**This network takes in the stacked GAT node embeddings and outputs a lower-dimensional embedding.**

In [12]:
class EmbeddingNet(nn.Module):
    def __init__(self, input_size, hidden_size = 128, embedding_size = 16, dropout = 0.2):
        super(EmbeddingNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(input_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, embedding_size),
        )
    
    def forward(self, x):
        x = self.fc(x)
        x = F.normalize(x, p=2, dim=-1)
        return x

## Triplet Network
**This network takes in three inputs: an anchor, a positive example, and a negative example. It outputs the embeddings of the three inputs. It is used to simplify triplet loss training**

In [13]:
class TripletNet(nn.Module):
    def __init__(self, embedding_net: EmbeddingNet, edge_spec: [EdgeType], gat_encoders: dict[EdgeType, nn.Module]):
        super(TripletNet, self).__init__()
        self.edge_spec = edge_spec
        self.gat_encoders = gat_encoders
        self.embedding_net = embedding_net
        
        for gat in self.gat_encoders.values():
            gat.eval()
            for param in gat.parameters():
                param.requires_grad = False
        

    def forward(self, data_dict: dict):
        anchor = []
        positive = []
        negative = []
        
        for edge_type in self.edge_spec:
            # Anchor node embedding for the edge type 
            anchor_graph = data_dict[edge_type][0]
            anchor_gat_emb = self.gat_encoders[edge_type](anchor_graph)
            anchor.append(anchor_gat_emb[anchor_graph.central_node_id])
            
            # Positive node embedding for the edge type
            positive_graph = data_dict[edge_type][1]
            positive_gat_emb = self.gat_encoders[edge_type](positive_graph)
            positive.append(positive_gat_emb[positive_graph.central_node_id])
            
            # Negative node embedding for the edge type
            negative_graph = data_dict[edge_type][2]
            negative_gat_emb = self.gat_encoders[edge_type](negative_graph)
            negative.append(negative_gat_emb[negative_graph.central_node_id])
            
        anchor = torch.cat(anchor, dim=1)
        positive = torch.cat(positive, dim=1)
        negative = torch.cat(negative, dim=1)
        
        output_anchor = self.embedding_net(anchor)
        output_positive = self.embedding_net(positive)
        output_negative = self.embedding_net(negative)
        
        return output_anchor, output_positive, output_negative

    def get_embedding(self, x):
        return self.embedding_net(x)

## Pair Classifier (Old)
**This network will be used for the actual classification task (the AND pipeline). (Deprecated)**

In [14]:
class PairClassifier(nn.Module):
    def __init__(self, embedding_net):
        super(PairClassifier, self).__init__()
        self.embedding_net = embedding_net

    def forward(self, embedding_1, embedding_2):
        out_1 = self.embedding_net(embedding_1)
        out_2 = self.embedding_net(embedding_2)
        
        # Compute probability of the pair being similar by computing pairwise distance
        distance = F.pairwise_distance(out_1, out_2)
        similarity_prediction = torch.sigmoid(-distance)
        
        return similarity_prediction

    def get_embedding(self, x):
        return self.embedding_net(x)

## Training

In [15]:
db = DatabaseWrapper(database=database)
data_harvester = ClassifierTripletDataHarvester(db=db, gs=gs, edge_spec=edge_spec, config=config, valid_triplets_save_file='valid_triplets_classifier_small_graph', transformer_model='sentence-transformers/all-MiniLM-L6-v2')


# Split the pairs into train and test

# Harvest the evaluation triplets first, since triplets are ordered by author. This will ensure that the evaluation set has authors not seen in the training set.
num_eval_authors = 3
eval_papers = set()
train_data = WhoIsWhoDataset.parse_train()
for i, val in enumerate(train_data.values()):
    num_eval_authors -= 1
    for paper in val['normal_data']:
        eval_papers.add(paper)
    if num_eval_authors == 0:
        break
eval_triplets = []
for triplet in data_harvester.triplets:
    if triplet[0] in eval_papers or triplet[1] in eval_papers or triplet[2] in eval_papers:
        eval_triplets.append(triplet)

# Remove the evaluation triplets from the data harvester
train_test_triplets = [triplet for triplet in data_harvester.triplets if triplet not in eval_triplets]

random.shuffle(train_test_triplets)

train_size = int(0.85 * len(train_test_triplets))
test_size = len(train_test_triplets) - train_size

train_triplets = train_test_triplets[:train_size]
test_triplets = train_test_triplets[train_size:]
config['train_size'] = len(train_triplets)
config['test_size'] = len(test_triplets)
config['eval_size'] = len(eval_triplets)

print(f"Train size: {len(train_triplets)}, Test size: {len(test_triplets)}, Eval size: {len(eval_triplets)}")

# Create the datasets from the pairs (distinct pairs for training and testing)
train_dataset = MultiHomogeneousGraphTripletDataset(train_triplets, gs, edge_spec=edge_spec, config=config)
test_dataset = MultiHomogeneousGraphTripletDataset(test_triplets, gs, edge_spec=edge_spec, config=config)
eval_dataset = MultiHomogeneousGraphTripletDataset(eval_triplets, gs, edge_spec=edge_spec, config=config)

# Create the DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=custom_triplet_collate)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)
eval_dataloader = DataLoader(eval_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)

# Create models
metadata = (
    [n.value for n in node_spec],
    [edge_pyg_key_vals[r] for r in edge_spec]
)
config['node_spec'] = metadata[0]
config['edge_spec'] = metadata[1]

# Embedding model
embedding_net = EmbeddingNet(
    input_size=config['classifier_in_channels'], 
    hidden_size=config['classifier_hidden_channels'],
    embedding_size=config['classifier_out_channels'],
    dropout=config['classifier_dropout']
).to(device)

# Triplet training classifier model
triplet_net = TripletNet(
    embedding_net=embedding_net,
    edge_spec=edge_spec,
    gat_encoders=gat_list
).to(device)

# Optimizer
optimizer = torch.optim.Adam(triplet_net.embedding_net.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

2024-12-13 10:39:15,434 - DatabaseWrapper - INFO - Connecting to the database ...
2024-12-13 10:39:15,434 - DatabaseWrapper - INFO - Database ready.


Preparing triplets...
Loading triplets...
Loaded 6284 triplets.
Train size: 4985, Test size: 880, Eval size: 419


In [16]:
from collections import defaultdict

num_epochs = config['num_epochs']
results = defaultdict(list)

current_batch = 1

for epoch in range(1, num_epochs + 1):
    print(f"=== Epoch {epoch}/{num_epochs} ======================")
    epoch_marker_pos = list(range(0, len(train_dataloader) * epoch, len(train_dataloader)))
    current_batch = 1
    for data_dict in tqdm(train_dataloader, desc=f"Epoch {epoch}/{num_epochs}"):
        if current_batch == 1 or current_batch == len(train_dataloader) // 2:
            print(f"___ Current Batch {current_batch}/{len(train_dataloader)} _________________________")
            # Model testing
            print("    Test Results:")
            test_loss, test_num_correct, test_correct_pos_val, test_correct_neg_val, test_precision, test_recall, test_F1 = test_and_eval(
                model=triplet_net,
                loss_fn=loss_fn,
                dataloader=test_dataloader,
                margin=config['margin']
            )
            results['test_total_loss'].append(test_loss)
            results['test_accuracies'].append(test_num_correct)
            results['test_accuracies_correct_pos'].append(test_correct_pos_val)
            results['test_accuracies_correct_neg'].append(test_correct_neg_val)
            results['test_precision'].append(test_precision)
            results['test_recall'].append(test_recall)
            results['test_F1'].append(test_F1)
    
            plot_losses(
                losses=[results['test_total_loss']], 
                epoch_len=2, 
                plot_title='Test Loss', 
                plot_file=result_folder_path + '/test_loss.png', 
                line_labels=["Triplet Loss"]
            )
            
            plot_losses(
                [results['test_accuracies'], results['test_accuracies_correct_pos'], results['test_accuracies_correct_neg']],
                epoch_len=2,
                plot_title='Test Accuracy',
                x_label='Test Iterations',
                y_label='Accuracy',
                line_labels=['Total Accuracy', 'Correct Pos', 'Correct Neg'],
                plot_file=result_folder_path + '/test_accuracy.png'
            )
            
            # Model evaluation
            print("    Eval Results:")
            eval_loss, eval_num_correct, eval_correct_pos_val, eval_correct_neg_val, eval_precision, eval_recall, eval_F1 = test_and_eval(
                model=triplet_net,
                loss_fn=loss_fn,
                dataloader=eval_dataloader,
                margin=config['margin']
            )
            results['eval_total_loss'].append(eval_loss)
            results['eval_accuracies'].append(eval_num_correct)
            results['eval_accuracies_correct_pos'].append(eval_correct_pos_val)
            results['eval_accuracies_correct_neg'].append(eval_correct_neg_val)
            results['eval_precision'].append(eval_precision)
            results['eval_recall'].append(eval_recall)
            results['eval_F1'].append(eval_F1)
            
            plot_losses(
                losses=[results['eval_total_loss']], 
                epoch_len=2, 
                plot_title='Evaluation Loss', 
                plot_file=result_folder_path + '/eval_loss.png', 
                line_labels=["Triplet Loss"]
            )
            
            plot_losses(
                [results['eval_accuracies'], results['eval_accuracies_correct_pos'], results['eval_accuracies_correct_neg']], 
                epoch_len=2, 
                plot_title='Evaluation Accuracy', 
                x_label='Eval Iterations',
                y_label='Accuracy',
                line_labels=['Total Accuracy', 'Correct Pos', 'Correct Neg'],
                plot_file=result_folder_path + '/eval_accuracy.png'
            )
            
            # Save config and training results
            config['results'] = results
            save_dict_to_json(config, result_folder_path + '/training_data.json')
            
            # Save model if loss has decreased
            if len(results['eval_accuracies']) > 1 and results['eval_accuracies'][-1] > max(results['eval_accuracies'][:-1]):
                print(f"Saving model at epoch {epoch}...")
                torch.save(triplet_net.embedding_net.state_dict(), result_folder_path + '/embedding_net.pt')
        
        loss = train(
            triplet_classifier_model=triplet_net,
            loss_fn=loss_fn,
            data_dict=data_dict,
            optimizer=optimizer
        )
        results['train_loss'].append(loss)
        
        plot_loss(results['train_loss'], epoch_len=len(train_dataloader), plot_title='Training Loss', plot_avg=True, plot_file=result_folder_path + '/train_loss.png')
        current_batch += 1




Epoch 1/10:   0%|          | 0/156 [00:00<?, ?it/s]

___ Current Batch 1/156 _________________________
    Test Results:
        Correct positive: 880 (100.00%), Correct negative: 0 (0.00%)
        Total correct: 880 (50.00%)
        Test/Eval Loss: 0.9547, Test/Eval Accuracy: 0.5000
        Precision: 0.5000, Recall: 1.0000, F1: 0.6667
    Eval Results:
        Correct positive: 419 (100.00%), Correct negative: 0 (0.00%)
        Total correct: 419 (50.00%)
        Test/Eval Loss: 0.9522, Test/Eval Accuracy: 0.5000
        Precision: 0.5000, Recall: 1.0000, F1: 0.6667
___ Current Batch 78/156 _________________________
    Test Results:
        Correct positive: 528 (60.00%), Correct negative: 648 (73.64%)
        Total correct: 1176 (66.82%)
        Test/Eval Loss: 0.6473, Test/Eval Accuracy: 0.6682
        Precision: 0.6947, Recall: 0.6000, F1: 0.6439
    Eval Results:
        Correct positive: 329 (78.52%), Correct negative: 161 (38.42%)
        Total correct: 490 (58.47%)
        Test/Eval Loss: 0.7192, Test/Eval Accuracy: 0.5847
    

Epoch 2/10:   0%|          | 0/156 [00:00<?, ?it/s]

___ Current Batch 1/156 _________________________
    Test Results:
        Correct positive: 580 (65.91%), Correct negative: 639 (72.61%)
        Total correct: 1219 (69.26%)
        Test/Eval Loss: 0.6089, Test/Eval Accuracy: 0.6926
        Precision: 0.7065, Recall: 0.6591, F1: 0.6820
    Eval Results:
        Correct positive: 321 (76.61%), Correct negative: 162 (38.66%)
        Total correct: 483 (57.64%)
        Test/Eval Loss: 0.7380, Test/Eval Accuracy: 0.5764
        Precision: 0.5554, Recall: 0.7661, F1: 0.6439
___ Current Batch 78/156 _________________________
    Test Results:
        Correct positive: 548 (62.27%), Correct negative: 648 (73.64%)
        Total correct: 1196 (67.95%)
        Test/Eval Loss: 0.5971, Test/Eval Accuracy: 0.6795
        Precision: 0.7026, Recall: 0.6227, F1: 0.6602
    Eval Results:
        Correct positive: 267 (63.72%), Correct negative: 160 (38.19%)
        Total correct: 427 (50.95%)
        Test/Eval Loss: 0.7526, Test/Eval Accuracy: 0.5095

Epoch 3/10:   0%|          | 0/156 [00:00<?, ?it/s]

___ Current Batch 1/156 _________________________
    Test Results:
        Correct positive: 579 (65.80%), Correct negative: 672 (76.36%)
        Total correct: 1251 (71.08%)
        Test/Eval Loss: 0.5864, Test/Eval Accuracy: 0.7108
        Precision: 0.7357, Recall: 0.6580, F1: 0.6947
    Eval Results:
        Correct positive: 258 (61.58%), Correct negative: 215 (51.31%)
        Total correct: 473 (56.44%)
        Test/Eval Loss: 0.7656, Test/Eval Accuracy: 0.5644
        Precision: 0.5584, Recall: 0.6158, F1: 0.5857
___ Current Batch 78/156 _________________________
    Test Results:
        Correct positive: 565 (64.20%), Correct negative: 649 (73.75%)
        Total correct: 1214 (68.98%)
        Test/Eval Loss: 0.6075, Test/Eval Accuracy: 0.6898
        Precision: 0.7098, Recall: 0.6420, F1: 0.6742
    Eval Results:
        Correct positive: 260 (62.05%), Correct negative: 239 (57.04%)
        Total correct: 499 (59.55%)
        Test/Eval Loss: 0.6804, Test/Eval Accuracy: 0.5955

Epoch 4/10:   0%|          | 0/156 [00:00<?, ?it/s]

___ Current Batch 1/156 _________________________
    Test Results:
        Correct positive: 574 (65.23%), Correct negative: 662 (75.23%)
        Total correct: 1236 (70.23%)
        Test/Eval Loss: 0.5868, Test/Eval Accuracy: 0.7023
        Precision: 0.7247, Recall: 0.6523, F1: 0.6866
    Eval Results:
        Correct positive: 256 (61.10%), Correct negative: 224 (53.46%)
        Total correct: 480 (57.28%)
        Test/Eval Loss: 0.7463, Test/Eval Accuracy: 0.5728
        Precision: 0.5676, Recall: 0.6110, F1: 0.5885
___ Current Batch 78/156 _________________________
    Test Results:
        Correct positive: 565 (64.20%), Correct negative: 657 (74.66%)
        Total correct: 1222 (69.43%)
        Test/Eval Loss: 0.5849, Test/Eval Accuracy: 0.6943
        Precision: 0.7170, Recall: 0.6420, F1: 0.6775
    Eval Results:
        Correct positive: 267 (63.72%), Correct negative: 165 (39.38%)
        Total correct: 432 (51.55%)
        Test/Eval Loss: 0.6991, Test/Eval Accuracy: 0.5155

Epoch 5/10:   0%|          | 0/156 [00:00<?, ?it/s]

___ Current Batch 1/156 _________________________
    Test Results:
        Correct positive: 570 (64.77%), Correct negative: 666 (75.68%)
        Total correct: 1236 (70.23%)
        Test/Eval Loss: 0.5895, Test/Eval Accuracy: 0.7023
        Precision: 0.7270, Recall: 0.6477, F1: 0.6851
    Eval Results:
        Correct positive: 248 (59.19%), Correct negative: 357 (85.20%)
        Total correct: 605 (72.20%)
        Test/Eval Loss: 0.6534, Test/Eval Accuracy: 0.7220
        Precision: 0.8000, Recall: 0.5919, F1: 0.6804
Saving model at epoch 5...
___ Current Batch 78/156 _________________________
    Test Results:
        Correct positive: 556 (63.18%), Correct negative: 646 (73.41%)
        Total correct: 1202 (68.30%)
        Test/Eval Loss: 0.6036, Test/Eval Accuracy: 0.6830
        Precision: 0.7038, Recall: 0.6318, F1: 0.6659
    Eval Results:
        Correct positive: 302 (72.08%), Correct negative: 193 (46.06%)
        Total correct: 495 (59.07%)
        Test/Eval Loss: 0.6588,

Epoch 6/10:   0%|          | 0/156 [00:00<?, ?it/s]

___ Current Batch 1/156 _________________________
    Test Results:
        Correct positive: 594 (67.50%), Correct negative: 668 (75.91%)
        Total correct: 1262 (71.70%)
        Test/Eval Loss: 0.5616, Test/Eval Accuracy: 0.7170
        Precision: 0.7370, Recall: 0.6750, F1: 0.7046
    Eval Results:
        Correct positive: 278 (66.35%), Correct negative: 211 (50.36%)
        Total correct: 489 (58.35%)
        Test/Eval Loss: 0.6655, Test/Eval Accuracy: 0.5835
        Precision: 0.5720, Recall: 0.6635, F1: 0.6144
___ Current Batch 78/156 _________________________
    Test Results:
        Correct positive: 576 (65.45%), Correct negative: 668 (75.91%)
        Total correct: 1244 (70.68%)
        Test/Eval Loss: 0.5699, Test/Eval Accuracy: 0.7068
        Precision: 0.7310, Recall: 0.6545, F1: 0.6906
    Eval Results:
        Correct positive: 261 (62.29%), Correct negative: 347 (82.82%)
        Total correct: 608 (72.55%)
        Test/Eval Loss: 0.6307, Test/Eval Accuracy: 0.7255

Epoch 7/10:   0%|          | 0/156 [00:00<?, ?it/s]

___ Current Batch 1/156 _________________________
    Test Results:
        Correct positive: 573 (65.11%), Correct negative: 666 (75.68%)
        Total correct: 1239 (70.40%)
        Test/Eval Loss: 0.5739, Test/Eval Accuracy: 0.7040
        Precision: 0.7281, Recall: 0.6511, F1: 0.6875
    Eval Results:
        Correct positive: 282 (67.30%), Correct negative: 349 (83.29%)
        Total correct: 631 (75.30%)
        Test/Eval Loss: 0.5827, Test/Eval Accuracy: 0.7530
        Precision: 0.8011, Recall: 0.6730, F1: 0.7315
Saving model at epoch 7...
___ Current Batch 78/156 _________________________
    Test Results:
        Correct positive: 579 (65.80%), Correct negative: 666 (75.68%)
        Total correct: 1245 (70.74%)
        Test/Eval Loss: 0.5704, Test/Eval Accuracy: 0.7074
        Precision: 0.7301, Recall: 0.6580, F1: 0.6922
    Eval Results:
        Correct positive: 258 (61.58%), Correct negative: 364 (86.87%)
        Total correct: 622 (74.22%)
        Test/Eval Loss: 0.6511,

Epoch 8/10:   0%|          | 0/156 [00:00<?, ?it/s]

___ Current Batch 1/156 _________________________
    Test Results:
        Correct positive: 581 (66.02%), Correct negative: 683 (77.61%)
        Total correct: 1264 (71.82%)
        Test/Eval Loss: 0.5618, Test/Eval Accuracy: 0.7182
        Precision: 0.7468, Recall: 0.6602, F1: 0.7008
    Eval Results:
        Correct positive: 280 (66.83%), Correct negative: 223 (53.22%)
        Total correct: 503 (60.02%)
        Test/Eval Loss: 0.6572, Test/Eval Accuracy: 0.6002
        Precision: 0.5882, Recall: 0.6683, F1: 0.6257
___ Current Batch 78/156 _________________________
    Test Results:
        Correct positive: 585 (66.48%), Correct negative: 656 (74.55%)
        Total correct: 1241 (70.51%)
        Test/Eval Loss: 0.5697, Test/Eval Accuracy: 0.7051
        Precision: 0.7231, Recall: 0.6648, F1: 0.6927
    Eval Results:
        Correct positive: 316 (75.42%), Correct negative: 166 (39.62%)
        Total correct: 482 (57.52%)
        Test/Eval Loss: 0.6525, Test/Eval Accuracy: 0.5752

Epoch 9/10:   0%|          | 0/156 [00:00<?, ?it/s]

___ Current Batch 1/156 _________________________
    Test Results:
        Correct positive: 577 (65.57%), Correct negative: 677 (76.93%)
        Total correct: 1254 (71.25%)
        Test/Eval Loss: 0.5735, Test/Eval Accuracy: 0.7125
        Precision: 0.7397, Recall: 0.6557, F1: 0.6952
    Eval Results:
        Correct positive: 270 (64.44%), Correct negative: 260 (62.05%)
        Total correct: 530 (63.25%)
        Test/Eval Loss: 0.6413, Test/Eval Accuracy: 0.6325
        Precision: 0.6294, Recall: 0.6444, F1: 0.6368
___ Current Batch 78/156 _________________________
    Test Results:
        Correct positive: 597 (67.84%), Correct negative: 674 (76.59%)
        Total correct: 1271 (72.22%)
        Test/Eval Loss: 0.5582, Test/Eval Accuracy: 0.7222
        Precision: 0.7435, Recall: 0.6784, F1: 0.7094
    Eval Results:
        Correct positive: 286 (68.26%), Correct negative: 179 (42.72%)
        Total correct: 465 (55.49%)
        Test/Eval Loss: 0.7031, Test/Eval Accuracy: 0.5549

Epoch 10/10:   0%|          | 0/156 [00:00<?, ?it/s]

___ Current Batch 1/156 _________________________
    Test Results:
        Correct positive: 580 (65.91%), Correct negative: 662 (75.23%)
        Total correct: 1242 (70.57%)
        Test/Eval Loss: 0.5636, Test/Eval Accuracy: 0.7057
        Precision: 0.7268, Recall: 0.6591, F1: 0.6913
    Eval Results:
        Correct positive: 248 (59.19%), Correct negative: 363 (86.63%)
        Total correct: 611 (72.91%)
        Test/Eval Loss: 0.6276, Test/Eval Accuracy: 0.7291
        Precision: 0.8158, Recall: 0.5919, F1: 0.6860
___ Current Batch 78/156 _________________________
    Test Results:
        Correct positive: 571 (64.89%), Correct negative: 677 (76.93%)
        Total correct: 1248 (70.91%)
        Test/Eval Loss: 0.5701, Test/Eval Accuracy: 0.7091
        Precision: 0.7377, Recall: 0.6489, F1: 0.6904
    Eval Results:
        Correct positive: 261 (62.29%), Correct negative: 358 (85.44%)
        Total correct: 619 (73.87%)
        Test/Eval Loss: 0.5527, Test/Eval Accuracy: 0.7387