# 2024 COMP90042 Project
*Make sure you change the file name with your group id.*

# Readme
1. The training can be quite slow under the environment of free Colab. It takes us 2.5 hours to run 5 epochs.
2. At the end of the notebook, some json files are strored for the next part --- claim classification.

**We use pytorch, nltk, scikit-learn in this project.**

# 1.DataSet Processing


## 1.1 Download the data from github

In [2]:
import os

# the repository link:
repository_url = 'https://github.com/drcarenhan/COMP90042_2024.git'

# clone the repository
os.system(f'git clone {repository_url}')


32768

In [3]:
save_path = '/content/COMP90042_2024/data'
os.makedirs(save_path, exist_ok=True)

output_file_path = os.path.join(save_path, 'evidence.json')

!gdown --id '1JlUzRufknsHzKzvrEjgw8D3n_IRpjzo6' -O {output_file_path}

Downloading...
From (original): https://drive.google.com/uc?id=1JlUzRufknsHzKzvrEjgw8D3n_IRpjzo6
From (redirected): https://drive.google.com/uc?id=1JlUzRufknsHzKzvrEjgw8D3n_IRpjzo6&confirm=t&uuid=c124f8eb-c9a6-47fc-8e0b-c16fbca66293
To: /content/COMP90042_2024/data/evidence.json
100% 174M/174M [00:01<00:00, 160MB/s]


In [4]:
cd /content/COMP90042_2024/

/content/COMP90042_2024


## 1.2 PreProcess for evidence and claims

### 1.2.1 preprocessing function

This code of stemming, lemmatizing and stopword removal are referred from tutorial

In [1]:
import string
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer

# Download necessary NLTK data files
nltk.download('wordnet')
nltk.download('stopwords')
nltk.download('punkt')

# Initialize the lemmatizer and stopwords
lemmatizer = WordNetLemmatizer()
stopwords_set = set(stopwords.words('english'))

# Lemmatizer function
def lemmatize(word):
    lemma = lemmatizer.lemmatize(word, 'v')
    if lemma == word:
        lemma = lemmatizer.lemmatize(word, 'n')
    return lemma

# Text preprocessing function
def text_preprocessing(text):
    # Lowercasing
    text = text.lower()

    # Tokenizing
    words = word_tokenize(text)

    # Lemmatizing and removing stopwords
    new_words = [lemmatize(w) for w in words if w not in stopwords_set]

    return " ".join(new_words)


[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


### 1.2.2 read files

Auxilary functions for reading and pre-processing the data.

In [6]:
import json

def process_claims(claims, evidences_id_dict):
    """
    Process claims data to extract relevant information and map evidence IDs.

    Args:
    claims (dict): A dictionary of claims where each key is a claim ID and each value is a dictionary containing claim details.
    evidences_id_dict (dict): A dictionary mapping evidence IDs to indices for quick access.

    Returns:
    tuple: Contains lists of claim IDs, claim texts, preprocessed claim texts, associated evidence indices, and claim labels.
    """
    ids = []
    texts = []
    processed_texts = []
    evidences = []
    labels = []

    for claim_id, data in claims.items():
        ids.append(claim_id)
        texts.append(data["claim_text"])
        processed_texts.append(text_preprocessing(data["claim_text"]))
        labels.append(data.get("claim_label", None))  # Test data have no labels.
        evidences.append([evidences_id_dict[i] for i in data.get("evidences", [])])

    return ids, texts, processed_texts, evidences, labels

def process_evidences(evidences):
    """
    Process evidences data to extract relevant information and create a mapping from evidence IDs to indices.

    Args:
    evidences (dict): A dictionary of evidences where each key is an evidence ID and the value is the evidence text.

    Returns:
    tuple: Contains lists of evidence IDs, evidence texts, preprocessed evidence texts, and a dictionary mapping IDs to indices.
    """
    ids = []
    texts = []
    processed_texts = []
    id_dict = {}

    for idx, (evidence_id, evidence_text) in enumerate(evidences.items()):
        ids.append(evidence_id)
        texts.append(evidence_text)
        processed_texts.append(text_preprocessing(evidence_text))
        id_dict[evidence_id] = idx

    return ids, texts, processed_texts, id_dict

Use the functions to read the data.

In [7]:
# Load data from files
with open(save_path+'/train-claims.json', 'r') as file:
    train_claims = json.load(file)

with open(save_path+'/evidence.json', 'r') as file:
    evidences = json.load(file)

with open(save_path+'/dev-claims.json', 'r') as file:
    dev_claims = json.load(file)

with open(save_path+'/test-claims-unlabelled.json', 'r') as file:
    test_claims = json.load(file)

# Process evidence data to prepare for linkage with claims
evidences_ids, evidences_texts, evidences_p_texts, evidences_id_dict = process_evidences(evidences)

# Process claims data for training, development, and test sets using the evidence dictionary
train_ids, train_claim_texts, train_p_claim_texts, train_evidences, train_labels = process_claims(train_claims, evidences_id_dict)
dev_ids, dev_claim_texts, dev_p_claim_texts, dev_evidences, dev_labels = process_claims(dev_claims, evidences_id_dict)
test_ids, test_claim_texts, test_p_claim_texts, _, _ = process_claims(test_claims, evidences_id_dict)

### 1.2.3 Construct TFIDF representations

We use the TFIDF to pre-process the data to find potential relevant evidences that can be used for later training.

Here, we transform them into TFIDF representations first.

In [8]:
from sklearn.feature_extraction.text import TfidfVectorizer

# Initialize a TfidfVectorizer.
# max_features limits the number of features (i.e., maximum number of distinct words) to consider to 500,000.
# This helps in managing memory usage and improving processing speed.
vectorizer = TfidfVectorizer(max_features=500000)

# Fit the TfidfVectorizer to the entire set of claim and evidence texts.
# This will compute the IDF (Inverse Document Frequency) values across all given texts,
# and determine the vocabulary from the evidence texts.
# Use all the processed claim and evidence texts to fit TFIDF
vectorizer.fit(evidences_p_texts+train_p_claim_texts)

# Transform the preprocessed texts of training claims into a TF-IDF-weighted document-term matrix.
train_claim_tfidf = vectorizer.transform(train_p_claim_texts)

# Similarly, transform the development and test set preprocessed texts into their respective
# TF-IDF-weighted document-term matrices.
dev_claim_tfidf = vectorizer.transform(dev_p_claim_texts)
test_claim_tfidf = vectorizer.transform(test_p_claim_texts)

# Transform the preprocessed texts of the evidences into a TF-IDF-weighted document-term matrix.
# This allows comparison of claims with evidences based on their textual content.
evidence_tfidf = vectorizer.transform(evidences_p_texts)


### 1.2.4 Find the potential relevant evidences for each claim

We also sort our potential evidences for each claim according to TFIDF. This is can provide potential relevant evidences for later model to train and predict. We calculate the similarity matrix in this function by TFIDF and we do this by splitting the claim into many parts with the length of "claim_splits_length" to avoid the crush of RAM in colab.

In [9]:
import numpy as np

def sort_evidence_candidates(claim_tfidf, evidence_tfidf, claim_splits_length):
    """
    Sorts evidence candidates for claims based on cosine similarity scores between claim and evidence TF-IDF vectors.

    Args:
    claim_tfidf (sparse matrix): A TF-IDF weighted document-term matrix for the claims.
    evidence_tfidf (sparse matrix): A TF-IDF weighted document-term matrix for the evidences.
    claim_splits_length (int): The number of claims to process in each batch.

    Returns:
    list of lists: Each sublist contains the indices of the top 1000 most similar evidences for each claim.
    """
    i = 0
    potential_evidences = []
    # Loop through each batch of claims
    while i * claim_splits_length < claim_tfidf.shape[0]:
        # Calculate cosine similarity between a batch of claims and all evidences.
        # The result is a matrix where each row corresponds to a claim and each column to an evidence.
        cos_sims = np.dot(
            claim_tfidf[i * claim_splits_length:min((i + 1) * claim_splits_length, claim_tfidf.shape[0])],
            evidence_tfidf.transpose()
        ).toarray()

        # For each claim in the current batch, find the indices of the top 1000 most similar evidences.
        for j in range(cos_sims.shape[0]):
            top_potential_evidence_ids = np.argsort(-cos_sims[j]).tolist()[:1000]
            potential_evidences.append(top_potential_evidence_ids)
        i += 1

    return potential_evidences

Use the defined function to get the potential evidences for claims

In [10]:
train_sort_potential_evidences = sort_evidence_candidates(train_claim_tfidf, evidence_tfidf, 100)
dev_sort_potential_evidences = sort_evidence_candidates(dev_claim_tfidf, evidence_tfidf, 100)
test_sort_potential_evidences = sort_evidence_candidates(test_claim_tfidf, evidence_tfidf, 100)

### 1.2.5 Construct vocab and indexing

In this part we construc the vocab indexing for our model. The related code are referred from workshops.

In [11]:
def build_vocabulary(texts, min_count=3):
    """
    Build a vocabulary from a list of texts, filtering words by a minimum count threshold.

    Args:
    texts (list of str): A list of sentences from which to build the vocabulary.
    min_count (int): Minimum occurrence threshold for words to be included in the vocabulary.

    Returns:
    tuple: Two dictionaries - idx2word (maps index to word) and word2idx (maps word to index).
    """
    # Initialize word count dictionary and predefined special tokens.
    wordcount = {}
    idx2word = ["<pad>", "<cls>", "<sep>", "<unk>"]
    word2idx = {"<pad>": 0, "<cls>": 1, "<sep>": 2, "<unk>": 3}

    # Count occurrences of each word in the texts.
    for text in texts:
        for word in text.split():
            wordcount[word] = wordcount.get(word, 0) + 1

    # Start indexing for new words from 4 since 0-3 are reserved for special tokens.
    idx = len(idx2word)

    # Include words in the vocabulary only if they meet the minimum count criteria.
    for word, count in wordcount.items():
        if count > min_count:
            idx2word.append(word)
            word2idx[word] = idx
            idx += 1

    return idx2word, word2idx

# Use the function to build the vocabulary from training and evidence texts.
idx2word, word2idx = build_vocabulary(train_claim_texts + evidences_texts, min_count=3)

In [12]:
def convert_to_indices(text_data, word2idx):
    """
    Convert a list of sentences into lists of indices based on a given word-to-index mapping.

    Args:
    text_data (list of str): A list of sentences to be converted.
    word2idx (dict): A dictionary mapping words to their corresponding indices.

    Returns:
    list of list of int: A list where each sentence is represented as a list of indices.
    """
    # Initialize the list that will store the converted sentences.
    idx_data = []

    # Iterate over each sentence in the input list.
    for text in text_data:
        # Convert each word in the sentence to its corresponding index.
        # If the word is not found in the dictionary, use the index for "<unk>".
        indices = [word2idx.get(word, word2idx["<unk>"]) for word in text.split()]

        # Append the list of indices to the main list.
        idx_data.append(indices)

    return idx_data

Use the defined function to convert the words in texts into indices.

In [13]:
train_claim_text_idx = convert_to_indices(train_claim_texts, word2idx)
dev_claim_text_idx = convert_to_indices(dev_claim_texts, word2idx)
test_claim_text_idx = convert_to_indices(test_claim_texts, word2idx)
evidences_text_idx = convert_to_indices(evidences_texts, word2idx)

Perform a statistics on the length of claim texts and evidence texts. This help us determine the length for padding and truncating.

In [14]:
import numpy as np

def calculate_statistics(text_indices):
    """
    Calculate the average and median lengths of lists of indices.

    Args:
    text_indices (list of list of int): A list where each inner list contains indices for a sentence.

    Returns:
    tuple: A tuple containing the average length and the median length of the sentences.
    """
    # Calculate the lengths of all sentences
    lengths = [len(sentence) for sentence in text_indices]

    # Compute the average length of the sentences
    average_length = np.mean(lengths)

    # Compute the median length of the sentences
    median_length = np.median(lengths)

    max_length = max(lengths)
    return average_length, median_length, max_length

# Calculate statistics for each dataset
train_avg, train_med, train_max = calculate_statistics(train_claim_text_idx)
dev_avg, dev_med, dev_max = calculate_statistics(dev_claim_text_idx)
test_avg, test_med, test_max = calculate_statistics(test_claim_text_idx)
evidence_avg, evidence_med, evidence_max = calculate_statistics(evidences_text_idx)

# Print the statistics
print(f"Train - Average: {train_avg}, Median: {train_med},Max: {train_max}")
print(f"Dev - Average: {dev_avg}, Median: {dev_med},Max: {dev_max}")
print(f"Test - Average: {test_avg}, Median: {test_med},Max: {test_max}")
print(f"Evidence - Average: {evidence_avg}, Median: {evidence_med},Max: {evidence_max}")


Train - Average: 20.09771986970684, Median: 19.0,Max: 67
Dev - Average: 21.084415584415584, Median: 18.0,Max: 65
Test - Average: 20.03921568627451, Median: 19.0,Max: 53
Evidence - Average: 19.691925312720514, Median: 18.0,Max: 479


So we can set the reasonable length for padding and truncating according to our statistics result shown above. But this hyperparameter can also be tuned by experimenting with different values later.

In [15]:
claim_pad_len = 50
evidences_pad_len = 70

In [16]:
def construct_input_text(text_indices, padding_length, word2idx):
    """
    Construct input text arrays with special tokens and padding.

    Args:
    text_indices (list of list of int): A list of sentences, where each sentence is represented by a list of word indices.
    padding_length (int): The fixed length to which all input texts should be padded or truncated.
    word2idx (dict): A dictionary mapping words to their corresponding indices, must include special tokens.

    Returns:
    list of list of int: A list of processed texts, each converted to a list of indices with special tokens and padding.
    """
    idx_data = []
    # Iterate over each sentence represented by its indices.
    for indices in text_indices:
        # Check if the length of the current sentence is less than the padding length.
        if len(indices) < padding_length:
            # If less, pad the sentence. Start with the <cls> token, followed by the original sentence,
            # and a <sep> token, then pad with <pad> tokens up to the required length.
            padded_sentence = ([word2idx["<cls>"]] + indices + [word2idx["<sep>"]] +
                               [word2idx["<pad>"]] * (padding_length - len(indices)))
        else:
            # If the sentence is longer or equal to the padding length, truncate it after adding the initial <cls> token
            # and end it with a <sep> token within the specified padding length.
            padded_sentence = ([word2idx["<cls>"]] + indices[:padding_length] + [word2idx["<sep>"]])

        # Add the processed sentence to the list.
        idx_data.append(padded_sentence)

    return idx_data

In [17]:
train_claim_input = construct_input_text(train_claim_text_idx, claim_pad_len, word2idx)
dev_claim_input = construct_input_text(dev_claim_text_idx, claim_pad_len, word2idx)
test_claim_input = construct_input_text(test_claim_text_idx, claim_pad_len, word2idx)
evidences_input = construct_input_text(evidences_text_idx, evidences_pad_len, word2idx)

Print the length to check whether the length is correct. Here the length is as expected because we added CLS and SEP token no matter whether the length is longer or shorter than padding length.

In [18]:
print(max([len(i) for i in train_claim_input]), max([len(i) for i in dev_claim_input]), max([len(i) for i in test_claim_input]), max([len(i) for i in evidences_input]))

52 52 52 72


## 1.3 Construct the dataloader

In this part, we define how to load the batch of data for training, including the strategy for pairing a claim with negative and positive evidences.

In [19]:
import torch
from torch.utils.data import Dataset
import random

random.seed(90042)

class TrainDataset(Dataset):
    """
    A dataset class for training the machine learning model that requires claim inputs,
    potential relevant evidences, truly associated positive evidences, and a mechanism for handling negative sampling.
    """
    def __init__(self, claim_input_data, evidence_input_data, sort_potential_evidences, positive_evidence, negative_num=10,negative_sample_start_idx=10):
        """
        Initialize the dataset with text and evidence data.

        Args:
        claim_input_data (list): List of text inputs.
        evidence_input_data (list): List of evidence inputs (all evidences).
        sort_potential_evidences (list): List of potential evidence indices sorted by their TF-IDF similarity with claim.
        positive_evidence (list): List of positive evidence for the claim input. (In this retrieval part, the positive means relavant)
        negative_num (int): Number of negative samples of evidences to include.
        """
        self.claim_input_data = claim_input_data
        self.evidence_input_data = evidence_input_data
        self.sort_potential_evidences = sort_potential_evidences
        self.positive_evidence = positive_evidence
        self.negative_num = negative_num
        self.evidence_len = len(evidence_input_data[0]) if evidence_input_data else 0
        self.claim_text_len = len(claim_input_data[0]) if claim_input_data else 0
        self.negative_sample_start_idx = negative_sample_start_idx

    def __len__(self):
        """Return the total number of items in the dataset."""
        return len(self.claim_input_data)

    def __getitem__(self, idx):
        """
        Retrieves an item from the dataset at the specified index.

        Args:
        idx (int): The index of the item.

        Returns:
        tuple: A tuple containing the claim input, selected negative evidences, and corresponding positive evidences.
        """
        # Select negative evidences by sampling from tfidf sorted potential evidence indices excluding top 10.
        negative_samples = random.sample(self.sort_potential_evidences[idx][self.negative_sample_start_idx: self.negative_num*20], self.negative_num)
        return [self.claim_input_data[idx], negative_samples, self.positive_evidence[idx]]

    def collate_fn(self, batch):
        """
        Custom collate function to process a batch of data.

        Args:
        batch (list of tuples): A list of tuples from the __getitem__ method.

        Returns:
        dict: A dictionary containing tensors for claim, evidences, positions, and labels.
        """
        claim, claim_pos, positive_evidences, evidences = [], [], [], []

        # Unpack batch and process data
        for claim_data, negative_samples, positive_evidence in batch:
            claim.append(claim_data)
            claim_pos.append(list(range(self.claim_text_len)))
            positive_evidences.append(positive_evidence)
            evidences.extend(positive_evidence + negative_samples)

        # Create a new unique list of evidences inside this batch
        unique_evidences, evidences_pos = list(set(evidences)), []
        # Map the positive evidence to the indices of this new unique list for convenience of later calculating loss in training phase
        evidences2idx = {evid: idx for idx, evid in enumerate(unique_evidences)}
        positive_evidences = [[evidences2idx[evid] for evid in positive_evidence_set] for positive_evidence_set in positive_evidences]
        # get all evidences used in this batch (Both negative and positive)
        batch_evidences = [self.evidence_input_data[evid] for evid in unique_evidences]
        batch_evidences_pos = [list(range(self.evidence_len)) for _ in unique_evidences]

        # Pack everything into a dictionary as a instance of a batch
        batch_data = {
            "claim_queries": torch.LongTensor(claim),
            "claim_queries_pos": torch.LongTensor(claim_pos),

            "batch_evidences": torch.LongTensor(batch_evidences),
            "batch_evidences_pos": torch.LongTensor(batch_evidences_pos),

            "positive_evidences": positive_evidences
        }

        return batch_data

In [21]:
train_set = TrainDataset(train_claim_input, evidences_input, train_sort_potential_evidences,
                         train_evidences, negative_num=10,negative_sample_start_idx=10)
from torch.utils.data import DataLoader

dataloader = DataLoader(train_set, batch_size=5, shuffle=True, num_workers=1, collate_fn=train_set.collate_fn)

# 2.Model Implementation


Define our transformer based encoder. The code of this part is referred from workshops.

In [22]:
import torch
import torch.nn as nn

class Encoder(nn.Module):
    """
    A Transformer-based encoder module.

    Args:
    vocab_size (int): Size of the vocabulary.
    embed_dim (int): Dimensionality of the embeddings.
    hidden_size (int): Size of the hidden layer.
    nhead (int): Number of attention heads in the TransformerEncoder.
    num_layers (int): Number of layers in the TransformerEncoder.
    max_position (int): Maximum sequence length for positional embeddings.
    """
    def __init__(self, vocab_size, embed_dim, hidden_size, nhead, num_layers, max_position=180):
        super(Encoder, self).__init__()

        # Hidden size attribute for possible external use.
        self.hidden_size = hidden_size

        # Embedding layers for vocabulary and positional encodings.
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_embedding = nn.Embedding(max_position, embed_dim)

        # Transformer encoder layer and complete encoder.
        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=nhead, batch_first=True, dropout=0.1)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, norm=nn.LayerNorm(hidden_size))

    def forward(self, text_data, position_text):
        """
        Forward pass of the encoder model.

        Args:
        text_data (Tensor): Input text data (token indices).
        position_text (Tensor): Positional indices associated with the text data.

        Returns:
        Tensor: Encoded output from the Transformer encoder.
        """
        # Mask for padding tokens (usually zero).
        mask_ = text_data == 0

        # Add word embeddings and positional embeddings.
        text_x = self.embedding(text_data) + self.pos_embedding(position_text)

        # Apply the Transformer encoder.
        x_encoded = self.encoder(text_x, src_key_padding_mask=mask_)
        return x_encoded

In [23]:
vocab_size = len(idx2word)
print(vocab_size)

197728


Initialize our encoder

In [24]:
trans_encoder = Encoder(vocab_size=vocab_size, embed_dim=512, hidden_size=512, nhead=8, num_layers=5, max_position=200)
trans_encoder.cuda()

Encoder(
  (embedding): Embedding(197728, 512)
  (pos_embedding): Embedding(200, 512)
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-4): 5 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
)

### 2.1 Set the related parameters before training

In [25]:
import os
import torch.optim as optim
torch.manual_seed(90042)
torch.cuda.manual_seed_all(90042)
random.seed(90042)
# setting the optimizer for encoder
encoder_optimizer = optim.Adam(trans_encoder.parameters())
# setting the max learning rate
max_lr = 1e-2
for param_group in encoder_optimizer.param_groups:
    param_group['lr'] = max_lr

save_dir = "model_ckpts"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

### 2.2 Training and validation preparation

Validation functions: The functions in this part are used during the training to evaluate the current encoder on dev dataset.

In [26]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np

def get_embeddings(text_idx, model, text_len):
    """
    Generate embeddings for texts using the provided encoder model.

    Args:
    text_idx (list): List of text indices to be embedded.
    model (torch.nn.Module): Encoder model to generate embeddings.
    text_len (int): Fixed length of the text inputs.

    Returns:
    torch.Tensor: Normalized embeddings of the input texts.
    """
    model.eval()
    embeddings = []
    batch_size = 1000

    # Process text indices in batches without tqdm
    for start_idx in range(0, len(text_idx), batch_size):
        end_idx = min(start_idx + batch_size, len(text_idx))
        cur_text = torch.LongTensor(text_idx[start_idx:end_idx]).view(-1, text_len).cuda()
        cur_text_pos = torch.LongTensor([list(range(text_len)) for _ in range(end_idx - start_idx)]).cuda()

        # Generate embeddings
        embedding = model(cur_text, cur_text_pos)
        embedding = embedding[:, 0, :].detach()
        normalized_embedding = F.normalize(embedding, p=2, dim=1).cpu()
        del embedding, cur_text, cur_text_pos
        embeddings.append(normalized_embedding)

    return torch.cat(embeddings, dim=0)

def calculate_fscore(claim_embeddings, evidence_embeddings, dev_sort_potential_evidences, dev_evidences, dev_candis_num, retrieval_num):
    """
    Calculate the F-score for evidence retrieval.

    Args:
    claim_embeddings (torch.Tensor): claim embeddings.
    evidence_embeddings (torch.Tensor): Evidence embeddings.
    dev_sort_potential_evidences (list): Sorted potential evidences.
    dev_evidences (list): Actual evidences.
    dev_candis_num (int): Number of candidate evidences considered.
    retrieval_num (int): Number of retrieved evidences for evaluation.

    Returns:
    float: The average F-score.
    """
    fscores = []
    batch_size = 1000

    # Evaluate F-score in batches
    for start_idx in tqdm(range(0, len(claim_embeddings), batch_size), desc="Calculating F-scores"):
        end_idx = min(start_idx + batch_size, len(claim_embeddings))
        scores = torch.mm(claim_embeddings[start_idx:end_idx], evidence_embeddings.t())

        for i in range(scores.size(0)):
            current_scores = torch.index_select(scores[i], 0, torch.LongTensor(dev_sort_potential_evidences[start_idx+i][:dev_candis_num]))
            topk_ids = torch.argsort(current_scores).tolist()
            select_ids = topk_ids[:retrieval_num]

            pred_evidences = [dev_sort_potential_evidences[start_idx+i][j] for j in select_ids]
            label = dev_evidences[start_idx+i]
            evidence_correct = sum(1 for eid in label if eid in pred_evidences)

            if evidence_correct > 0:
                recall = evidence_correct / len(label)
                precision = evidence_correct / retrieval_num
                fscores.append(2 * precision * recall / (precision + recall))
            else:
                fscores.append(0)

    return np.mean(fscores)

def validate(dev_claim_text_idx, evidence_text_idx, dev_sort_potential_evidences, dev_evidences, encoder_model,
             dev_candis_num = 10, retrieval_num = 5, pre_computed_evidence_embeddings=None):
    """
    Validate the encoder model by computing the F-score for evidence retrieval.

    Args:
    dev_claim_text_idx (list): Development claim text indices.
    evidence_text_idx (list): Evidence text indices.
    dev_sort_potential_evidences (list): Sorted potential evidences.
    dev_evidences (list): Actually relevant evidences.
    encoder_model (torch.nn.Module): Encoder model to use for embeddings.
    pre_computed_evidence_embeddings: Pre-computed evidence embeddings.

    Returns:
    float: The F-score of evidence retrieval.
    """
    encoder_model.eval()
    # Get embeddings for evidences and claims
    if pre_computed_evidence_embeddings is None:
      evidence_embeddings = get_embeddings(evidence_text_idx, encoder_model, len(evidence_text_idx[0]))
    else:
      evidence_embeddings = pre_computed_evidence_embeddings
    claim_query_embeddings = get_embeddings(dev_claim_text_idx, encoder_model, len(dev_claim_text_idx[0]))

    # Calculate F-score for evidence retrieval
    fscore = calculate_fscore(claim_query_embeddings, evidence_embeddings, dev_sort_potential_evidences, dev_evidences, dev_candis_num, retrieval_num)

    print("Evidence Retrieval F-score: {:.3f}\n".format(fscore))

    return fscore


Training starts here. We first define some functions for this training and then use them to train our encoder.

Define loss function which compute the mean value of probabilities of positive evidences.


The model learns to increase the scores (and thus the probabilities) of positive evidences while decreasing those of negative ones.

In [27]:
def compute_loss(claim_embeddings, evidence_embeddings, positive_evidences):
    """
    Compute the training loss given the embeddings and positive evidence indices.
    """
    cos_sims = torch.mm(claim_embeddings, evidence_embeddings.t())
    # softmax turn the similarity into probabilities
    scores = - torch.nn.functional.log_softmax(cos_sims / 0.1, dim=1)
    # the loss we use is the mean value of probabilities of positive evidences.
    losses = [torch.mean(torch.index_select(scores[i], 0, torch.LongTensor(positive_evidences[i]).cuda())) for i in range(len(positive_evidences))]
    return torch.stack(losses).mean()

Here are two auxilary functions used in the training function.

In [28]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm

def adjust_learning_rate(optimizer, step_cnt, warmup_steps, max_lr):
    """
    Adjust the learning rate based on the current step count.
    Args:
    optimizer (torch.optim.Optimizer): The optimizer for which to adjust the learning rate.
    step_cnt (int): The current step count in the training process.
    warmup_steps (int): The number of steps to linearly increase the learning rate.
    max_lr (float): The maximum learning rate after warmup.
    """
    if step_cnt <= warmup_steps:
        lr = step_cnt * (max_lr - 1e-8) / warmup_steps + 1e-8
    else:
        lr = max_lr - (step_cnt - warmup_steps) * 1e-5

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

def perform_model_forward_pass(model, batch):
    """
    Perform the forward pass for the model.
    """
    model.train()
    claim_query_embeddings = model(batch["claim_queries"].cuda(), batch["claim_queries_pos"].cuda())
    evidence_embeddings = model(batch["batch_evidences"].cuda(), batch["batch_evidences_pos"].cuda())
    claim_query_embeddings = torch.nn.functional.normalize(claim_query_embeddings[:, 0, :], p=2, dim=1)
    evidence_embeddings = torch.nn.functional.normalize(evidence_embeddings[:, 0, :], p=2, dim=1)
    return claim_query_embeddings, evidence_embeddings


The training function.

In [30]:
def train_model(model, dataloader, optimizer, epochs, accumulate_step, grad_norm, warmup_steps, report_freq, eval_interval, save_dir):
    """
    Train the model using the provided data loader and optimizer according to the given parameters.
    Args:
    model (torch.nn.Module): The model to train.
    dataloader (torch.utils.data.DataLoader): The DataLoader for providing training data.
    optimizer (torch.optim.Optimizer): The optimizer to use for training.
    epochs (int): Total number of epochs to train.
    accumulate_step (int): Steps to accumulate gradients before performing an update.
    grad_norm (float): Max norm for gradient clipping.
    warmup_steps (int): Number of warm-up steps for learning rate adjustment.
    report_freq (int): Frequency of reporting training statistics.
    eval_interval (int): Interval to perform validation checks.
    save_dir (str): Directory to save model checkpoints.
    """
    scaler = GradScaler()
    step_cnt = 0
    all_step_cnt = 0
    avg_loss = 0
    maximum_f_score = 0

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    for epoch in range(epochs):
        epoch_step = 0

        for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}"):
            step_cnt += 1
            with autocast():
                # Forward pass to get embeddings of claims and evidences and then compute loss
                claim_query_embeddings, evidence_embeddings = perform_model_forward_pass(model, batch)
                loss = compute_loss(claim_query_embeddings, evidence_embeddings, batch["positive_evidences"])
                loss = loss / accumulate_step
            # Backward pass
            scaler.scale(loss).backward()
            avg_loss += loss.item()

            if step_cnt == accumulate_step:
                # Gradient clipping and optimizer step
                if grad_norm > 0:
                    nn.utils.clip_grad_norm_(model.parameters(), grad_norm)

                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

                # Learning rate adjustment
                lr = adjust_learning_rate(optimizer, all_step_cnt, warmup_steps, max_lr)

                step_cnt = 0
                epoch_step += 1
                all_step_cnt += 1

                # Report training status
                if all_step_cnt % report_freq == 0:
                    print(f"\nEpoch: {epoch + 1}, Step: {epoch_step}, Avg Loss: {avg_loss / report_freq:.6f}, Learning Rate: {lr:.6f}\n")
                    avg_loss = 0

                # Evaluation and checkpointing
                if all_step_cnt % eval_interval == 0 and all_step_cnt != 0:
                    f_score = validate(dev_claim_input, evidences_input, dev_sort_potential_evidences, dev_evidences, model)
                    # turn back the model to train mode after evaluation
                    model.train()
                    if f_score > maximum_f_score:
                        maximum_f_score = f_score
                        torch.save(model.state_dict(), os.path.join(save_dir, "best_ckpt.bin"))
                        print(f"New best F-score: {f_score:.4f} at Epoch: {epoch + 1}, Step: {epoch_step}")

            # Clean up to save memory
            del loss, claim_query_embeddings, evidence_embeddings

## 2.3 Training start here!

In [34]:
train_model(trans_encoder, dataloader, encoder_optimizer, epochs=5, accumulate_step=2, grad_norm=4,
            warmup_steps=200, report_freq=15, eval_interval=50, save_dir="model_ckpts")

Epoch 1:   9%|▉         | 22/246 [00:04<00:21, 10.58it/s]


Epoch: 1, Step: 10, Avg Loss: 4.191335, Learning Rate: 0.000450



Epoch 1:  17%|█▋        | 42/246 [00:06<00:17, 11.37it/s]


Epoch: 1, Step: 20, Avg Loss: 4.211146, Learning Rate: 0.000950



Epoch 1:  25%|██▌       | 62/246 [00:07<00:16, 11.14it/s]


Epoch: 1, Step: 30, Avg Loss: 4.197593, Learning Rate: 0.001450



Epoch 1:  33%|███▎      | 82/246 [00:09<00:14, 11.24it/s]


Epoch: 1, Step: 40, Avg Loss: 4.193512, Learning Rate: 0.001950



Epoch 1:  40%|███▉      | 98/246 [00:10<00:12, 11.39it/s]


Epoch: 1, Step: 50, Avg Loss: 4.183458, Learning Rate: 0.002450



Epoch 1:  40%|███▉      | 98/246 [00:30<00:12, 11.39it/s]
Calculating F-scores:   0%|          | 0/1 [00:00<?, ?it/s][A
Calculating F-scores: 100%|██████████| 1/1 [00:02<00:00,  2.43s/it]


Evidence Retrieval F-score: 0.072



Epoch 1:  41%|████      | 101/246 [14:54<4:23:53, 109.20s/it]

New best F-score: 0.0724 at Epoch: 1, Step: 50


Epoch 1:  49%|████▉     | 121/246 [14:56<05:40,  2.72s/it]


Epoch: 1, Step: 60, Avg Loss: 4.161486, Learning Rate: 0.002950



Epoch 1:  57%|█████▋    | 141/246 [14:58<00:17,  6.07it/s]


Epoch: 1, Step: 70, Avg Loss: 4.194533, Learning Rate: 0.003450



Epoch 1:  65%|██████▌   | 161/246 [15:00<00:08, 10.51it/s]


Epoch: 1, Step: 80, Avg Loss: 4.223937, Learning Rate: 0.003950



Epoch 1:  74%|███████▎  | 181/246 [15:01<00:05, 10.86it/s]


Epoch: 1, Step: 90, Avg Loss: 4.210349, Learning Rate: 0.004450



Epoch 1:  81%|████████  | 199/246 [15:03<00:04, 10.93it/s]


Epoch: 1, Step: 100, Avg Loss: 4.204441, Learning Rate: 0.004950



Epoch 1:  81%|████████  | 199/246 [15:20<00:04, 10.93it/s]
Calculating F-scores:   0%|          | 0/1 [00:00<?, ?it/s][A
Calculating F-scores: 100%|██████████| 1/1 [00:02<00:00,  2.37s/it]


Evidence Retrieval F-score: 0.073



Epoch 1:  82%|████████▏ | 201/246 [29:59<1:34:44, 126.33s/it]

New best F-score: 0.0729 at Epoch: 1, Step: 100


Epoch 1:  90%|████████▉ | 221/246 [30:01<01:09,  2.77s/it]


Epoch: 1, Step: 110, Avg Loss: 4.216748, Learning Rate: 0.005450



Epoch 1:  98%|█████████▊| 241/246 [30:03<00:00,  5.83it/s]


Epoch: 1, Step: 120, Avg Loss: 4.217640, Learning Rate: 0.005950



Epoch 1: 100%|██████████| 246/246 [30:03<00:00,  7.33s/it]
Epoch 2:   6%|▌         | 15/246 [00:01<00:21, 10.91it/s]


Epoch: 2, Step: 7, Avg Loss: 4.177513, Learning Rate: 0.006450



Epoch 2:  14%|█▍        | 35/246 [00:03<00:18, 11.12it/s]


Epoch: 2, Step: 17, Avg Loss: 4.172003, Learning Rate: 0.006950



Epoch 2:  22%|██▏       | 53/246 [00:04<00:17, 10.87it/s]


Epoch: 2, Step: 27, Avg Loss: 4.188939, Learning Rate: 0.007450



Epoch 2:  22%|██▏       | 53/246 [00:16<00:17, 10.87it/s]
Calculating F-scores:   0%|          | 0/1 [00:00<?, ?it/s][A
Calculating F-scores: 100%|██████████| 1/1 [00:02<00:00,  2.91s/it]


Evidence Retrieval F-score: 0.092



Epoch 2:  22%|██▏       | 54/246 [14:59<8:25:35, 158.00s/it]

New best F-score: 0.0918 at Epoch: 2, Step: 27


Epoch 2:  31%|███       | 76/246 [15:01<07:47,  2.75s/it]


Epoch: 2, Step: 37, Avg Loss: 4.206278, Learning Rate: 0.007950



Epoch 2:  39%|███▉      | 96/246 [15:03<00:24,  6.03it/s]


Epoch: 2, Step: 47, Avg Loss: 4.228678, Learning Rate: 0.008450



Epoch 2:  47%|████▋     | 116/246 [15:05<00:12, 10.37it/s]


Epoch: 2, Step: 57, Avg Loss: 4.213851, Learning Rate: 0.008950



Epoch 2:  55%|█████▌    | 136/246 [15:07<00:10, 10.80it/s]


Epoch: 2, Step: 67, Avg Loss: 4.219225, Learning Rate: 0.009450



Epoch 2:  62%|██████▏   | 152/246 [15:08<00:08, 11.24it/s]


Epoch: 2, Step: 77, Avg Loss: 4.173221, Learning Rate: 0.009950



Epoch 2:  62%|██████▏   | 152/246 [15:26<00:08, 11.24it/s]
Calculating F-scores:   0%|          | 0/1 [00:00<?, ?it/s][A
Calculating F-scores: 100%|██████████| 1/1 [00:03<00:00,  3.84s/it]
Epoch 2:  63%|██████▎   | 154/246 [29:55<3:24:08, 133.14s/it]

Evidence Retrieval F-score: 0.080



Epoch 2:  72%|███████▏  | 176/246 [29:58<02:16,  1.95s/it]


Epoch: 2, Step: 87, Avg Loss: 4.185212, Learning Rate: 0.009910



Epoch 2:  80%|███████▉  | 196/246 [30:00<00:07,  6.85it/s]


Epoch: 2, Step: 97, Avg Loss: 4.204511, Learning Rate: 0.009810



Epoch 2:  88%|████████▊ | 216/246 [30:02<00:02, 10.79it/s]


Epoch: 2, Step: 107, Avg Loss: 4.202397, Learning Rate: 0.009710



Epoch 2:  96%|█████████▌| 236/246 [30:03<00:00, 10.94it/s]


Epoch: 2, Step: 117, Avg Loss: 4.222156, Learning Rate: 0.009610



Epoch 2: 100%|██████████| 246/246 [30:05<00:00,  7.34s/it]
Epoch 3:   3%|▎         | 7/246 [00:00<00:26,  9.14it/s]


Epoch: 3, Step: 4, Avg Loss: 4.169816, Learning Rate: 0.009510




Calculating F-scores:   0%|          | 0/1 [00:00<?, ?it/s][A
Calculating F-scores: 100%|██████████| 1/1 [00:02<00:00,  2.80s/it]


Evidence Retrieval F-score: 0.085



Epoch 3:  12%|█▏        | 30/246 [14:48<07:01,  1.95s/it]


Epoch: 3, Step: 14, Avg Loss: 4.165401, Learning Rate: 0.009410



Epoch 3:  20%|██        | 50/246 [14:50<00:27,  7.08it/s]


Epoch: 3, Step: 24, Avg Loss: 4.187390, Learning Rate: 0.009310



Epoch 3:  28%|██▊       | 70/246 [14:52<00:15, 11.11it/s]


Epoch: 3, Step: 34, Avg Loss: 4.198667, Learning Rate: 0.009210



Epoch 3:  37%|███▋      | 90/246 [14:54<00:14, 11.09it/s]


Epoch: 3, Step: 44, Avg Loss: 4.190304, Learning Rate: 0.009110



Epoch 3:  43%|████▎     | 106/246 [14:55<00:12, 11.11it/s]


Epoch: 3, Step: 54, Avg Loss: 4.212715, Learning Rate: 0.009010



Epoch 3:  43%|████▎     | 106/246 [15:11<00:12, 11.11it/s]
Calculating F-scores:   0%|          | 0/1 [00:00<?, ?it/s][A
Calculating F-scores: 100%|██████████| 1/1 [00:02<00:00,  2.93s/it]


Evidence Retrieval F-score: 0.074



Epoch 3:  52%|█████▏    | 129/246 [29:41<05:18,  2.72s/it]


Epoch: 3, Step: 64, Avg Loss: 4.192585, Learning Rate: 0.008910



Epoch 3:  61%|██████    | 149/246 [29:43<00:16,  5.99it/s]


Epoch: 3, Step: 74, Avg Loss: 4.197869, Learning Rate: 0.008810



Epoch 3:  69%|██████▊   | 169/246 [29:45<00:07, 10.68it/s]


Epoch: 3, Step: 84, Avg Loss: 4.197557, Learning Rate: 0.008710



Epoch 3:  77%|███████▋  | 189/246 [29:47<00:05, 11.01it/s]


Epoch: 3, Step: 94, Avg Loss: 4.201171, Learning Rate: 0.008610



Epoch 3:  84%|████████▍ | 207/246 [29:49<00:03, 11.03it/s]


Epoch: 3, Step: 104, Avg Loss: 4.204408, Learning Rate: 0.008510



Epoch 3:  84%|████████▍ | 207/246 [30:01<00:03, 11.03it/s]
Calculating F-scores:   0%|          | 0/1 [00:00<?, ?it/s][A
Calculating F-scores: 100%|██████████| 1/1 [00:02<00:00,  2.71s/it]


Evidence Retrieval F-score: 0.090



Epoch 3:  93%|█████████▎| 229/246 [44:34<00:46,  2.73s/it]


Epoch: 3, Step: 114, Avg Loss: 4.210955, Learning Rate: 0.008410



Epoch 3: 100%|██████████| 246/246 [44:36<00:00, 10.88s/it]
Epoch 4:   1%|          | 3/246 [00:00<00:27,  8.69it/s]


Epoch: 4, Step: 1, Avg Loss: 4.187757, Learning Rate: 0.008310



Epoch 4:   9%|▉         | 23/246 [00:02<00:20, 11.06it/s]


Epoch: 4, Step: 11, Avg Loss: 4.183949, Learning Rate: 0.008210



Epoch 4:  17%|█▋        | 43/246 [00:04<00:18, 10.98it/s]


Epoch: 4, Step: 21, Avg Loss: 4.198098, Learning Rate: 0.008110



Epoch 4:  25%|██▍       | 61/246 [00:05<00:17, 10.80it/s]


Epoch: 4, Step: 31, Avg Loss: 4.191273, Learning Rate: 0.008010



Epoch 4:  25%|██▍       | 61/246 [00:24<00:17, 10.80it/s]
Calculating F-scores:   0%|          | 0/1 [00:00<?, ?it/s][A
Calculating F-scores: 100%|██████████| 1/1 [00:03<00:00,  3.69s/it]


Evidence Retrieval F-score: 0.095



Epoch 4:  26%|██▌       | 63/246 [15:08<6:28:18, 127.32s/it]

New best F-score: 0.0953 at Epoch: 4, Step: 31


Epoch 4:  34%|███▎      | 83/246 [15:10<07:34,  2.79s/it]


Epoch: 4, Step: 41, Avg Loss: 4.174054, Learning Rate: 0.007910



Epoch 4:  42%|████▏     | 103/246 [15:12<00:23,  5.98it/s]


Epoch: 4, Step: 51, Avg Loss: 4.195454, Learning Rate: 0.007810



Epoch 4:  50%|█████     | 123/246 [15:14<00:12, 10.09it/s]


Epoch: 4, Step: 61, Avg Loss: 4.191637, Learning Rate: 0.007710



Epoch 4:  58%|█████▊    | 143/246 [15:16<00:09, 10.39it/s]


Epoch: 4, Step: 71, Avg Loss: 4.208059, Learning Rate: 0.007610



Epoch 4:  65%|██████▌   | 161/246 [15:17<00:07, 10.83it/s]


Epoch: 4, Step: 81, Avg Loss: 4.197741, Learning Rate: 0.007510



Epoch 4:  65%|██████▌   | 161/246 [15:34<00:07, 10.83it/s]
Calculating F-scores:   0%|          | 0/1 [00:00<?, ?it/s][A
Calculating F-scores: 100%|██████████| 1/1 [00:02<00:00,  2.71s/it]
Epoch 4:  66%|██████▌   | 162/246 [30:01<3:38:25, 156.02s/it]

Evidence Retrieval F-score: 0.076



Epoch 4:  75%|███████▍  | 184/246 [30:03<02:48,  2.72s/it]


Epoch: 4, Step: 91, Avg Loss: 4.210366, Learning Rate: 0.007410



Epoch 4:  83%|████████▎ | 204/246 [30:05<00:07,  5.97it/s]


Epoch: 4, Step: 101, Avg Loss: 4.203722, Learning Rate: 0.007310



Epoch 4:  91%|█████████ | 224/246 [30:07<00:02, 10.60it/s]


Epoch: 4, Step: 111, Avg Loss: 4.182894, Learning Rate: 0.007210



Epoch 4:  99%|█████████▉| 244/246 [30:09<00:00, 11.15it/s]


Epoch: 4, Step: 121, Avg Loss: 4.179206, Learning Rate: 0.007110



Epoch 4: 100%|██████████| 246/246 [30:09<00:00,  7.36s/it]
Epoch 5:   6%|▌         | 15/246 [00:01<00:21, 10.76it/s]


Epoch: 5, Step: 8, Avg Loss: 4.174708, Learning Rate: 0.007010



Epoch 5:   6%|▌         | 15/246 [00:15<00:21, 10.76it/s]
Calculating F-scores:   0%|          | 0/1 [00:00<?, ?it/s][A
Calculating F-scores: 100%|██████████| 1/1 [00:02<00:00,  2.32s/it]
Epoch 5:   7%|▋         | 16/246 [14:43<10:33:32, 165.27s/it]

Evidence Retrieval F-score: 0.079



Epoch 5:  15%|█▌        | 38/246 [14:45<09:25,  2.72s/it]


Epoch: 5, Step: 18, Avg Loss: 4.203152, Learning Rate: 0.006910



Epoch 5:  24%|██▎       | 58/246 [14:47<00:30,  6.10it/s]


Epoch: 5, Step: 28, Avg Loss: 4.190424, Learning Rate: 0.006810



Epoch 5:  32%|███▏      | 78/246 [14:49<00:16, 10.47it/s]


Epoch: 5, Step: 38, Avg Loss: 4.211583, Learning Rate: 0.006710



Epoch 5:  40%|███▉      | 98/246 [14:51<00:13, 10.89it/s]


Epoch: 5, Step: 48, Avg Loss: 4.176031, Learning Rate: 0.006610



Epoch 5:  46%|████▋     | 114/246 [14:52<00:12, 10.93it/s]


Epoch: 5, Step: 58, Avg Loss: 4.207781, Learning Rate: 0.006510



Epoch 5:  46%|████▋     | 114/246 [15:05<00:12, 10.93it/s]
Calculating F-scores:   0%|          | 0/1 [00:00<?, ?it/s][A
Calculating F-scores: 100%|██████████| 1/1 [00:02<00:00,  2.32s/it]
Epoch 5:  47%|████▋     | 116/246 [29:33<4:46:28, 132.22s/it]

Evidence Retrieval F-score: 0.075



Epoch 5:  56%|█████▌    | 138/246 [29:35<04:51,  2.70s/it]


Epoch: 5, Step: 68, Avg Loss: 4.178936, Learning Rate: 0.006410



Epoch 5:  64%|██████▍   | 158/246 [29:37<00:14,  6.07it/s]


Epoch: 5, Step: 78, Avg Loss: 4.192685, Learning Rate: 0.006310



Epoch 5:  72%|███████▏  | 178/246 [29:39<00:06, 10.88it/s]


Epoch: 5, Step: 88, Avg Loss: 4.196143, Learning Rate: 0.006210



Epoch 5:  80%|████████  | 198/246 [29:41<00:04, 10.71it/s]


Epoch: 5, Step: 98, Avg Loss: 4.201349, Learning Rate: 0.006110



Epoch 5:  87%|████████▋ | 214/246 [29:42<00:02, 10.83it/s]


Epoch: 5, Step: 108, Avg Loss: 4.190344, Learning Rate: 0.006010



Epoch 5:  87%|████████▋ | 214/246 [29:55<00:02, 10.83it/s]
Calculating F-scores:   0%|          | 0/1 [00:00<?, ?it/s][A
Calculating F-scores: 100%|██████████| 1/1 [00:02<00:00,  2.35s/it]
Epoch 5:  88%|████████▊ | 216/246 [44:26<1:06:17, 132.59s/it]

Evidence Retrieval F-score: 0.087



Epoch 5:  97%|█████████▋| 238/246 [44:28<00:21,  2.71s/it]


Epoch: 5, Step: 118, Avg Loss: 4.206174, Learning Rate: 0.005910



Epoch 5: 100%|██████████| 246/246 [44:29<00:00, 10.85s/it]


# 3.Testing and Evaluation


### 3.1 Pre-compute Evidence embeddings

Load the best check point that we got from the training

In [31]:
import os
trans_encoder.load_state_dict(torch.load(os.path.join(save_dir, "best_ckpt.bin")))

trans_encoder.cuda()
trans_encoder.eval()

Encoder(
  (embedding): Embedding(197728, 512)
  (pos_embedding): Embedding(200, 512)
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-4): 5 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
)

Now to make prediction and evaluation, we need to encode all evidences using the trained encoder for later prediction on dev and test dataset.

In [32]:
evidence_embeddings = []
evidence_embeddings = get_embeddings(evidences_input, trans_encoder, len(evidences_input[0]))

### 3.2 Evaluate on dev dataset

Use the validate function that we have defined in the training part to evaluate our trained model on dev dataset.

In [34]:
fscore = validate(dev_claim_input, evidences_input, dev_sort_potential_evidences, dev_evidences, trans_encoder,
                  dev_candis_num=8, retrieval_num=5, pre_computed_evidence_embeddings=evidence_embeddings)
print(fscore)

Calculating F-scores: 100%|██████████| 1/1 [00:02<00:00,  2.82s/it]

Evidence Retrieval F-score: 0.102

0.10166975881261596





### 3.3 Retrieve the relevant evidences for test dataset

We then define a function for predicting the retrieved relevant evidences.

In [33]:
def evidence_predicts(dev_text_idx, evidences_embeddings, dev_sort_evidences, evidences_ids, encoder_model, dev_candis_num, retrieval_num, batch_size=100):
    """
    Predict the most relevant evidences for a set of development texts using batch processing.

    Args:
    dev_text_idx (list of list of int): Development text indices for queries.
    evidences_embeddings (torch.Tensor): Pre-computed embeddings for all possible evidences.
    dev_sort_evidences (list of list of int): Sorted indices of potential evidences for each claim, based on TFIDF.
    evidences_ids (list): List of original IDs corresponding to the evidences.
    encoder_model (torch.nn.Module): Model to generate embeddings for the queries.
    dev_candis_num (int): Number of candidate evidences to consider for final ranking.
    retrieval_num (int): Number of top evidences to retrieve.
    batch_size (int): The size of each batch to process.

    Returns:
    list of list: A list containing the lists of predicted evidence IDs for each development text.
    """
    encoder_model.eval()
    text_len = len(dev_text_idx[0])
    preds = []

    # Processing in batches to manage memory consumption
    for start_idx in range(0, len(dev_text_idx), batch_size):
        end_idx = min(start_idx + batch_size, len(dev_text_idx))

        # Generate embeddings for the current batch of queries
        batch_text_idx = dev_text_idx[start_idx:end_idx]
        claim_embeddings = get_embeddings(batch_text_idx, encoder_model, text_len)

        # Compute cosine similarity scores between batch query embeddings and all evidence embeddings
        scores = torch.mm(claim_embeddings, evidences_embeddings.t())

        # Determine the top evidences based on the scores for each query in the batch
        for i in range(scores.size(0)):
            # Select top candidate evidences scores
            candidate_scores = torch.index_select(scores[i], 0, torch.LongTensor(dev_sort_evidences[start_idx + i][:dev_candis_num]))
            topk_ids = torch.argsort(candidate_scores, descending=True).tolist()
            select_ids = topk_ids[:retrieval_num]

            # Map back to original evidence IDs
            pred_evidences = [evidences_ids[j] for j in select_ids]
            preds.append(pred_evidences)

    return preds

Use this function to get the retrieved evidences for dev and test dataset.

In [35]:
dev_evidences_ids_predicted = evidence_predicts(dev_claim_input, evidence_embeddings, dev_sort_potential_evidences, evidences_ids,
                                                trans_encoder,dev_candis_num=10,retrieval_num=5)
test_evidences_ids_predicted = evidence_predicts(test_claim_input, evidence_embeddings, test_sort_potential_evidences, evidences_ids,
                                                 trans_encoder,dev_candis_num=10,retrieval_num=5)

Store the predictions into the dictionary for dev and test claims

In [36]:
dev_claims = json.load(open("data/dev-claims.json", "r"))
test_claims = json.load(open("data/test-claims-unlabelled.json", "r"))

def store_predictions(claims_data, ids, predicted_evidences_ids):
    """
    Populate a dictionary with claims data where each claim is updated with associated evidence IDs.

    Args:
    claims_data (dict): Dictionary of claims.
    ids (list): List of claim IDs.
    evidences_ids (list of list of int): List of evidence IDs for each claim.

    Returns:
    dict: A dictionary with updated claims including their associated evidences.
    """
    predictions = {}
    for idx, evidence_ids in enumerate(predicted_evidences_ids):

        cur_data = claims_data[ids[idx]].copy()  # Use copy to avoid modifying the original data
        cur_data['evidences'] = evidence_ids
        predictions[ids[idx]] = cur_data
    return predictions

pred_dev_claims = store_predictions(dev_claims, dev_ids, dev_evidences_ids_predicted)
pred_test_claims = store_predictions(test_claims, test_ids, test_evidences_ids_predicted)

Store the prediction result for evidence retrieval as a json file

In [37]:
json.dump(pred_dev_claims, open("pred_dev_claims_retrieval.json", "w"))
json.dump(pred_test_claims, open("pred_test_claims_retrieval.json", "w"))

### 3.4 Preparation for Claim Classification

At the last of evidence retrieval, we also need to retrieve the relevant evidences for training dataset fot the next part ---- Claim Classification.

We will need to use the retrieved evidences and the claim together to train our next claim classification model.

Here, we only keep the wrongly predicted relevant evidences. We will use the truly relevant evidences for training claim classification but also include some of these wrongly predicted relevant evidences as we will have to use the predicted relevant evidences for test so we hope our classification model should have seen these wrongly predicted evidences during training.

In [38]:
train_evidences_ids_predicted = evidence_predicts(train_claim_input, evidence_embeddings, train_sort_potential_evidences,
                                                  evidences_ids, trans_encoder,dev_candis_num=8,retrieval_num=5)

def get_negative_evidences(evidences_ids_predicted, true_evidences, evidences_id_dict):
    """
    Identify negative evidences for each set of predictions based on the true evidence IDs.

    Args:
    evidences_ids_predicted (list of list of int): Predicted evidence IDs for each training example.
    true_evidences (list of list of int): Actual evidence IDs for each training example.
    evidences_id_dict (dict): Dictionary mapping evidence IDs to their respective indices or transformed IDs.

    Returns:
    list of list of int: List containing lists of negative evidence IDs for each training example.
    """
    pred_negative_evidences = []
    for idx, evidence_ids in enumerate(evidences_ids_predicted):
        # Temp list to store negative evidences for the current training example
        temp_ = []
        # Convert the set of true evidences for current training example for faster lookup
        true_evidence_set = set(true_evidences[idx])
        for i in evidence_ids:
            # Check if the mapped evidence ID is not in the set of true evidences
            if evidences_id_dict[i] not in true_evidence_set:
                temp_.append(evidences_id_dict[i])
        pred_negative_evidences.append(temp_)
    return pred_negative_evidences

train_wrongly_pred_evidences = get_negative_evidences(train_evidences_ids_predicted, train_evidences, evidences_id_dict)

In [39]:
## save wrongly prediction data
json.dump(train_wrongly_pred_evidences, open("pred_train_wrongly_pred_evidences.json", "w"))

### 3.5 Saving retrieved evidences for dev and test

We also save the retrieved data for later claim classification.

Here, we concatenate the retirieved evidences together with the claim.

In [40]:
def prepare_classification_data(claim_text_idx, evidences_ids_predicted, evidences_text_idx, evidences_id_dict, labels=None):
    """
    Prepare data for classification by processing claim and evidence indices, adding special tokens, and padding.

    Args:
    claim_text_idx (list of list of int): Indices of claims.
    evidences_ids_predicted (list of list of int): Predicted evidence IDs for each claim.
    evidences_text_idx (dict): Dictionary mapping evidence IDs to their text indices.
    evidences_id_dict (dict): Dictionary mapping evidence original IDs to their indices in `evidences_text_idx`.
    labels (list of int, optional): Labels for each claim, if available (for training data).

    Returns:
    list of dict: List of dictionaries, each containing processed text and optionally a label.
    """
    concatenated_claim_evidences = []
    text_max_len = 60
    evidence_max_len = 100
    # limit the max length for the concatenated claim and evidences
    # claim length + evidence length * 5 + special tokens * 6
    # 60 + 100*5 + 6 = 566
    all_max_len = 570

    for idx, claim in enumerate(claim_text_idx):
        cur_data = {}
        if labels:
            cur_data['label'] = labels[idx]
        temp_text = [word2idx["<cls>"]] + claim[:text_max_len]

        # concatenate the retrieved evidences together with the claim to construct the input for claim classification
        for predicted_evidence_id in evidences_ids_predicted[idx]:
            temp_text.extend([word2idx["<sep>"]] + evidences_text_idx[evidences_id_dict[predicted_evidence_id]][:evidence_max_len])
        temp_text.append(word2idx["<sep>"])

        # Padding the sequence if necessary
        if len(temp_text) < all_max_len:
            temp_text.extend([word2idx["<pad>"]] * (all_max_len - len(temp_text)))

        cur_data['text'] = temp_text
        concatenated_claim_evidences.append(cur_data)

    return concatenated_claim_evidences

In [41]:
# Use the function to prepare data
dev_concatenated_claim_evidences = prepare_classification_data(dev_claim_text_idx, dev_evidences_ids_predicted, evidences_text_idx, evidences_id_dict, dev_labels)
test_concatenated_claim_evidences = prepare_classification_data(test_claim_text_idx, test_evidences_ids_predicted, evidences_text_idx, evidences_id_dict)

# Write the processed data to JSON files
json.dump(dev_concatenated_claim_evidences, open("dev_concatenated_claim_evidences.json", "w"))
json.dump(test_concatenated_claim_evidences, open("test_concatenated_claim_evidences.json", "w"))

## Object Oriented Programming codes here

*You can use multiple code snippets. Just add more if needed*