In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from timeit import default_timer as timer

import os
import sys

module_path = os.path.abspath(os.path.join('../../Ant_Syn_Scraping/'))
if module_path not in sys.path:
    sys.path.append(module_path)
import model_functions_PhaseI as functions

In [4]:
def Phase_I_eval_model(model, testing_data_set, optimizer):
    #evaluate the model
    model.eval()
    
    syn_criterion = functions.Loss_Synonymy()
    ant_criterion = functions.Loss_Antonymy()

    #don't update nodes during evaluation b/c not training
    with torch.no_grad():
        test_losses = []
        syn_test_losses = []
        ant_test_losses = []
        
        syn_test_acc_list = []
        ant_test_acc_list = []
        
        test_total = 0

        for i, data in enumerate(testing_data_set,0):
        
            inputs, labels = data
        
            inputs, labels = Variable(inputs), Variable(labels)
            
            S1_out, S2_out, A1_out, A2_out, synonymy_score, antonymy_score = model(inputs)

            #calculate loss per batch of testing data
            syn_test_loss = syn_criterion(S1_out, S2_out, synonymy_score)
            ant_test_loss = ant_criterion(S2_out, A1_out, antonymy_score)
            
            test_loss = syn_test_loss + ant_test_loss
            
            test_losses.append(test_loss.item())
            syn_test_losses.append(syn_test_loss.item())
            ant_test_losses.append(ant_test_loss.item())
            test_total += 1 
        
            
            #accuracy function
            syn_el_count = 0
            syn_correct = 0
            
            for x, y in zip(synonymy_score, labels[0]):
                if x*0.8 <= y:
                    syn_correct += 1
                    syn_el_count += 1

                else:
                    syn_el_count += 1
            
            ant_el_count = 0
            ant_correct = 0
            
            for x, y in zip(antonymy_score, labels[1]):
                if x*0.8 <= y:
                    ant_correct += 1
                    ant_el_count += 1

                else:
                    ant_el_count += 1
            
        syn_acc = (syn_correct/syn_el_count) * 100
        syn_test_acc_list.append(syn_acc)
        
        ant_acc = (ant_correct/ant_el_count) * 100
        ant_test_acc_list.append(ant_acc)

        test_epoch_loss = sum(test_losses)/test_total
        syn_test_epoch_loss = sum(syn_test_losses)/test_total
        ant_test_epoch_loss = sum(ant_test_losses)/test_total
        
        syn_epoch_acc = sum(syn_test_acc_list)/test_total
        ant_epoch_acc = sum(ant_test_acc_list)/test_total


        print(f"Total Epoch Testing Loss is: {test_epoch_loss}")
        print(f"Total Epoch Antonym Testing Accuracy is: {ant_epoch_acc}")
        print(f"Total Epoch Synonym Testing Accuracy is: {syn_epoch_acc}")
        
    
    return test_epoch_loss, syn_test_epoch_loss, ant_test_epoch_loss, syn_epoch_acc, ant_epoch_acc
