In [1]:
# dependencies
import os
import json
from tqdm import tqdm
import wandb
from datetime import datetime

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from torch.nn.utils.rnn import pad_sequence

os.environ["TOKENIZERS_PARALLELISM"] = "false"
# wandb.init(project="nlp_project_rtv", name="bert base uncased test run")

In [2]:
# PATHS
DEV_CLAIMS_BASELINE_JSON_PATH = "/Users/taylortang/Life-at-UniMelb/Semester_3/COMP90042_NLP/Project_2/code/data/dev-claims-baseline.json"
DEV_CLAIMS_JSON_PATH = "/Users/taylortang/Life-at-UniMelb/Semester_3/COMP90042_NLP/Project_2/code/data/dev-claims.json"
EVIDENCE_JSON_PATH = "/Users/taylortang/Life-at-UniMelb/Semester_3/COMP90042_NLP/Project_2/code/data/evidence.json"
SMALL_EVIDENCE_JSON_PATH = "/Users/taylortang/Life-at-UniMelb/Semester_3/COMP90042_NLP/Project_2/code/data/small_evidence.json"
TINY_EVIDENCE_JSON_PATH = "/Users/taylortang/Life-at-UniMelb/Semester_3/COMP90042_NLP/Project_2/code/data/tiny_evidence.json"
CODE_DEV_EVIDENCE_JSON_PATH = "/Users/taylortang/Life-at-UniMelb/Semester_3/COMP90042_NLP/Project_2/code/data/code_dev_evidence.json"
TEST_CLAIMS_UNLABELLED_JSON_PATH = "/Users/taylortang/Life-at-UniMelb/Semester_3/COMP90042_NLP/Project_2/code/data/test-claims-unlabelled.json"
TRAIN_CLAIMS_JSON_PATH = "/Users/taylortang/Life-at-UniMelb/Semester_3/COMP90042_NLP/Project_2/code/data/train-claims.json"

# ARGS
BATCH_SIZE = 4
EPOCH = 1
MODEL_NAME = "distilbert-base-uncased"
MAX_LR = 2e-5
MAX_LENGTH = 64
RETRIEVAL_NUM = 3
K = 4

In [3]:
# bert model
model = AutoModel.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("using cuda")
else:
    if torch.backends.mps.is_available():
        device = torch.device("mps")
        print("using mps")
    else:
        device = torch.device("cpu")
        print("using cpu")

model = model.to(device)  # Move model to device

# Instantiate the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=MAX_LR, weight_decay=1e-4)
optimizer.zero_grad()

# Instantiate scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


using mps


In [4]:
pretrained_model = "/Users/taylortang/Life-at-UniMelb/Semester_3/COMP90042_NLP/Project_2/code/rtv/model/rtv_13_05_2023/rtv_checkpoint.bin"
load_model = True
if load_model:
    model.load_state_dict(torch.load(pretrained_model))
# model.cuda() # for NVIDIA GPU

In [5]:
# small util
def makedir(sub_dir):
    date = datetime.now().strftime("%d_%m_%Y")
    save_dir = f"./model/{sub_dir}_{date}"
    os.makedirs(save_dir, exist_ok=True)

    return save_dir

In [6]:
# small util
def pre_process(sentence):
    sentence = sentence.lower()
    cleaned = ''.join([char if char.isalnum() or char.isspace() else '' for char in sentence])
    return cleaned

In [7]:
# small util
def tokenization(text):
    tokens = tokenizer(text, max_length=MAX_LENGTH, padding=True, return_tensors="pt", truncation=True)
    return tokens

In [8]:
# Function to compute embeddings for either claim or evidence
def extract_embeddings(batch, prefix):
    embeddings = model(input_ids=batch[f"batched_{prefix}_input_ids"],attention_mask=batch[f"batched_{prefix}_attention_mask"]).last_hidden_state
    embeddings = embeddings[:, 0, :]
    return torch.nn.functional.normalize(embeddings, p=2, dim=1)

def extract_origin_embedding(claim_token):
    claim_embedding = model(input_ids=claim_token.input_ids, attention_mask=claim_token.attention_mask).last_hidden_state
    claim_embedding = claim_embedding[:, 0, :]
    return torch.nn.functional.normalize(claim_embedding, p=2, dim=1)

In [9]:
class TestSet(Dataset):
    def __init__(self):
        with open(TEST_CLAIMS_UNLABELLED_JSON_PATH, "r") as f:
            self.data = json.load(f)
        self.all_claim_keys = list(self.data.keys())

    def __len__(self):
        return len(self.all_claim_keys)

    def __getitem__(self, id):
        claim_id = self.all_claim_keys[id]
        claim = self.data.get(claim_id)
        return [claim, claim_id, claim["claim_text"]]

    def collate_fn(self, batch):
        batched_encoding = dict()
    
        claims, claim_ids, claim_texts = zip(*[(claim, claim_id, pre_process(claim_text)) for claim, claim_id, claim_text in batch])
        batched_encoding["batched_claims"] = claims
        batched_encoding["batched_claim_ids"] = claim_ids
        claim_tokens = tokenization(claim_texts)
        batched_encoding["batched_claim_input_ids"] = claim_tokens.input_ids
        batched_encoding["batched_claim_attention_mask"] = claim_tokens.attention_mask

        return batched_encoding

In [10]:
class ValidationSet(Dataset):
    def __init__(self):
        with open(DEV_CLAIMS_JSON_PATH, "r") as f:
            self.data = json.load(f)
        self.all_claim_keys = list(self.data.keys())

    def __len__(self):
        return len(self.all_claim_keys)

    def __getitem__(self, id):
        claim_id = self.all_claim_keys[id]
        claim = self.data.get(claim_id)
        claim_text = pre_process(claim["claim_text"])
        return [claim, claim_id, claim_text]

    def collate_fn(self, batch):
        batched_encoding = dict()
        
        claims, claim_ids, claim_texts = zip(*[(claim, claim_id, claim_text) for claim, claim_id, claim_text in batch])
        
        batched_encoding["batched_claims"] = claims
        batched_encoding["batched_claim_ids"] = claim_ids

        claim_tokens = tokenization(claim_texts)
        batched_encoding["batched_claim_input_ids"] = claim_tokens.input_ids
        batched_encoding["batched_claim_attention_mask"] = claim_tokens.attention_mask
        
        evidences = [claim["evidences"] for claim, claim_id, claim_text in batch]
        batched_encoding["batched_claim_evidences"] = evidences

        return batched_encoding

In [11]:
class EvidenceSet(Dataset):
    def __init__(self):
        with open(CODE_DEV_EVIDENCE_JSON_PATH, "r") as f:
            self.data = json.load(f)
        self.all_evidence_keys = list(self.data.keys())

    def __len__(self):
        return len(self.all_evidence_keys)

    def __getitem__(self, id):
        evidence_id = self.all_evidence_keys[id]
        evidence = pre_process(self.data[evidence_id])
        return [evidence_id, evidence]

    def collate_fn(self, batch):
        batched_encoding = dict()
        
        evidence_ids, evidence_texts = zip(*batch)
        batched_encoding["batched_evidence_ids"] = evidence_ids

        evidence_tokens = tokenization(evidence_texts)
        batched_encoding["batched_evidence_input_ids"] = evidence_tokens.input_ids
        batched_encoding["batched_evidence_attention_mask"] = evidence_tokens.attention_mask
        
        return batched_encoding

In [12]:
class TrainSet(Dataset):
    def __init__(self, evidence_embeddings, evidence_ids):
        self.evidence_embeddings = evidence_embeddings.to(device)
        self.evidence_ids = evidence_ids
        self.data = self.in_batch_negative_samples()
        self.all_claim_keys = list(self.data.keys())
        with open(CODE_DEV_EVIDENCE_JSON_PATH, "r") as f:
            self.evidences = json.load(f)
        self.all_evidence_keys = list(self.evidences.keys())

    def __len__(self):
        return len(self.all_claim_keys)

    def __getitem__(self, id):
        claim_id = self.all_claim_keys[id]
        claim = self.data.get(claim_id)
        return claim 

    def in_batch_negative_samples(self, batch_size=1):
        with open(DEV_CLAIMS_JSON_PATH, "r") as f:
            train_set = json.load(f)
        train_with_negative = dict()

        # Iterate over batches of data
        for i in range(0, len(train_set), batch_size):
            batch_items = list(train_set.items())[i:i+batch_size]
            claim_texts = [pre_process(item[1]['claim_text']) for item in batch_items]

            # Tokenize the claim texts
            claim_token = tokenization(claim_texts)
            claim_token = claim_token.to(device)
            claim_embedding = extract_origin_embedding(claim_token)

            # Compute scores between claim embeddings and evidence embeddings
            similarity = torch.mm(claim_embedding, self.evidence_embeddings.t())
            similar_ids = torch.topk(similarity, k=K, dim=1).indices.tolist()

            # Iterate over the items in the current batch
            for idx, (claim_id, claim) in enumerate(batch_items):
                negative_evidences = []
                
                for i in similar_ids[idx]:
                    if self.evidence_ids[i] not in claim["evidences"]:
                        negative_evidences.append(self.evidence_ids[i])
                
                claim["negative_evidences"] = negative_evidences
                train_with_negative[claim_id] = claim
            
        return train_with_negative

    
    def collate_fn(self, batch):
        batched_encoding = dict()

        claims = [claim for claim in batch]

        related_ids = [id for claim in batch for id in claim["evidences"]]
        unrelated_ids = [id for claim in batch for id in claim["negative_evidences"]]

        labels = [[1]*len(claim["evidences"]) for claim in batch]

        claim_texts = [pre_process(claim["claim_text"]) for claim in claims]
        claim_tokens = tokenization(claim_texts)

        related_evidence_texts = [pre_process(self.evidences[id]) for id in related_ids]
        unrelated_evidence_texts = [pre_process(self.evidences[id]) for id in unrelated_ids]
        evidences_texts = related_evidence_texts + unrelated_evidence_texts

        evidence_tokens = tokenization(evidences_texts)

        batched_encoding["batched_claim_input_ids"] = claim_tokens.input_ids
        batched_encoding["batched_claim_attention_mask"] = claim_tokens.attention_mask
        batched_encoding["batched_evidence_input_ids"] = evidence_tokens.input_ids
        batched_encoding["batched_evidence_attention_mask"] = evidence_tokens.attention_mask
        batched_encoding["labels"] = labels

        return batched_encoding


In [13]:
# make predication
load_emb = False
if load_emb == True:
    best_evidence_embeddings = torch.load("/Users/taylortang/Life-at-UniMelb/Semester_3/COMP90042_NLP/Project_2/code/rtv/model/rtv_13_05_2023/evidence_embeddings")
    best_evidence_ids = torch.load("/Users/taylortang/Life-at-UniMelb/Semester_3/COMP90042_NLP/Project_2/code/rtv/model/rtv_13_05_2023/evidence_ids")

In [14]:
def embed_evidence(evi_dataloader, model):

    model.eval()

    # place holder
    total_samples = len(evi_dataloader.dataset)
    embedding_dim = model.config.hidden_size
    evidence_embeddings = torch.zeros(total_samples, embedding_dim).to(device)
    evidence_ids = []

    with torch.no_grad():
        for i, batch in enumerate(tqdm(evi_dataloader)):
            batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}
            
            evidence_embedding = extract_embeddings(batch, "evidence")
            start_index = i * evi_dataloader.batch_size
            end_index = start_index + evidence_embedding.size(0)
            evidence_embeddings[start_index:end_index] = evidence_embedding
            evidence_ids.extend(batch["batched_evidence_ids"])

    return evidence_embeddings, evidence_ids

In [15]:
def validate(val_dataloader, evidence_embeddings, evidence_ids, model):
    
    f_scores = []
    model.eval()
    evidence_embeddings = evidence_embeddings.to(device).t()

    with torch.no_grad():
        for batch in tqdm(val_dataloader):
            batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}
            batched_claim_embeddings = extract_embeddings(batch, "claim")

            similarity = torch.mm(batched_claim_embeddings, evidence_embeddings)
            retrieved_ids = torch.topk(similarity, k=RETRIEVAL_NUM, dim=1).indices.tolist()

            for idx, data in enumerate(batch["batched_claims"]):
                evidence_correct = 0
                pred_evidences = [evidence_ids[i] for i in retrieved_ids[idx]]
                
                for evidence_id in batch["batched_claim_evidences"][idx]:
                    if evidence_id in pred_evidences:
                        evidence_correct += 1

                if evidence_correct > 0:
                    evidence_recall = float(evidence_correct) / len(batch["batched_claim_evidences"][idx])
                    evidence_precision = float(evidence_correct) / len(pred_evidences)
                    evidence_fscore = (2 * evidence_precision * evidence_recall) / (evidence_precision + evidence_recall)
                
                else:
                    evidence_fscore = 0

                f_scores.append(evidence_fscore)

    return sum(f_scores) / len(f_scores)

In [16]:
def predict(evidence_embeddings, evidence_ids):
    test_set = TestSet()
    dataloader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, collate_fn=test_set.collate_fn)

    predicted = {}
    for batch in tqdm(dataloader):
        batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}
        claim_embedding = extract_embeddings(batch, "claim")
        
        similarity = torch.mm(claim_embedding, evidence_embeddings.t())
        topk_ids = torch.topk(similarity, k=RETRIEVAL_NUM, dim=1).indices.tolist()

        for idx, claim in enumerate(batch["batched_claims"]):
            claim["evidences"] = [evidence_ids[i] for i in topk_ids[idx]]
            predicted[batch["batched_claim_ids"][idx]] = claim

    with open("/Users/taylortang/Life-at-UniMelb/Semester_3/COMP90042_NLP/Project_2/code/prediction/retrieval-test-claims.json", 'w') as file:
        json.dump(predicted, file)

In [17]:
# Function to compute the loss
def compute_loss(claim_embeddings, evidence_embeddings, batch):

    # Compute cosine similarities and scores
    similarity = torch.mm(claim_embeddings, evidence_embeddings.t())
    negative_log_likelihood  = - torch.nn.functional.log_softmax(similarity / 0.1, dim=1)

    loss = []
    s = 0
    # Iterate over each label in the batch
    for idx, labels in enumerate(batch["labels"]):
        for label in labels:

            # Select the specific row from the negative_log_likelihood tensor
            selected_row = negative_log_likelihood[idx]

            # Determine the end index for slicing
            e = s + label

            # Slice the tensor from start_idx to end_idx
            selected_elements = selected_row[s:e]

            # Compute the mean of the selected elements
            cur_loss = torch.mean(selected_elements)

            loss.append(cur_loss)
            s += 1

    return torch.stack(loss).mean() / len(batch)

In [18]:
def train(val_dataloader, evi_dataloader, save_dir):
    method_2 = False
    sp = 0
    if method_2:
        
        # # Define the total number of steps, 3000 is the train_dataloader length
        total_steps = EPOCH * 3000
        final_lr_steps = EPOCH/2 * 3000

        # Define a lambda function to decrease the learning rate linearly over the first 5 epochs, and keep it constant afterwards
        lr_lambda = lambda step: MAX_LR - (step / final_lr_steps) * (MAX_LR - 1e-5) if step < final_lr_steps else 1e-5

        # Instantiate the scheduler
        lmbda_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    # optional evaluation
    print("\nGenerating evidence embedding:\n")
    evidence_embeddings, evidence_ids = embed_evidence(evi_dataloader, model)
    print("\nEvaluate evidence embedding, f-score:\n")
    f_score = validate(val_dataloader, evidence_embeddings, evidence_ids, model)
    # wandb.log({"f_score": f_score}, step=0)
    print(f_score)

    # Early stopping parameters
    
    maximum_f_score = 0
    patience = 5  # Number of epochs to wait for improvement before stopping
    patience_counter = 0  # Counter to keep track of non-improving epochs

    # Training loop
    for epoch in range(EPOCH):
        sp += 1
        
        print("Generating training dataset with negative samples")
        train = TrainSet(evidence_embeddings, evidence_ids)
        train_dataloader = DataLoader(train, BATCH_SIZE, shuffle=False, collate_fn=train.collate_fn)
        del evidence_embeddings, evidence_ids

        print("Starting epoch: ", epoch)
        # Iterate over each batch in the data loader
        for idx, batch in enumerate(tqdm(train_dataloader)):

            # Move tensors to device
            batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}

            # Compute embeddings for claim and evidence
            in_batch_claim_embeddings = extract_embeddings(batch, "claim")
            in_batch_evidence_embeddings = extract_embeddings(batch, "evidence")

            # Compute the loss and perform backpropagation
            loss = compute_loss(in_batch_claim_embeddings, in_batch_evidence_embeddings, batch)
            loss.backward()

            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            del in_batch_claim_embeddings, in_batch_evidence_embeddings
            
        print("Finishing epoch: ", epoch)

        # Evaluate the model every epoch
        print("\nGenerating evidence embedding:")
        evidence_embeddings, evidence_ids = embed_evidence(evi_dataloader, model)
        print("\nEvaluate evidence embedding, f-score:")
        f_score = validate(val_dataloader, evidence_embeddings, evidence_ids, model)
        print("FSCORE", f_score)
        # wandb.log({"f_score": f_score}, step=epoch)

        best_evidence_embeddings, best_evidence_ids = None, None
        if f_score > maximum_f_score:

            torch.save(model.state_dict(), os.path.join(save_dir, "rtv_checkpoint.bin"))
            torch.save(evidence_embeddings, os.path.join(save_dir, "evidence_embeddings"))
            torch.save(evidence_ids, os.path.join(save_dir, "evidence_ids"))

            best_evidence_embeddings, best_evidence_ids = evidence_embeddings, evidence_ids
            maximum_f_score = f_score
            print("maximum_f_score", f_score)
            
        else:

            patience_counter += 1
            print("No improvement in f_score, patience: ", patience_counter)
            if patience_counter >= patience:
                print("Early stopping triggered - no improvement in f_score for {} epochs".format(patience))
                break 
            
    return best_evidence_embeddings, best_evidence_ids

In [19]:
# create validation dataloader
validation = ValidationSet()
val_dataloader = DataLoader(validation, batch_size=BATCH_SIZE, shuffle=False, collate_fn=validation.collate_fn)

In [20]:
# create evidence dataloader
evidence = EvidenceSet()
evi_dataloader = DataLoader(evidence, batch_size=BATCH_SIZE, shuffle=False, collate_fn=evidence.collate_fn)

In [None]:
# start training
save_dir = makedir("rtv")
best_evidence_embeddings, best_evidence_ids = train(val_dataloader, evi_dataloader, save_dir)

In [22]:
# make predication
load_emb = False
if load_emb:
    best_evidence_embeddings = torch.load("/Users/taylortang/Life-at-UniMelb/Semester_3/COMP90042_NLP/Project_2/code/rtv/model/rtv_13_05_2023/evidence_embeddings")
    best_evidence_ids = torch.load("/Users/taylortang/Life-at-UniMelb/Semester_3/COMP90042_NLP/Project_2/code/rtv/model/rtv_13_05_2023/evidence_ids")
predict = False
if predict:
    predict(best_evidence_embeddings, best_evidence_ids)