In [1]:
import pandas as pd
import json
import os
import numpy as np
import utils

In [2]:
tweets, test_tweets = utils.get_tweets()
test_tweets = test_tweets[1:]
train_conns, dev_conns, test_conns = utils.get_qrels()
claims = utils.get_claims()

In [3]:
from sentence_transformers import SentenceTransformer

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
model = SentenceTransformer("sentence-transformers/sentence-t5-base")

In [4]:
from functools import partial

MAX_LENGTH = 256
BATCH_SIZE = 16

tokenize = partial(model.tokenizer, **dict(
    truncation=True, 
    max_length=MAX_LENGTH, 
    padding="max_length", 
    return_attention_mask=True
))

In [5]:
import dataloaders
import importlib
importlib.reload(dataloaders)

train_dl = dataloaders.get_clef2021_dataloader(tokenize, claims, tweets, train_conns, 
                                               {'batch_size':BATCH_SIZE, 'shuffle':True})    
dev_dl = dataloaders.get_clef2021_dataloader(tokenize, claims, tweets, dev_conns, 
                                               {'batch_size':BATCH_SIZE, 'shuffle':False})    

In [6]:
import torch
import torch.optim as optim
nn = torch.nn

LR = 1e-5

optimizer = optim.AdamW(model.parameters(), lr=LR)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [7]:
torch.cuda.is_available()

True

In [8]:
CE = nn.CrossEntropyLoss()
temp = .05
def MNR_loss(left_tensors, right_tensors, negatives=None):
    logits = torch.einsum("bd,cd->bc", left_tensors, right_tensors)
    return CE(logits / temp, torch.arange(logits.shape[0]).to(device))
    

In [41]:
len(train_dl.dataset)

999

In [9]:
EPOCHS = 3
PRINT_STEPS = 5

model.to(device)
for epoch in range(EPOCHS):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(train_dl, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = [elmt.to(device) for elmt in inputs]
        labels = [elmt.to(device) for elmt in labels]
        current_batch_size = inputs[0].shape[0]
        inpt_dict = {
            "input_ids":torch.cat([inputs[0], labels[0]]),
            "attention_mask":torch.cat([inputs[1], labels[1]])
        }

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inpt_dict)
        embeddings = outputs['sentence_embedding']
        loss = MNR_loss(embeddings[:current_batch_size], embeddings[current_batch_size:])
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % PRINT_STEPS  == PRINT_STEPS-1:    # print every 2000 mini-batches
            print(f'TRAIN [{epoch + 1}, {i + 1:5d}] loss: {running_loss / PRINT_STEPS:.3f}')
            running_loss = 0.0
    
    running_loss = 0.0       
    for i, data in enumerate(dev_dl, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = [elmt.to(device) for elmt in inputs]
        labels = [elmt.to(device) for elmt in labels]
        current_batch_size = inputs[0].shape[0]
        inpt_dict = {
            "input_ids":torch.cat([inputs[0], labels[0]]),
            "attention_mask":torch.cat([inputs[1], labels[1]])
        }

        with torch.no_grad():
            outputs = model(inpt_dict)
            embeddings = outputs['sentence_embedding']
            loss = MNR_loss(embeddings[:current_batch_size], embeddings[current_batch_size:])
            running_loss += (loss * embeddings.shape[0]).item()
            
    print(f'DEV [{epoch + 1}, {i + 1:5d}] loss: {running_loss / len(dev_dl.dataset):.3f}')

print('Finished Training')

TRAIN [1,     5] loss: 0.893
TRAIN [1,    10] loss: 0.776
TRAIN [1,    15] loss: 0.688
TRAIN [1,    20] loss: 0.657
TRAIN [1,    25] loss: 0.422
TRAIN [1,    30] loss: 0.448
TRAIN [1,    35] loss: 0.389
TRAIN [1,    40] loss: 0.354
TRAIN [1,    45] loss: 0.317
TRAIN [1,    50] loss: 0.297
TRAIN [1,    55] loss: 0.277
TRAIN [1,    60] loss: 0.298
DEV [1,     5] loss: 0.580
DEV [1,    10] loss: 0.447
TRAIN [2,     5] loss: 0.239
TRAIN [2,    10] loss: 0.210
TRAIN [2,    15] loss: 0.181
TRAIN [2,    20] loss: 0.167
TRAIN [2,    25] loss: 0.231
TRAIN [2,    30] loss: 0.155
TRAIN [2,    35] loss: 0.125
TRAIN [2,    40] loss: 0.140
TRAIN [2,    45] loss: 0.166
TRAIN [2,    50] loss: 0.128
TRAIN [2,    55] loss: 0.133
TRAIN [2,    60] loss: 0.179
DEV [2,     5] loss: 0.506
DEV [2,    10] loss: 0.382
TRAIN [3,     5] loss: 0.091
TRAIN [3,    10] loss: 0.133
TRAIN [3,    15] loss: 0.141
TRAIN [3,    20] loss: 0.107
TRAIN [3,    25] loss: 0.099
TRAIN [3,    30] loss: 0.097
TRAIN [3,    35] loss:

In [33]:
embs = model.encode(claims.vclaim.to_list())
# embs = model.encode(claims[["title", "subtitle", "vclaim"]].apply(lambda x: f"title: {x[0]}\nsubtitle: {x[1]}\nclaim: {x[2]}", axis=1).to_list())

In [34]:
def get_idx(connections, claims, tweets):
    run_tweets = tweets.join(connections.set_index("tweet_id"), on="id", how="inner")
    run_tweets = run_tweets.join(claims.set_index("vclaim_id"), on="claim_id", how="inner")
    run_tweets = run_tweets[["tweet", "vclaim"]].reset_index()
    claim_idx = [claims.vclaim.to_list().index(t_claim) for t_claim in run_tweets.vclaim.to_list()]
    return run_tweets, claim_idx

run_tweets, claim_idx = get_idx(test_conns, claims, test_tweets)
tweet_embs = model.encode(run_tweets.tweet.to_list())
scores = tweet_embs @ embs.T
ranks = [score.argsort()[::-1] for score in scores]

def avg_prec(gold, rankings, n):
    is_rel = (np.array(rankings)[:n] == gold).astype(float)
    return (is_rel/np.arange(1,n+1)).sum()

def mean_avg_prec(golds, rankings, n):
    avg_precs = [avg_prec(gold, rlist, n) for gold, rlist in zip(golds, rankings)]
    return np.array(avg_precs).mean()

In [36]:
map_5 = mean_avg_prec(claim_idx, ranks, 1)
map_5

0.8316831683168316

In [38]:
inpt_dict["input_ids"].device

device(type='cuda', index=0)

In [46]:
import configparser
config = configparser.ConfigParser()
config.read("experiments/finetune_st5_base_claims/config.ini")

['experiments/finetune_st5_base_claims/config.ini']

In [52]:
config["training"].getfloat("lr")

1.0

In [54]:
config["model"].get("model_string")

'sentence-transformers/sentence-t5-base'