In [15]:
from sentence_transformers import SentenceTransformer,util
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from datasets import load_dataset
import pandas as pd
import torch
import json

datasetraw = load_dataset("csv",data_files="data/climate_fever.csv")
dataset = datasetraw["train"].train_test_split(test_size=0.2)

claimtest = dataset['test']['claim']
labeltest = dataset['test']['claim_label']

data = pd.read_csv("data/climate_fever_evidence_embedding.csv",header=None)

embds = []

for embd in data[1]:
    embds.append(json.loads(embd))

embds = torch.Tensor(embds)

model = SentenceTransformer('sentence-transformers/stsb-roberta-base-v2')



In [19]:
def topkRelatedSentence(k, inputEmb, dataEmb):
    similarityScore = util.cos_sim(inputEmb, dataEmb)
    return torch.topk(similarityScore, k)[1].reshape(-1)

def labelPrediction(input, data):
    indexes = topkRelatedSentence(5, model.encode(input), embds)

    topEvidences = data[0].iloc[indexes].tolist()

    pairs = []

    for evidence in topEvidences:
        pairs.append(json.dumps([input,evidence]))

    votes = []

    model_token = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
    model_voter_1 = DistilBertForSequenceClassification.from_pretrained("model/voter_1")
    model_voter_2 = DistilBertForSequenceClassification.from_pretrained("model/voter_2")
    model_voter_3 = DistilBertForSequenceClassification.from_pretrained("model/voter_3")
    model_voter_4 = DistilBertForSequenceClassification.from_pretrained("model/voter_4")
    model_voter_5 = DistilBertForSequenceClassification.from_pretrained("model/voter_5")

    model_voters = [model_voter_1, model_voter_2, model_voter_3, model_voter_4, model_voter_5]

    for pair in pairs:
        temp_vote = []
        for model_voter in model_voters:
            inputs = model_token(pair, return_tensors="pt")
            with torch.no_grad():
                logits = model_voter(**inputs).logits
            predicted_class_id = logits.argmax().item()
            temp_vote.append(predicted_class_id)
        votes.append(temp_vote)

    model_token_verdict = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
    model_voter_verdict = DistilBertForSequenceClassification.from_pretrained("model/verdict")

    inputs = model_token_verdict(json.dumps(votes), return_tensors="pt")
    with torch.no_grad():
        logits = model_voter_verdict(**inputs).logits

    predicted_class_id = logits.argmax().item()

    classes = ['NOT_ENOUGH_INFO','SUPPORTS', 'REFUTES', 'DISPUTED']

    return classes[predicted_class_id]

In [20]:
test = []

for claim in claimtest:
    test.append(labelPrediction(claim,data))
    print(len(test)/len(claimtest))

print(test)

0.003257328990228013
0.006514657980456026
0.009771986970684038
0.013029315960912053
0.016286644951140065
0.019543973941368076
0.02280130293159609
0.026058631921824105
0.029315960912052116
0.03257328990228013
0.035830618892508145
0.03908794788273615
0.04234527687296417
0.04560260586319218
0.048859934853420196
0.05211726384364821
0.05537459283387622
0.05863192182410423
0.06188925081433225
0.06514657980456026
0.06840390879478828
0.07166123778501629
0.0749185667752443
0.0781758957654723
0.08143322475570032
0.08469055374592833
0.08794788273615635
0.09120521172638436
0.09446254071661238
0.09771986970684039
0.10097719869706841
0.10423452768729642
0.10749185667752444
0.11074918566775244
0.11400651465798045
0.11726384364820847
0.12052117263843648
0.1237785016286645
0.1270358306188925
0.13029315960912052
0.13355048859934854
0.13680781758957655
0.14006514657980457
0.14332247557003258
0.1465798045602606
0.1498371335504886
0.15309446254071662
0.1563517915309446
0.15960912052117263
0.162866449511400

In [21]:
score = 0

for i in range(len(test)):
    if test[i] == labeltest[i]:
        score += 1

print(score/len(test))

0.739413680781759
