In [1]:
import sys
import os

module_path=".."
sys.path.append(module_path)

import torch

import transformers
from transformers import AutoTokenizer, AutoModel

from datasets import load_dataset

from models import SimCSE

import numpy as np
from scipy import spatial, stats
from sklearn.metrics import classification_report

In [2]:
# NOT Logging Lower than ERROR Level
transformers.logging.set_verbosity_error()

In [3]:
def load_trained_model(device, model_path):
    """
    Return Trained Tokenizer, Model
    """
    # Model Size
    model_size=model_path.split("/")[-1].split("_")[0].split("-")[-1]

    # Load Pre-Trained
    if model_size=="base":
        tokenizer=AutoTokenizer.from_pretrained("bert-base-uncased")
        pretrained=AutoModel.from_pretrained("bert-base-uncased").to(device)
    elif model_size=="large":
        tokenizer=AutoTokenizer.from_pretrained("bert-large-uncased")
        pretrained=AutoModel.from_pretrained("bert-large-uncased").to(device)

    # Load Trained
    model=SimCSE(pretrained=pretrained)
    model.load_state_dict(torch.load(model_path))
    model=model.to(device)

    return tokenizer, model

In [4]:
def evaluate_on_sts(device, model_path, split):
    """
    Evaluate Trained Model on STS Benchmark
    """
    # Load Trained Tokenizer, Model
    tokenizer, model=load_trained_model(device=device, model_path=model_path)

    # Load Dataset
    if split=="dev":
        dataset=open("../dataset/stsbenchmark/sts-dev.csv", "r").read()
    elif split=="test":
        dataset=open("../dataset/stsbenchmark/sts-test.csv", "r").read()

    # Evaluate
    preds=[]
    labels=[]
    
    model.eval()
    with torch.no_grad():
        for data in dataset.split("\n")[:-1]:
            # Parse
            label, sent1, sent2=data.split('\t')[4:7]

            # Encode
            enc1=tokenizer.encode(sent1)
            enc2=tokenizer.encode(sent2)

            # Prediction
            pred=1-spatial.distance.cosine(
                np.array(model.get_embedding(torch.tensor([enc1]).to(device)).detach().cpu()),
                np.array(model.get_embedding(torch.tensor([enc2]).to(device)).detach().cpu())
            )
            preds.append(pred)
            # Labels
            labels.append(float(label))

    # Results
    print(np.corrcoef(preds, labels))
    print(stats.spearmanr(preds, labels))

In [5]:
def evaluate_on_casehold(device, model_path, split):
    """
    Evaluate Trained Model on CaseHOLD
    """
    # Load Trained Tokenizer, Model
    tokenizer, model=load_trained_model(device=device, model_path=model_path)

    # Load Dataset
    if split=="dev":
        dataset=load_dataset("lex_glue", "case_hold")["validation"]
    elif split=="test":
        dataset=load_dataset("lex_glue", "case_hold")["test"]

    # Evaluate
    preds=[]
    labels=[]

    model.eval()
    with torch.no_grad():
        for data in dataset:
            # Context
            enc_context=tokenizer.encode(data["context"])
            embd_context=model.get_embedding(torch.tensor([enc_context]).to(device))

            # Prediction
            pred=-1
            max_sim=-1
            for idx, ending in enumerate(data["endings"]):
                # Ending
                enc_ending=tokenizer.encode(ending)
                #
                sim=1-spatial.distance.cosine(
                    np.array(embd_context.detach().cpu()),
                    np.array(model.get_embedding(torch.tensor([enc_ending]).to(device)).detach().cpu())
                )
                #
                if sim>max_sim:
                    pred=idx
                    max_sim=sim
            preds.append(pred)
            # Labels
            labels.append(data["label"])

    # Results
    print(classification_report(labels, preds, digits=4))

In [6]:
# Device
device=torch.device("cuda:3")

# Model Path
model_name="simcse-base_general_batch64_lr7e-05"

In [None]:
# Evaluation
for i in range(250, 100000, 250):
    model_path=f'{model_name}_step{str(i)}.pth'
    if model_path not in [ckpt for ckpt in os.listdir("../model/") if model_name in ckpt]:
        break
    print("===\n"+model_path+"\n---")
        
    evaluate_on_sts(
    #evaluate_on_casehold(
        device=device,
        model_path="../model/"+model_path,
        split="dev"
    )

===
simcse-base_general_batch64_lr7e-05_step250.pth
---
[[1.         0.62667342]
 [0.62667342 1.        ]]
SpearmanrResult(correlation=0.6188642971916182, pvalue=2.8139713663700435e-159)
===
simcse-base_general_batch64_lr7e-05_step500.pth
---
[[1.         0.61957803]
 [0.61957803 1.        ]]
SpearmanrResult(correlation=0.611379324054628, pvalue=1.8765526707007926e-154)
===
simcse-base_general_batch64_lr7e-05_step750.pth
---
[[1.         0.62852028]
 [0.62852028 1.        ]]
SpearmanrResult(correlation=0.6213916921742345, pvalue=6.176909433078792e-161)
===
simcse-base_general_batch64_lr7e-05_step1000.pth
---
[[1.         0.61926282]
 [0.61926282 1.        ]]
SpearmanrResult(correlation=0.6117845081520796, pvalue=1.036406400832974e-154)
===
simcse-base_general_batch64_lr7e-05_step1250.pth
---
