# 2024 COMP90042 Project
*Make sure you change the file name with your group id.*

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Readme
*If there is something to be noted for the marker, please mention here.*

*If you are planning to implement a program with Object Oriented Programming style, please put those the bottom of this ipynb file*

**We use pytorch, nltk, scikit-learn in this project.**

# 1.DataSet Processing
(You can add as many code blocks and text blocks as you need. However, YOU SHOULD NOT MODIFY the section title)

## PreProcess for evidence and claims

### preprocessing function

In [2]:
import torch
print("CUDA available:", torch.cuda.is_available())

CUDA available: True


In [3]:
print("Current CUDA device:", torch.cuda.current_device())
print("Device count:", torch.cuda.device_count())
print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))

Current CUDA device: 0
Device count: 1
Device name: NVIDIA L4


### read files

In [4]:
import json
import nltk
import string
import re
import numpy as np
import torch
from sklearn.metrics.pairwise import cosine_similarity
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from collections import Counter
from statistics import mean

nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')

with open('data/train-claims.json', 'r') as input_file:
    train_claims = json.load(input_file)

# Read in development data (claim)
with open('data/dev-claims.json', 'r') as input_file:
    dev_claims = json.load(input_file)

# Read in test data (claim)
with open('data/test-claims-unlabelled.json', 'r') as input_file:
    test_claims = json.load(input_file)

# Read in evidence data
with open('data/evidence.json', 'r') as input_file:
    evidences = json.load(input_file)

#EDA
claim_count = 0
evi_count = 0
claim_length = []
evidence_count = []
evidence_length = []
labels = []

for key,value in train_claims.items():
    claim_count+=1
    claim_length.append(len(value["claim_text"]))
    evidence_count.append(len(value["evidences"]))
    evidence_length += [len(evidences[x]) for x in value["evidences"]]
    labels.append(value["claim_label"])

for key,value in evidences.items():
    evi_count+=1

print("claim count: ",claim_count)
print("evidence count: ",evi_count)
print("max claim length: ",max(claim_length))
print("min claim length: ",min(claim_length))
print("mean claim length: ",mean(claim_length))
print("max evidence count: ",max(evidence_count))
print("min evidence count: ",min(evidence_count))
print("mean evidence count: ",mean(evidence_count))
print("max evidence length: ",max(evidence_length))
print("min evidence length: ",min(evidence_length))
print("mean evidence length: ",mean(evidence_length))
print(Counter(labels))

inside = 0
outside = 0

train_evi_id = []
for claim_id,claim_value in train_claims.items():
    train_evi_id=train_evi_id+claim_value['evidences']

for claim_id,claim_value in dev_claims.items():
    test_evi_id=claim_value['evidences']
    for e in test_evi_id:
        if e in train_evi_id:
            inside += 1
        else:
            outside += 1
print("Dev evi inside train evi", inside)
print("Dev evi outside train evi", outside)

full_evidence_id = list(evidences.keys())
full_evidence_text  = list(evidences.values())
train_claim_id = list(train_claims.keys())
train_claim_text  = [ v["claim_text"] for v in train_claims.values()]
print("Train claim count: ",len(train_claim_id))

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package wordnet to /root/nltk_data...


claim count:  1228
evidence count:  1208827
max claim length:  332
min claim length:  26
mean claim length:  122.95521172638436
max evidence count:  5
min evidence count:  1
mean evidence count:  3.3566775244299674
max evidence length:  1979
min evidence length:  13
mean evidence length:  173.5
Counter({'SUPPORTS': 519, 'NOT_ENOUGH_INFO': 386, 'REFUTES': 199, 'DISPUTED': 124})
Dev evi inside train evi 163
Dev evi outside train evi 328
Train claim count:  1228


In [5]:
lemmatizer = nltk.stem.wordnet.WordNetLemmatizer()
stopwords = set(stopwords.words('english'))

def lemmatize_text(text):
    words = nltk.word_tokenize(text)
    lemmatized_words = [lemmatizer.lemmatize(word) for word in words]
    lemmatized_text = ' '.join(lemmatized_words)
    return lemmatized_text

# def is_pure_english(text):
#     english_letters = set(string.ascii_letters)
#     cleaned_text = ''.join(char for char in text if char.isalpha() or char.isspace())
#     return all(char in english_letters or char.isspace() for char in cleaned_text)

# def remove_non_eng(dictionary):
#     eng_data = {}
#     for key, value in dictionary.items():
#         if is_pure_english(value):
#             eng_data[key] = value
#     return eng_data

# def contains_climate_keywords(text, keywords):
#     text = text.lower()
#     for keyword in keywords:
#         if re.search(r"\b" + re.escape(keyword) + r"\b", text):
#             return True
#     return False

# def filter_climate_related(dictionary, keywords):
#     cs_data = {}
#     for key, value in dictionary.items():
#         if contains_climate_keywords(value, keywords):
#             cs_data[key] = value
#     return cs_data

def preprocessing_text(text):

     # Convert to lowercase
    text = text.lower()

     # Lemmatize the text
    text = lemmatize_text(text)

    # Remove leading/trailing whitespaces
    text = text.strip()
    return text

In [6]:
climate_keywords = [
    "climate", "environment", "global warming", "greenhouse effect", "carbon", "co2", "carbon dioxide",
    "methane", "renewable energy", "sustainability", "ecology", "biodiversity", "fossil fuels",
    "emissions", "air quality", "ozone", "solar energy", "wind energy", "climate change", "climate crisis",
    "climate adaptation", "climate mitigation", "ocean", "sea levels", "ice melting", "deforestation",
    "reforestation", "pollution"," electricity","energy","solar","wind","renewable","fossil","fuel","emission","air","quality","ozone","solar","wind","climate","change","crisis","adaptation","mitigation","ocean","sea","level","ice","melt","deforestation",
]


# def preprocess_claim_data(claim_data, existed_evidences_id=None):
#     claim_data = remove_non_eng(claim_data)
#     claim_data_text = []
#     claim_data_id = []
#     claim_data_label = []
#     claim_evidences = []
#     for key in claim_data.keys():
#         claim_data[key]["claim_text"] = preprocessing_text(claim_data[key]["claim_text"])
#         claim_data_text.append(claim_data[key]["claim_text"])
#         claim_data_id.append(key)
#         if "claim_label" in claim_data[key]:
#             claim_data_label.append(claim_data[key]["claim_label"])
#         else:
#             claim_data_label.append(None)
#         if existed_evidences_id and "evidences" in claim_data[key]:
#             valid_evidences = [existed_evidences_id[i] for i in claim_data[key]["evidences"] if i in existed_evidences_id]
#             claim_evidences.append(valid_evidences)
#         else:
#             claim_evidences.append([])
#     return claim_data_text, claim_data_id, claim_data_label, claim_evidences

def preprocess_claim_data(claim_data, existed_evidences_id=None):
    claim_data_text = []
    claim_data_id = []
    claim_data_label = []
    claim_evidences = []
    for key in claim_data.keys():
        claim_data[key]["claim_text"] = preprocessing_text(claim_data[key]["claim_text"])
        claim_data_text.append(claim_data[key]["claim_text"])
        claim_data_id.append(key)
        if "claim_label" in claim_data[key]:
            claim_data_label.append(claim_data[key]["claim_label"])
        else:
            claim_data_label.append(None)
        if existed_evidences_id and "evidences" in claim_data[key]:
            valid_evidences = [existed_evidences_id[i] for i in claim_data[key]["evidences"] if i in existed_evidences_id]
            claim_evidences.append(valid_evidences)
        else:
            claim_evidences.append([])
    return claim_data_text, claim_data_id, claim_data_label, claim_evidences

# def preprocess_evi_data(evi_data, climate_keywords):
#     evi_data = remove_non_eng(evi_data)
#     cs_evi_data = filter_climate_related(evi_data, climate_keywords)

#     for key in cs_evi_data.keys():
#         cs_evi_data[key] = preprocessing_text(cs_evi_data[key])

#     cleaned_evidence_text = list(cs_evi_data.values())
#     cleaned_evidence_id = list(cs_evi_data.keys())
#     return cleaned_evidence_text, cleaned_evidence_id

def preprocess_evi_data(evi_data):
    cleaned_evidence_text = []
    cleaned_evidence_id = []
    for key, value in evi_data.items():
        cleaned_text = preprocessing_text(value)
        cleaned_evidence_text.append(cleaned_text)
        cleaned_evidence_id.append(key)
    return cleaned_evidence_text, cleaned_evidence_id

In [7]:
cleaned_evidence_text, cleaned_evidence_id = preprocess_evi_data(evidences)

evidences_id_dict = {evidence_id: idx for idx, evidence_id in enumerate(cleaned_evidence_id)}

train_claim_text, train_claim_id, train_claim_label, train_claim_evidences = preprocess_claim_data(train_claims, evidences_id_dict)

dev_claim_text, dev_claim_id, dev_claim_label, dev_claim_evidences = preprocess_claim_data(dev_claims, evidences_id_dict)

test_claim_text, test_claim_id, _, _ = preprocess_claim_data(test_claims)

In [8]:
print("Number of claims after removing non-English:", len(train_claim_evidences))
print("Number of claims after preprocessing:", len(train_claim_text))

Number of claims after removing non-English: 1228
Number of claims after preprocessing: 1228


In [9]:
print(train_claim_text[0:10])

['not only is there no scientific evidence that co2 is a pollutant , higher co2 concentration actually help ecosystem support more plant and animal life .', 'el niño drove record high in global temperature suggesting rise may not be down to man-made emission .', 'in 1946 , pdo switched to a cool phase .', 'weather channel co-founder john coleman provided evidence that convincingly refutes the concept of anthropogenic global warming .', '`` january 2008 capped a 12 month period of global temperature drop on all of the major well respected indicator .', 'the last time the planet wa even four degree warmer , peter brannen point out in the end of the world , his new history of the planet ’ s major extinction event , the ocean were hundred of foot higher .', 'tree-ring proxy reconstruction are reliable before 1960 , tracking closely with the instrumental record and other independent proxy .', 'under the most ambitious scenario , they found a strong likelihood that antarctica would remain fa

### tfidf retrieval

In [10]:

from sklearn.feature_extraction.text import TfidfVectorizer

vectorizer = TfidfVectorizer()
vectorizer.fit(cleaned_evidence_text)

train_tfidf = vectorizer.transform(train_claim_text)
dev_tfidf = vectorizer.transform(dev_claim_text)
test_tfidf = vectorizer.transform(test_claim_text)
evidence_tfidf = vectorizer.transform(cleaned_evidence_text)


In [11]:
train_cos_sims = cosine_similarity(train_tfidf, evidence_tfidf)
dev_cos_sims = cosine_similarity(dev_tfidf, evidence_tfidf)
test_cos_sims = cosine_similarity(test_tfidf, evidence_tfidf)
print(train_cos_sims.shape)

(1228, 1208827)


In [12]:
def test_retrieval_topk(k, cur_scores, cur_labels):
    ACC = []
    top_ids = torch.topk(torch.FloatTensor(cur_scores), k, -1).indices.tolist()
    for i in range(len(cur_labels)):
        all_count = 0
        recall_count = 0
        for cur_ in cur_labels[i]:
            if cur_ in top_ids[i]:
                recall_count += 1
            all_count += 1
        if all_count == 0:
            all_count = 1e-9  # to avoid division by zero
        ACC.append(recall_count / all_count)
    print(sum(ACC) / len(ACC))

topK = 30
test_retrieval_topk(topK, train_cos_sims, train_claim_evidences)
test_retrieval_topk(topK, dev_cos_sims, dev_claim_evidences)

0.26230998914223635
0.30573593073593086


In [13]:
def sort_evidence_candidates(cos_sims):
    top_ids = np.argsort(-cos_sims, axis=1)[:, :10000]
    return top_ids.tolist()

In [14]:
dev_sort_evidences = sort_evidence_candidates(dev_cos_sims)
test_sort_evidences = sort_evidence_candidates(test_cos_sims)
train_sort_evidences = sort_evidence_candidates(train_cos_sims)

### construct vocab and indexing

In [15]:
from collections import defaultdict

min_count = 5
wordcount = defaultdict(int)
idxword = ["<cls>",  "<sep>", "<pad>", "<unk>"]
wordidx = {"<cls>": 0, "<sep>": 1, "<pad>":2, "<unk>": 3}

for texts in train_claim_text + cleaned_evidence_text:
    for word in texts.split():
        wordcount[word] += 1

idx = 4
for word, count in wordcount.items():
    if count > min_count:
        idxword.append(word)
        wordidx[word] = idx
        idx += 1

In [16]:
def convert2idx(text_data, wordidx_):
    idx_data = []
    for texts in text_data:
        temp_idx = []
        for word in texts.split():
            temp_idx.append(wordidx_.get(word, wordidx_["<unk>"]))
        idx_data.append(temp_idx)
    return idx_data

In [17]:
train_text_idx = convert2idx(train_claim_text, wordidx)
dev_text_idx = convert2idx(dev_claim_text, wordidx)
test_text_idx = convert2idx(test_claim_text, wordidx)
evidences_text_idx = convert2idx(cleaned_evidence_text, wordidx)

In [18]:
print(max([len(i) for i in train_text_idx]), max([len(i) for i in dev_text_idx]), max([len(i) for i in test_text_idx]), max([len(i) for i in evidences_text_idx]))

76 73 60 636


In [19]:
text_pad_len = 50
evidences_pad_len = 80

In [20]:
def construct_input_text(text_idx, padding_len, wordidx_):
    idx_data = []
    for texts in text_idx:
        if len(texts) < padding_len:
            idx_data.append([wordidx_["<cls>"]] + texts + [wordidx_["<sep>"]] + [wordidx_["<pad>"]] * (padding_len - len(texts)))
        else:
            idx_data.append([wordidx_["<cls>"]] + texts[:padding_len] + [wordidx_["<sep>"]])
    return idx_data

In [21]:
train_input = construct_input_text(train_text_idx, text_pad_len, wordidx)
dev_input = construct_input_text(dev_text_idx, text_pad_len, wordidx)
test_input = construct_input_text(test_text_idx, text_pad_len, wordidx)
evidences_input = construct_input_text(evidences_text_idx, evidences_pad_len, wordidx)

In [22]:
print(len(train_input[0]), len(evidences_input[0]))

52 82


In [23]:
vocab_size = len(idxword)
print(vocab_size)

90097


In [24]:
from torch.utils.data import Dataset
import random

class TrainDataset(Dataset):
    def __init__(self, text_input_data, evidence_input_data, tfidf_sort_evidences, evidence_label, negative_num=10):
        self.text_input_data = text_input_data
        self.evidence_input_data = evidence_input_data
        self.tfidf_sort_evidences = tfidf_sort_evidences
        self.evidence_label = evidence_label
        self.negative_num = 10
        self.evidence_len = len(evidence_input_data[0])
        self.text_len = len(text_input_data[0])

    def __len__(self):
        return len(self.text_input_data)

    def __getitem__(self, idx):
        # please note the negative evidences
        return [self.text_input_data[idx], random.sample(self.tfidf_sort_evidences[idx][10: self.negative_num*10], self.negative_num), self.evidence_label[idx]]

    def collate_fn(self, batch):
        queries = []
        queries_pos = []
        evidences = []
        temp_labels = []

        for i, j, k in batch:
            queries.append(i)
            queries_pos.append(list(range(self.text_len)))
            temp_labels.append(k)
            evidences.extend(k + j)

        evidences = list(set(evidences))

        evidences2idx = {}
        for i, j in enumerate(evidences):
            evidences2idx[j] = i

        labels = []
        for i in temp_labels:
            labels.append([evidences2idx[j] for j in i])

        evidences = [self.evidence_input_data[i] for i in evidences]
        evidences_pos = [list(range(self.evidence_len)) for _ in range(len(evidences))]

        batch_encoding = {}
        batch_encoding["queries"] = torch.LongTensor(queries)
        batch_encoding["evidences"] = torch.LongTensor(evidences)
        batch_encoding["queries_pos"] = torch.LongTensor(queries_pos)
        batch_encoding["evidences_pos"] = torch.LongTensor(evidences_pos)
        batch_encoding["labels"] = labels

        return batch_encoding

In [25]:
train_set = TrainDataset(train_input, evidences_input, train_sort_evidences, train_claim_evidences, negative_num=800)
from torch.utils.data import DataLoader

dataloader = DataLoader(train_set, batch_size=5, shuffle=True, num_workers=0, collate_fn=train_set.collate_fn)

# 2. Model Implementation
(You can add as many code blocks and text blocks as you need. However, YOU SHOULD NOT MODIFY the section title)

In [26]:
# from workshop
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, vocab_emb, embed_dim, hidden_size, num_layers, max_position=180, dropout=0.2):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_emb, embed_dim)
        self.pos_embedding = nn.Embedding(max_position, embed_dim)
        self.encoder = nn.LSTM(embed_dim, hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout, bidirectional=True)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, text_data, position_text):
        text_x = self.embedding(text_data) + self.pos_embedding(position_text)
        text_x = self.dropout(text_x)
        x_encoded, _ = self.encoder(text_x)
        x_encoded = self.dropout(x_encoded)
        return x_encoded

In [27]:
lstm_encoder = Encoder(vocab_emb=vocab_size, embed_dim=512, hidden_size=512, num_layers=6, max_position=180)
lstm_encoder.cuda()

Encoder(
  (embedding): Embedding(90097, 512)
  (pos_embedding): Embedding(180, 512)
  (encoder): LSTM(512, 512, num_layers=6, batch_first=True, dropout=0.2, bidirectional=True)
  (dropout): Dropout(p=0.2, inplace=False)
)

### Training

In [28]:
torch.manual_seed(41)
torch.cuda.manual_seed_all(41)
random.seed(41)

weight_decay = 1e-4
encoder_optimizer = optim.Adam(lstm_encoder.parameters(), weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(encoder_optimizer, mode='min', factor=0.1, patience=10, verbose=True, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=1e-8, eps=1e-08)

max_lr = 1e-5
for param_group in encoder_optimizer.param_groups:
    param_group['lr'] = max_lr

accumulate_step = 3
grad_norm = 0.5
warmup_steps = 500
report_freq = 10
eval_interval = 50
save_dir = "model_ckpts"

In [29]:
retrieval_num = 5
dev_candis_num = 10

def validate(dev_text_idx, evidence_text_idx, dev_sort_evidences, dev_claim_evidences, encoder_model):
    # get evidence embeddings
    start_idx = 0
    batch_size = 800
    evidence_len = len(evidence_text_idx[0])
    text_len = len(dev_text_idx[0])

    evidence_embeddings = []
    encoder_model.eval()

    while start_idx < len(evidence_text_idx):
        end_idx = min(start_idx + batch_size, len(evidence_text_idx))

        cur_evidence = torch.LongTensor(evidence_text_idx[start_idx:end_idx]).view(-1, evidence_len).cuda()
        cur_evidence_pos = torch.LongTensor([list(range(evidence_len)) for _ in range(end_idx - start_idx)]).cuda()
        print()
        cur_embedding = encoder_model(cur_evidence, cur_evidence_pos)
        cur_embedding = cur_embedding[:, -1, :].detach()
        cur_embedding_cpu = F.normalize(cur_embedding, p=2, dim=1).cpu()
        del cur_embedding, cur_evidence, cur_evidence_pos
        start_idx = end_idx
        evidence_embeddings.append(cur_embedding_cpu)

    evidence_embeddings = torch.cat(evidence_embeddings, dim=0).t()
    print("get all evidence embeddings!")
    f = []

    start_idx = 0
    batch_size = 800

    while start_idx < len(dev_text_idx):
        end_idx = min(start_idx + batch_size, len(dev_text_idx))

        cur_query = torch.LongTensor(dev_text_idx[start_idx:end_idx]).view(-1, text_len).cuda()
        cur_query_pos = torch.LongTensor([list(range(text_len)) for _ in range(end_idx - start_idx)]).cuda()

        query_embedding = encoder_model(cur_query, cur_query_pos)
        query_embedding = query_embedding[:, -1, :].detach()
        query_embedding = F.normalize(query_embedding, p=2, dim=1).cpu()

        scores = torch.mm(query_embedding, evidence_embeddings)

        for i in range(scores.size(0)):
            new_score = torch.index_select(scores[i], 0, torch.LongTensor(dev_sort_evidences[start_idx+i][:dev_candis_num]))
            topk_ids = torch.argsort(new_score).tolist()
            select_ids = topk_ids[:retrieval_num]

            evidence_correct = 0
            pred_evidences = [dev_sort_evidences[start_idx+i][j] for j in select_ids]
            label = dev_claim_evidences[start_idx+i]
            for evidence_id in label:
                if evidence_id in pred_evidences:
                    evidence_correct += 1
            if evidence_correct > 0:
                evidence_recall = float(evidence_correct) / len(label)
                evidence_precision = float(evidence_correct) / len(pred_evidences)
                evidence_fscore = (2 * evidence_precision * evidence_recall) / (evidence_precision + evidence_recall)
                print(evidence_fscore)
            else:
                evidence_fscore = 0
            f.append(evidence_fscore)

        start_idx = end_idx
        # print("----")
    fscore = np.mean(f)
    print("\n")
    print("Evidence Retrieval F-score: %.3f" % fscore)
    print("\n")
    encoder_model.train()
    return fscore

In [30]:
%env WANDB_NOTEBOOK_NAME Mon5PMGroup7_COMP90042.ipynb

env: WANDB_NOTEBOOK_NAME=Mon5PMGroup7_COMP90042.ipynb


In [31]:
import subprocess

def run_command(command):
    result = subprocess.run(command, shell=True, text=True, capture_output=True)
    if result.returncode != 0:
        print(f"Command failed: {command}\n{result.stderr}")
    else:
        print(f"Command succeeded: {command}\n{result.stdout}")
    return result

required_packages = ["wandb"]
run_command(f"pip install {' '.join(required_packages)}")

Command succeeded: pip install wandb
Collecting wandb
  Downloading wandb-0.17.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.7/6.7 MB 51.0 MB/s eta 0:00:00
Collecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting gitpython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.43-py3-none-any.whl (207 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 207.3/207.3 kB 20.5 MB/s eta 0:00:00
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-2.3.1-py2.py3-none-any.whl (289 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 289.0/289.0 kB 29.3 MB/s eta 0:00:00
Collecting setproctitle (from wandb)
  Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)
Collecting gitdb<5,>=4.0.1 (from gitpython!=3.1.29,>=1.0.0->wandb)
  Down



In [32]:
# start training
import wandb
import os
wandb.init(project="nlp", name="dpr")

from tqdm import tqdm
import numpy as np

encoder_optimizer.zero_grad()
step_cnt = 0
all_step_cnt = 0
avg_loss = 0
maximum_f_score = 0

for epoch in range(5):
    epoch_step = 0

    for (i, batch) in enumerate(tqdm(dataloader)):

        step_cnt += 1


        query_embeddings = lstm_encoder(batch["queries"].cuda(), batch["queries_pos"].cuda())
        evidence_embeddings = lstm_encoder(batch["evidences"].cuda(), batch["evidences_pos"].cuda())

        query_embeddings = query_embeddings[:, -1, :]
        evidence_embeddings = evidence_embeddings[:, -1, :]

        assert query_embeddings.size(1) == evidence_embeddings.size(1), "Embedding dimensions do not match!"

        query_embeddings = torch.nn.functional.normalize(query_embeddings, p=2, dim=1)
        evidence_embeddings = torch.nn.functional.normalize(evidence_embeddings, p=2, dim=1)

        cos_sims = torch.mm(query_embeddings, evidence_embeddings.t())
        scores = cos_sims / 0.1
        loss = []
        start_idx = 0
        criterion = torch.nn.CrossEntropyLoss()
        for idx, labels in enumerate(batch["labels"]):
            labels = torch.LongTensor(labels).cuda()
            cur_loss = criterion(scores[idx].unsqueeze(0).repeat(len(labels), 1), labels)
            loss.append(cur_loss)

        # cos_sims = torch.mm(query_embeddings, evidence_embeddings.t())
        # scores = - torch.nn.functional.log_softmax(cos_sims / 0.1 + 1e-10, dim=1)
        # loss = []
        # start_idx = 0
        # for idx, label in enumerate(batch["labels"]):
        #     label = torch.LongTensor(label).cuda()
        #     cur_loss = torch.mean(torch.index_select(scores[idx], 0, label))
        #     loss.append(cur_loss)

        # cos_sims = torch.mm(query_embeddings, evidence_embeddings.t())
        # scores = - torch.nn.functional.log_softmax(cos_sims / 0.1, dim=1)

        # loss = []
        # start_idx = 0
        # for idx, label in enumerate(batch["labels"]):
        #     label = torch.LongTensor(label).cuda()
        #     cur_loss = torch.mean(torch.index_select(scores[idx], 0, label))
        #     loss.append(cur_loss)

        loss = torch.stack(loss).mean()
        loss = loss / accumulate_step
        loss.backward()

        avg_loss += loss.item()
        if step_cnt == accumulate_step:
            # updating
            if grad_norm > 0:
                nn.utils.clip_grad_norm_(lstm_encoder.parameters(), grad_norm)

            step_cnt = 0
            epoch_step += 1
            all_step_cnt += 1

            # adjust learning rate
            if all_step_cnt <= warmup_steps:
                lr = all_step_cnt * (max_lr - 2e-8) / warmup_steps + 2e-8
            else:
                lr = max_lr - (all_step_cnt - warmup_steps) * 1e-5

            encoder_optimizer.step()
            encoder_optimizer.zero_grad()

        if all_step_cnt % report_freq == 0 and step_cnt == 0:
            if all_step_cnt <= warmup_steps:
                lr = all_step_cnt * (max_lr - 2e-8) / warmup_steps + 2e-8
            else:
                lr = max_lr - (all_step_cnt - warmup_steps) * 1e-5

            wandb.log({"learning_rate": lr}, step=all_step_cnt)
            wandb.log({"loss": avg_loss / report_freq}, step=all_step_cnt)

            # report stats
            print("\n")
            print("epoch: %d, epoch_step: %d, avg loss: %.6f" % (epoch + 1, epoch_step, avg_loss / report_freq))
            print(f"learning rate: {lr:.6f}")
            print("\n")
            avg_loss = 0
        del loss, cos_sims, query_embeddings, evidence_embeddings

        if all_step_cnt % eval_interval == 0 and all_step_cnt != 0 and step_cnt == 0:
            # evaluate the model as a scorer
            print("\nEvaluate:\n")

            f_score = validate(dev_input, evidences_input, dev_sort_evidences, dev_claim_evidences, lstm_encoder)
            wandb.log({"f_score": f_score}, step=all_step_cnt)

            if f_score > maximum_f_score:
                maximum_f_score = f_score
                os.makedirs(save_dir, exist_ok=True)
                torch.save(lstm_encoder.state_dict(), os.path.join(os.path.abspath(save_dir), "best_ckpt.bin"))
                print("\n")
                print("best val loss - epoch: %d, epoch_step: %d" % (epoch, epoch_step))
                print("maximum_f_score", f_score)
                print("\n")



<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


  9%|▊         | 21/246 [00:04<00:26,  8.52it/s]



epoch: 1, epoch_step: 10, avg loss: 4.152601
learning rate: 0.000002




 17%|█▋        | 41/246 [00:06<00:25,  8.03it/s]



epoch: 1, epoch_step: 20, avg loss: 4.213582
learning rate: 0.000004




 25%|██▍       | 61/246 [00:09<00:22,  8.33it/s]



epoch: 1, epoch_step: 30, avg loss: 4.214569
learning rate: 0.000006




 33%|███▎      | 81/246 [00:11<00:20,  8.13it/s]



epoch: 1, epoch_step: 40, avg loss: 4.204774
learning rate: 0.000008




 40%|████      | 99/246 [00:13<00:17,  8.30it/s]



epoch: 1, epoch_step: 50, avg loss: 4.234063
learning rate: 0.000010



Evaluate:






















































































































































































































































































































































































































































































































































































































































































































































































































































































































































 41%|████      | 100/246 [06:03<4:15:46, 105.12s/it]



best val loss - epoch: 0, epoch_step: 50
maximum_f_score 0.07337146980004124




 49%|████▉     | 121/246 [06:06<00:25,  4.83it/s]



epoch: 1, epoch_step: 60, avg loss: 4.200689
learning rate: 0.000012




 57%|█████▋    | 141/246 [06:08<00:13,  7.95it/s]



epoch: 1, epoch_step: 70, avg loss: 4.209951
learning rate: 0.000014




 65%|██████▌   | 161/246 [06:11<00:10,  8.20it/s]



epoch: 1, epoch_step: 80, avg loss: 4.216153
learning rate: 0.000016




 74%|███████▎  | 181/246 [06:13<00:07,  8.54it/s]



epoch: 1, epoch_step: 90, avg loss: 4.216439
learning rate: 0.000018




 81%|████████  | 199/246 [06:15<00:05,  8.09it/s]



epoch: 1, epoch_step: 100, avg loss: 4.222216
learning rate: 0.000020



Evaluate:





















































































































































































































































































































































































































































































































































































































































































































































































































































































































































 82%|████████▏ | 201/246 [12:06<55:13, 73.64s/it]   



best val loss - epoch: 0, epoch_step: 100
maximum_f_score 0.07402082044939189




 90%|████████▉ | 221/246 [12:08<00:04,  5.45it/s]



epoch: 1, epoch_step: 110, avg loss: 4.202668
learning rate: 0.000022




 98%|█████████▊| 241/246 [12:10<00:00,  8.63it/s]



epoch: 1, epoch_step: 120, avg loss: 4.188036
learning rate: 0.000024




100%|██████████| 246/246 [12:11<00:00,  2.97s/it]
  6%|▌         | 15/246 [00:01<00:27,  8.26it/s]



epoch: 2, epoch_step: 7, avg loss: 4.141691
learning rate: 0.000026




 14%|█▍        | 35/246 [00:04<00:25,  8.43it/s]



epoch: 2, epoch_step: 17, avg loss: 4.198451
learning rate: 0.000028




 22%|██▏       | 53/246 [00:06<00:24,  7.94it/s]



epoch: 2, epoch_step: 27, avg loss: 4.206359
learning rate: 0.000030



Evaluate:






















































































































































































































































































































































































































































































































































































































































































































































































































































































































































 22%|██▏       | 54/246 [05:56<5:36:41, 105.22s/it]



best val loss - epoch: 1, epoch_step: 27
maximum_f_score 0.07407235621521338




 30%|███       | 75/246 [05:59<00:34,  4.93it/s]



epoch: 2, epoch_step: 37, avg loss: 4.207467
learning rate: 0.000032




 39%|███▊      | 95/246 [06:01<00:18,  8.30it/s]



epoch: 2, epoch_step: 47, avg loss: 4.204693
learning rate: 0.000034




 47%|████▋     | 115/246 [06:04<00:16,  7.95it/s]



epoch: 2, epoch_step: 57, avg loss: 4.215429
learning rate: 0.000036




 55%|█████▍    | 135/246 [06:06<00:13,  8.17it/s]



epoch: 2, epoch_step: 67, avg loss: 4.208705
learning rate: 0.000038




 62%|██████▏   | 153/246 [06:08<00:11,  8.24it/s]



epoch: 2, epoch_step: 77, avg loss: 4.213001
learning rate: 0.000040



Evaluate:






















































































































































































































































































































































































































































































































































































































































































































































































































































































































































 71%|███████   | 175/246 [12:02<00:12,  5.70it/s]



epoch: 2, epoch_step: 87, avg loss: 4.184637
learning rate: 0.000042




 79%|███████▉  | 195/246 [12:04<00:06,  8.46it/s]



epoch: 2, epoch_step: 97, avg loss: 4.207277
learning rate: 0.000044




 87%|████████▋ | 215/246 [12:07<00:03,  8.36it/s]



epoch: 2, epoch_step: 107, avg loss: 4.190912
learning rate: 0.000046




 96%|█████████▌| 235/246 [12:09<00:01,  7.99it/s]



epoch: 2, epoch_step: 117, avg loss: 4.222260
learning rate: 0.000048




100%|██████████| 246/246 [12:11<00:00,  2.97s/it]
  3%|▎         | 7/246 [00:00<00:29,  8.02it/s]



epoch: 3, epoch_step: 4, avg loss: 4.186563
learning rate: 0.000050



Evaluate:























































































































































































































































































































































































































































































































































































































































































































































































































































































































































  4%|▎         | 9/246 [05:50<5:02:54, 76.68s/it] 



best val loss - epoch: 2, epoch_step: 4
maximum_f_score 0.07730364873222018




 12%|█▏        | 29/246 [05:53<00:40,  5.39it/s]



epoch: 3, epoch_step: 14, avg loss: 4.215192
learning rate: 0.000052




 20%|█▉        | 49/246 [05:55<00:23,  8.23it/s]



epoch: 3, epoch_step: 24, avg loss: 4.201501
learning rate: 0.000054




 28%|██▊       | 69/246 [05:58<00:21,  8.10it/s]



epoch: 3, epoch_step: 34, avg loss: 4.217076
learning rate: 0.000056




 36%|███▌      | 89/246 [06:00<00:19,  8.25it/s]



epoch: 3, epoch_step: 44, avg loss: 4.195935
learning rate: 0.000058




 43%|████▎     | 107/246 [06:02<00:16,  8.40it/s]



epoch: 3, epoch_step: 54, avg loss: 4.193678
learning rate: 0.000060



Evaluate:






















































































































































































































































































































































































































































































































































































































































































































































































































































































































































 52%|█████▏    | 129/246 [11:54<00:21,  5.51it/s]



epoch: 3, epoch_step: 64, avg loss: 4.216608
learning rate: 0.000062




 61%|██████    | 149/246 [11:57<00:11,  8.33it/s]



epoch: 3, epoch_step: 74, avg loss: 4.226557
learning rate: 0.000064




 69%|██████▊   | 169/246 [11:59<00:09,  8.10it/s]



epoch: 3, epoch_step: 84, avg loss: 4.191457
learning rate: 0.000066




 77%|███████▋  | 189/246 [12:02<00:06,  8.21it/s]



epoch: 3, epoch_step: 94, avg loss: 4.209315
learning rate: 0.000068




 84%|████████▍ | 207/246 [12:04<00:04,  8.22it/s]



epoch: 3, epoch_step: 104, avg loss: 4.197675
learning rate: 0.000070



Evaluate:





















































































































































































































































































































































































































































































































































































































































































































































































































































































































































 93%|█████████▎| 229/246 [17:56<00:03,  5.53it/s]



epoch: 3, epoch_step: 114, avg loss: 4.208269
learning rate: 0.000072




100%|██████████| 246/246 [17:58<00:00,  4.38s/it]
  1%|          | 3/246 [00:00<00:28,  8.38it/s]



epoch: 4, epoch_step: 1, avg loss: 4.184046
learning rate: 0.000074




  9%|▉         | 23/246 [00:02<00:27,  7.97it/s]



epoch: 4, epoch_step: 11, avg loss: 4.200738
learning rate: 0.000076




 17%|█▋        | 43/246 [00:05<00:26,  7.79it/s]



epoch: 4, epoch_step: 21, avg loss: 4.194684
learning rate: 0.000078




 25%|██▍       | 61/246 [00:07<00:23,  7.95it/s]



epoch: 4, epoch_step: 31, avg loss: 4.222091
learning rate: 0.000080



Evaluate:






















































































































































































































































































































































































































































































































































































































































































































































































































































































































































 34%|███▎      | 83/246 [05:59<00:29,  5.61it/s]



epoch: 4, epoch_step: 41, avg loss: 4.189929
learning rate: 0.000082




 42%|████▏     | 103/246 [06:01<00:17,  8.20it/s]



epoch: 4, epoch_step: 51, avg loss: 4.208453
learning rate: 0.000084




 50%|█████     | 123/246 [06:04<00:14,  8.23it/s]



epoch: 4, epoch_step: 61, avg loss: 4.201815
learning rate: 0.000086




 58%|█████▊    | 143/246 [06:06<00:12,  8.21it/s]



epoch: 4, epoch_step: 71, avg loss: 4.205131
learning rate: 0.000088




 65%|██████▌   | 161/246 [06:08<00:10,  8.06it/s]



epoch: 4, epoch_step: 81, avg loss: 4.204673
learning rate: 0.000090



Evaluate:






















































































































































































































































































































































































































































































































































































































































































































































































































































































































































 66%|██████▋   | 163/246 [11:59<1:41:54, 73.67s/it] 



best val loss - epoch: 3, epoch_step: 81
maximum_f_score 0.08424551638837353




 74%|███████▍  | 183/246 [12:01<00:11,  5.45it/s]



epoch: 4, epoch_step: 91, avg loss: 4.217607
learning rate: 0.000092




 83%|████████▎ | 203/246 [12:04<00:05,  8.00it/s]



epoch: 4, epoch_step: 101, avg loss: 4.211913
learning rate: 0.000094




 91%|█████████ | 223/246 [12:06<00:02,  8.11it/s]



epoch: 4, epoch_step: 111, avg loss: 4.212825
learning rate: 0.000096




 99%|█████████▉| 243/246 [12:08<00:00,  8.46it/s]



epoch: 4, epoch_step: 121, avg loss: 4.199550
learning rate: 0.000098




100%|██████████| 246/246 [12:09<00:00,  2.96s/it]
  6%|▌         | 15/246 [00:01<00:27,  8.38it/s]



epoch: 5, epoch_step: 8, avg loss: 4.181140
learning rate: 0.000100



Evaluate:























































































































































































































































































































































































































































































































































































































































































































































































































































































































































  7%|▋         | 17/246 [05:52<4:41:39, 73.80s/it] 



best val loss - epoch: 4, epoch_step: 8
maximum_f_score 0.08920841063698207




 15%|█▌        | 37/246 [05:54<00:38,  5.49it/s]



epoch: 5, epoch_step: 18, avg loss: 4.175956
learning rate: 0.000000




 23%|██▎       | 57/246 [05:56<00:23,  8.21it/s]



epoch: 5, epoch_step: 28, avg loss: 4.179324
learning rate: -0.000100




 31%|███▏      | 77/246 [05:59<00:20,  8.12it/s]



epoch: 5, epoch_step: 38, avg loss: 4.202888
learning rate: -0.000200




 39%|███▉      | 97/246 [06:01<00:18,  8.11it/s]



epoch: 5, epoch_step: 48, avg loss: 4.211966
learning rate: -0.000300




 47%|████▋     | 115/246 [06:03<00:16,  8.04it/s]



epoch: 5, epoch_step: 58, avg loss: 4.208861
learning rate: -0.000400



Evaluate:





















































































































































































































































































































































































































































































































































































































































































































































































































































































































































 48%|████▊     | 118/246 [11:54<2:00:58, 56.71s/it] 



best val loss - epoch: 4, epoch_step: 58
maximum_f_score 0.09160482374768089




 56%|█████▌    | 137/246 [11:56<00:22,  4.84it/s]



epoch: 5, epoch_step: 68, avg loss: 4.198362
learning rate: -0.000500




 64%|██████▍   | 157/246 [11:59<00:10,  8.14it/s]



epoch: 5, epoch_step: 78, avg loss: 4.222757
learning rate: -0.000600




 72%|███████▏  | 177/246 [12:01<00:08,  8.14it/s]



epoch: 5, epoch_step: 88, avg loss: 4.202088
learning rate: -0.000700




 80%|████████  | 197/246 [12:04<00:05,  8.64it/s]



epoch: 5, epoch_step: 98, avg loss: 4.190010
learning rate: -0.000800




 87%|████████▋ | 215/246 [12:06<00:03,  8.08it/s]



epoch: 5, epoch_step: 108, avg loss: 4.237365
learning rate: -0.000900



Evaluate:




















































































































































































































































































































































































































































































































































































































































































































































































































































































































































 96%|█████████▋| 237/246 [18:00<00:01,  5.66it/s]



epoch: 5, epoch_step: 118, avg loss: 4.210722
learning rate: -0.001000




100%|██████████| 246/246 [18:01<00:00,  4.40s/it]


In [33]:
torch.cuda.empty_cache()

# 3.Testing and Evaluation
(You can add as many code blocks and text blocks as you need. However, YOU SHOULD NOT MODIFY the section title)

In [34]:
import os
lstm_encoder.load_state_dict(torch.load(os.path.join(save_dir, "best_ckpt.bin")))
lstm_encoder.cuda()
lstm_encoder.eval()

Encoder(
  (embedding): Embedding(90097, 512)
  (pos_embedding): Embedding(180, 512)
  (encoder): LSTM(512, 512, num_layers=6, batch_first=True, dropout=0.2, bidirectional=True)
  (dropout): Dropout(p=0.2, inplace=False)
)

In [35]:
evidence_embeddings = []
start_idx = 0
batch_size = 800
evidence_len = len(evidences_input[0])

while start_idx < len(evidences_input):
    end_idx = min(start_idx + batch_size, len(evidences_input))

    cur_evidence = torch.LongTensor(evidences_input[start_idx:end_idx]).view(-1, evidence_len).cuda()
    cur_evidence_pos = torch.LongTensor([list(range(evidence_len)) for _ in range(end_idx - start_idx)]).cuda()

    cur_embedding = lstm_encoder(cur_evidence, cur_evidence_pos)
    cur_embedding = cur_embedding[:, -1, :].detach()
    cur_embedding_cpu = F.normalize(cur_embedding, p=2, dim=1).cpu()  # for cosine similarity

    del cur_embedding, cur_evidence, cur_evidence_pos
    start_idx = end_idx
    evidence_embeddings.append(cur_embedding_cpu)

evidence_embeddings = torch.cat(evidence_embeddings, dim=0).t()


In [36]:
torch.cuda.empty_cache()

In [37]:
import numpy as np

In [38]:
retrieval_num = 5
dev_candis_num = 10

def validate_(dev_text_idx, evidence_embeddings, dev_sort_evidences, dev_claim_evidences, encoder_model):
    # get evidence embeddings
    encoder_model.eval()

    text_len = len(dev_text_idx[0])
    f = []

    start_idx = 0
    batch_size = 200

    while start_idx < len(dev_text_idx):
        end_idx = min(start_idx + batch_size, len(dev_text_idx))

        cur_query = torch.LongTensor(dev_text_idx[start_idx:end_idx]).view(-1, text_len).cuda()
        cur_query_pos = torch.LongTensor([list(range(text_len)) for _ in range(end_idx - start_idx)]).cuda()

        query_embedding = encoder_model(cur_query, cur_query_pos)
        query_embedding = query_embedding[:, -1, :].detach()
        query_embedding = F.normalize(query_embedding, p=2, dim=1).cpu()

        scores = torch.mm(query_embedding, evidence_embeddings)

        for i in range(scores.size(0)):
            new_score = torch.index_select(scores[i], 0, torch.LongTensor(dev_sort_evidences[start_idx+i][:dev_candis_num]))
            topk_ids = torch.argsort(new_score).tolist()
            select_ids = topk_ids[:retrieval_num]

            evidence_correct = 0
            pred_evidences = [dev_sort_evidences[start_idx+i][j] for j in select_ids]

            label = dev_claim_evidences[start_idx+i]

            for evidence_id in label:
                if evidence_id in pred_evidences:
                    evidence_correct += 1
            if evidence_correct > 0:
                evidence_recall = float(evidence_correct) / len(label)
                evidence_precision = float(evidence_correct) / len(pred_evidences)
                evidence_fscore = (2 * evidence_precision * evidence_recall) / (evidence_precision + evidence_recall)
                # print(evidence_fscore)
            else:
                evidence_fscore = 0
            f.append(evidence_fscore)

        start_idx = end_idx
        # print("----")
    fscore = np.mean(f)
    print("\n")
    print("Evidence Retrieval F-score: %.3f" % fscore)
    print("\n")
    return fscore

In [39]:
retrieval_num = 5
dev_candis_num = 10
fscore = validate_(dev_input, evidence_embeddings, dev_sort_evidences, dev_claim_evidences, lstm_encoder)
print(fscore)



Evidence Retrieval F-score: 0.092


0.09160482374768089


In [40]:
retrieval_num = 5
dev_candis_num = 10

def evidence_predicts(dev_text_idx, evidences_embeddings, dev_sort_evidences, cleaned_evidence_id, encoder_model):
    # get evidence embeddings
    text_len = len(dev_text_idx[0])
    encoder_model.eval()

    f = []
    start_idx = 0
    batch_size = 200
    preds = []
    while start_idx < len(dev_text_idx):
        end_idx = min(start_idx + batch_size, len(dev_text_idx))

        cur_query = torch.LongTensor(dev_text_idx[start_idx:end_idx]).view(-1, text_len).cuda()
        cur_query_pos = torch.LongTensor([list(range(text_len)) for _ in range(end_idx - start_idx)]).cuda()

        query_embedding = encoder_model(cur_query, cur_query_pos)
        query_embedding = query_embedding[:, -1, :].detach()
        query_embedding = F.normalize(query_embedding, p=2, dim=1).cpu()

        scores = torch.mm(query_embedding, evidences_embeddings)

        for i in range(scores.size(0)):
            new_score = torch.index_select(scores[i], 0, torch.LongTensor(dev_sort_evidences[start_idx+i][:dev_candis_num]))
            topk_ids = torch.argsort(new_score).tolist()
            select_ids = topk_ids[:retrieval_num]

            pred_evidences = [cleaned_evidence_id[dev_sort_evidences[start_idx+i][j]] for j in select_ids]
            preds.append(pred_evidences)

        start_idx = end_idx
    return preds

In [41]:
dev_evidences_ids = evidence_predicts(dev_input, evidence_embeddings, dev_sort_evidences, cleaned_evidence_id, lstm_encoder)
test_evidences_ids = evidence_predicts(test_input, evidence_embeddings, test_sort_evidences, cleaned_evidence_id, lstm_encoder)

In [42]:
pred_dev_claims = {}
pred_test_claims = {}
dev_claims = json.load(open("data/dev-claims.json", "r"))
test_claims = json.load(open("data/test-claims-unlabelled.json", "r"))

for idx, evidence_ids in enumerate(dev_evidences_ids):
    cur_data = dev_claims[dev_claim_id[idx]]
    cur_data['evidences'] = evidence_ids
    pred_dev_claims[dev_claim_id[idx]] = cur_data


for idx, evidence_ids in enumerate(test_evidences_ids):
    cur_data = test_claims[test_claim_id[idx]]
    cur_data['evidences'] = evidence_ids
    pred_test_claims[test_claim_id[idx]] = cur_data


In [43]:
json.dump(pred_dev_claims, open("data/dev_predict.json", "w"))
json.dump(pred_test_claims, open("data/test-claims-unlabelled.json", "w"))

In [44]:
retrieval_num = 5
dev_candis_num = 10

train_evidences_ids = evidence_predicts(train_input, evidence_embeddings, train_sort_evidences, cleaned_evidence_id, lstm_encoder)

pred_train_negative_evidences = []
for idx, evidence_ids in enumerate(train_evidences_ids):
    temp_ = []
    for i in evidence_ids:
        if evidences_id_dict[i] not in train_claim_evidences[idx]:
            temp_.append(evidences_id_dict[i])
    pred_train_negative_evidences.append(temp_)

In [45]:
## save prediction data

json.dump(pred_train_negative_evidences, open("pred_train_negative_evidences.json", "w"))

In [46]:
## save cls data

dev_cls_data = []
test_cls_data = []
text_max_len = 60
evidence_max_len = 100
all_max_len = 580

for idx, dev_text in enumerate(dev_text_idx):
    cur_data = {"label": dev_claim_label[idx]}
    temp_text = [wordidx["<cls>"]] + dev_text_idx[idx][:text_max_len]
    for i in dev_evidences_ids[idx]:
        temp_text.extend([wordidx["<sep>"]] + evidences_text_idx[evidences_id_dict[i]][:evidence_max_len])
    temp_text.append(wordidx["<sep>"])
    if len(temp_text) < all_max_len:
        temp_text.extend([wordidx["<pad>"]] * (all_max_len - len(temp_text)))
    cur_data['text'] = temp_text
    dev_cls_data.append(cur_data)

for idx, dev_text in enumerate(test_text_idx):
    cur_data = {}
    temp_text = [wordidx["<cls>"]] + test_text_idx[idx][:text_max_len]
    for i in test_evidences_ids[idx]:
        temp_text.extend([wordidx["<sep>"]] + evidences_text_idx[evidences_id_dict[i]][:evidence_max_len])
    temp_text.append(wordidx["<sep>"])
    if len(temp_text) < all_max_len:
        temp_text.extend([wordidx["<pad>"]] * (all_max_len - len(temp_text)))
    cur_data['text'] = temp_text
    test_cls_data.append(cur_data)

json.dump(dev_cls_data, open("dev_cls_data.json", "w"))
json.dump(test_cls_data, open("test_cls_data.json", "w"))

Task2

Preprocessing

In [47]:
import json

dev_cls_data = json.load(open("dev_cls_data.json", "r"))
test_cls_data = json.load(open("test_cls_data.json", "r"))

text_max_len = 60
evidence_max_len = 100
all_max_len = 580
retrieval_num = 5

id2labels = ["SUPPORTS", "NOT_ENOUGH_INFO", "REFUTES", "DISPUTED"]
labels2id = {"SUPPORTS": 0, "NOT_ENOUGH_INFO": 1, "REFUTES": 2, "DISPUTED": 3}

train_negative_evidences = json.load(open("pred_train_negative_evidences.json", "r"))


In [48]:
from torch.utils.data import Dataset
import random

class TrainDataset(Dataset):
    def __init__(self, text_data, evidence_data, positive_evidences, negative_evidences, cls_label, cls_idx, sep_idx, pad_idx, evidence_num=5):
        self.text_data = text_data
        self.evidence_data = evidence_data

        self.negative_evidences = negative_evidences

        self.cls_label = [labels2id[i] for i in cls_label]
        self.evidence_num = evidence_num
        self.positive_evidences = positive_evidences

        self.cls_idx = cls_idx
        self.sep_idx = sep_idx
        self.pad_idx = pad_idx

    def __len__(self):
        return len(self.text_data)

    def __getitem__(self, idx):
        return [self.text_data[idx][:text_max_len], self.positive_evidences[idx], self.negative_evidences[idx], self.cls_label[idx]]

    def collate_fn(self, batch):
        queries = []
        queries_pos = []
        labels = []

        for i, j, h, k in batch:
            temp_text = [self.cls_idx]
            temp_text.extend(i)
            for p in j:
                temp_text.append(self.sep_idx)
                temp_text.extend(self.evidence_data[p][:evidence_max_len])
            if self.evidence_num > len(j):
                n = random.sample(h, self.evidence_num - len(j))
                for p in n:
                    temp_text.append(self.sep_idx)
                    temp_text.extend(self.evidence_data[p][:evidence_max_len])
            temp_text.append(self.sep_idx)
            if len(temp_text) < all_max_len:
                temp_text.extend([self.pad_idx] * (all_max_len - len(temp_text)))

            queries.append(temp_text)
            queries_pos.append(list(range(all_max_len)))
            labels.append(k)

        batch_encoding = {}
        batch_encoding["queries"] = torch.LongTensor(queries)
        batch_encoding["queries_pos"] = torch.LongTensor(queries_pos)
        batch_encoding["labels"] = torch.LongTensor(labels)

        return batch_encoding

In [49]:
dev_inputs = [i['text'] for i in dev_cls_data]
test_inputs = [i['text'] for i in test_cls_data]
dev_outputs = [labels2id[i["label"]] for i in dev_cls_data]

In [50]:
train_set = TrainDataset(train_text_idx, evidences_text_idx, train_claim_evidences, train_negative_evidences, train_claim_label, wordidx["<cls>"], wordidx["<sep>"], wordidx["<pad>"], evidence_num=retrieval_num)
from torch.utils.data import DataLoader
dataloader = DataLoader(train_set, batch_size=10, shuffle=True, num_workers=0, collate_fn=train_set.collate_fn)

In [51]:
from collections import Counter
print(Counter(train_claim_label))

Counter({'SUPPORTS': 519, 'NOT_ENOUGH_INFO': 386, 'REFUTES': 199, 'DISPUTED': 124})


In [52]:
# from workshop
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class CLS(nn.Module):
    def __init__(self, vocab_emb, embed_dim, hidden_size, output_size, num_layers, max_position=all_max_len):
        super(CLS, self).__init__()

        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_emb, embed_dim)
        self.pos_embedding = nn.Embedding(max_position, embed_dim)

        self.encoder = nn.LSTM(embed_dim, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=True)
        self.hidden_layer = nn.Linear(hidden_size * 2, hidden_size)
        self.cls = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(0.2)

    def forward(self, text_data, position_text):
        text_x = self.embedding(text_data) + self.pos_embedding(position_text) * 0.01
        x_encoded,_ = self.encoder(text_x)
        x_cls = x_encoded[:, 0, :]
        x_hidden = F.tanh(self.hidden_layer(x_cls))
        self.dropout(x_hidden)
        cls_res = self.cls(x_hidden)
        return cls_res


In [53]:
cls_model = CLS(vocab_emb=len(idxword), embed_dim=256, hidden_size=256, output_size=4, num_layers=7, max_position=700)
cls_model.cuda()

CLS(
  (embedding): Embedding(90097, 256)
  (pos_embedding): Embedding(700, 256)
  (encoder): LSTM(256, 256, num_layers=7, batch_first=True, bidirectional=True)
  (hidden_layer): Linear(in_features=512, out_features=256, bias=True)
  (cls): Linear(in_features=256, out_features=4, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)

In [54]:
torch.manual_seed(41)
torch.cuda.manual_seed_all(41)
random.seed(41)

weight_decay = 0.02
encoder_optimizer = optim.AdamW(cls_model.parameters(), weight_decay=weight_decay)
max_lr = 1e-2
for param_group in encoder_optimizer.param_groups:
    param_group['lr'] = max_lr
accumulate_step = 2
grad_norm = 3
warmup_steps = 300
report_freq = 10
eval_interval = 50
save_dir = "model_ckpts"

In [55]:
def validate(dev_input, dev_output, cls_model_):
    # get evidence embeddings
    start_idx = 0
    batch_size = 40
    pos_len = len(dev_input[0])
    cls_model.eval()

    acc = []
    correct_count = 0
    while start_idx < len(dev_output):
        end_idx = min(start_idx + batch_size, len(dev_output))

        cur_input = torch.LongTensor(dev_input[start_idx:end_idx]).view(-1, pos_len).cuda()
        cur_pos = torch.LongTensor([list(range(pos_len)) for _ in range(end_idx - start_idx)]).cuda()

        cur_res = cls_model_(cur_input, cur_pos)
        cur_res = torch.argmax(cur_res, 1).tolist()

        del cur_input, cur_pos

        for i, j in zip(cur_res, dev_output[start_idx: end_idx]):
            if i == j:
                correct_count += 1

        start_idx = end_idx
    acc = correct_count / len(dev_output)
    print("\n")
    print("Classification Accuracy: %.3f" % acc)
    print("\n")

    cls_model.train()
    return acc

In [56]:
%env WANDB_NOTEBOOK_NAME Mon5PMGroup7_COMP90042.ipynb

env: WANDB_NOTEBOOK_NAME=Mon5PMGroup7_COMP90042.ipynb


In [57]:
# start training
import wandb
import os
wandb.init(project="nlp", name="cls")

from tqdm import tqdm
import numpy as np

encoder_optimizer.zero_grad()
step_cnt = 0
all_step_cnt = 0
avg_loss = 0
maximum_f_score = 0
ce_fn = nn.CrossEntropyLoss(torch.FloatTensor([0.2, 0.3, 0.5, 1.]).cuda())

for epoch in range(5):
    epoch_step = 0

    for (i, batch) in enumerate(tqdm(dataloader)):

        step_cnt += 1

        # forward pass

        cur_res = cls_model(batch["queries"].cuda(), batch["queries_pos"].cuda())

        loss = ce_fn(cur_res, batch["labels"].cuda())
        loss = loss / accumulate_step
        loss.backward()

        avg_loss += loss.item()
        if step_cnt == accumulate_step:
            # updating
            if grad_norm > 0:
                nn.utils.clip_grad_norm_(cls_model.parameters(), grad_norm)

            step_cnt = 0
            epoch_step += 1
            all_step_cnt += 1

            # adjust learning rate
            if all_step_cnt <= warmup_steps:
                lr = all_step_cnt * (max_lr - 2e-8) / warmup_steps + 2e-8
            else:
                lr = max_lr - (all_step_cnt - warmup_steps) * 1e-6

            encoder_optimizer.step()
            encoder_optimizer.zero_grad()

        if all_step_cnt % report_freq == 0 and step_cnt == 0:
            if all_step_cnt <= warmup_steps:
                lr = all_step_cnt * (max_lr - 2e-8) / warmup_steps + 2e-8
            else:
                lr = max_lr - (all_step_cnt - warmup_steps) * 1e-6

            wandb.log({"learning_rate": lr}, step=all_step_cnt)
            wandb.log({"loss": avg_loss / report_freq}, step=all_step_cnt)

            # report stats
            print("\n")
            print("epoch: %d, epoch_step: %d, avg loss: %.6f" % (epoch + 1, epoch_step, avg_loss / report_freq))
            print(f"learning rate: {lr:.6f}")
            print("\n")

            avg_loss = 0
        del loss, cur_res

        if all_step_cnt % eval_interval == 0 and all_step_cnt != 0 and step_cnt == 0:
            # evaluate the model as a scorer
            print("\nEvaluate:\n")

            f_score = validate(dev_inputs, dev_outputs, cls_model)
            wandb.log({"acc": f_score}, step=all_step_cnt)

            if f_score > maximum_f_score:
                maximum_f_score = f_score
                torch.save(cls_model.state_dict(), os.path.join(save_dir, "best_cls_ckpt.bin"))
                print("\n")
                print("best val loss - epoch: %d, epoch_step: %d" % (epoch, epoch_step))
                print("maximum_f_score", f_score)
                print("\n")

VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
f_score,▂▂▂▁▃▂▂▁▅▇█▆
learning_rate,▇▇▇▇▇▇▇██████████████████████████▇▆▅▄▄▂▁
loss,▂▆▆█▆▇▇▆▁▅▆▆▆▄▆▇▄▆▇▅▇▅▆▆▄▅▅▆▆▆▆▆▄▄▆▆▅▇▅▆

0,1
f_score,0.08661
learning_rate,-0.001
loss,4.21072


 16%|█▋        | 20/123 [00:04<00:22,  4.55it/s]



epoch: 1, epoch_step: 10, avg loss: 2.123071
learning rate: 0.000333




 33%|███▎      | 40/123 [00:08<00:18,  4.59it/s]



epoch: 1, epoch_step: 20, avg loss: 1.776262
learning rate: 0.000667




 49%|████▉     | 60/123 [00:13<00:13,  4.55it/s]



epoch: 1, epoch_step: 30, avg loss: 1.697694
learning rate: 0.001000




 65%|██████▌   | 80/123 [00:17<00:09,  4.55it/s]



epoch: 1, epoch_step: 40, avg loss: 1.437948
learning rate: 0.001333




 80%|████████  | 99/123 [00:21<00:05,  4.55it/s]



epoch: 1, epoch_step: 50, avg loss: 1.518521
learning rate: 0.001667



Evaluate:



 81%|████████▏ | 100/123 [00:22<00:08,  2.57it/s]



Classification Accuracy: 0.117




best val loss - epoch: 0, epoch_step: 50
maximum_f_score 0.11688311688311688




 98%|█████████▊| 120/123 [00:26<00:00,  4.50it/s]



epoch: 1, epoch_step: 60, avg loss: 1.722920
learning rate: 0.002000




100%|██████████| 123/123 [00:27<00:00,  4.46it/s]
 14%|█▍        | 17/123 [00:03<00:23,  4.54it/s]



epoch: 2, epoch_step: 9, avg loss: 1.582268
learning rate: 0.002333




 30%|███       | 37/123 [00:08<00:19,  4.50it/s]



epoch: 2, epoch_step: 19, avg loss: 1.594963
learning rate: 0.002667




 46%|████▋     | 57/123 [00:12<00:14,  4.59it/s]



epoch: 2, epoch_step: 29, avg loss: 1.543151
learning rate: 0.003000




 62%|██████▏   | 76/123 [00:16<00:10,  4.58it/s]



epoch: 2, epoch_step: 39, avg loss: 1.496019
learning rate: 0.003333



Evaluate:



Classification Accuracy: 0.266




 63%|██████▎   | 77/123 [00:17<00:18,  2.46it/s]



best val loss - epoch: 1, epoch_step: 39
maximum_f_score 0.2662337662337662




 79%|███████▉  | 97/123 [00:21<00:05,  4.52it/s]



epoch: 2, epoch_step: 49, avg loss: 1.666746
learning rate: 0.003667




 95%|█████████▌| 117/123 [00:26<00:01,  4.56it/s]



epoch: 2, epoch_step: 59, avg loss: 1.511119
learning rate: 0.004000




100%|██████████| 123/123 [00:27<00:00,  4.45it/s]
 11%|█▏        | 14/123 [00:03<00:23,  4.56it/s]



epoch: 3, epoch_step: 7, avg loss: 1.554546
learning rate: 0.004333




 28%|██▊       | 34/123 [00:07<00:19,  4.58it/s]



epoch: 3, epoch_step: 17, avg loss: 1.424294
learning rate: 0.004667




 43%|████▎     | 53/123 [00:11<00:15,  4.56it/s]



epoch: 3, epoch_step: 27, avg loss: 1.758125
learning rate: 0.005000



Evaluate:



 44%|████▍     | 54/123 [00:12<00:22,  3.10it/s]



Classification Accuracy: 0.117




 60%|██████    | 74/123 [00:16<00:10,  4.48it/s]



epoch: 3, epoch_step: 37, avg loss: 1.713394
learning rate: 0.005333




 76%|███████▋  | 94/123 [00:20<00:06,  4.59it/s]



epoch: 3, epoch_step: 47, avg loss: 1.753071
learning rate: 0.005667




 93%|█████████▎| 114/123 [00:25<00:01,  4.56it/s]



epoch: 3, epoch_step: 57, avg loss: 1.725835
learning rate: 0.006000




100%|██████████| 123/123 [00:27<00:00,  4.50it/s]
  9%|▉         | 11/123 [00:02<00:24,  4.52it/s]



epoch: 4, epoch_step: 6, avg loss: 1.722324
learning rate: 0.006333




 24%|██▍       | 30/123 [00:06<00:20,  4.57it/s]



epoch: 4, epoch_step: 16, avg loss: 1.620107
learning rate: 0.006667



Evaluate:



 25%|██▌       | 31/123 [00:07<00:29,  3.09it/s]



Classification Accuracy: 0.175




 41%|████▏     | 51/123 [00:11<00:15,  4.58it/s]



epoch: 4, epoch_step: 26, avg loss: 1.769225
learning rate: 0.007000




 58%|█████▊    | 71/123 [00:15<00:11,  4.59it/s]



epoch: 4, epoch_step: 36, avg loss: 1.666883
learning rate: 0.007333




 74%|███████▍  | 91/123 [00:20<00:06,  4.57it/s]



epoch: 4, epoch_step: 46, avg loss: 1.585195
learning rate: 0.007667




 90%|█████████ | 111/123 [00:24<00:02,  4.56it/s]



epoch: 4, epoch_step: 56, avg loss: 1.696485
learning rate: 0.008000




100%|██████████| 123/123 [00:27<00:00,  4.51it/s]
  6%|▌         | 7/123 [00:01<00:25,  4.59it/s]



epoch: 5, epoch_step: 4, avg loss: 1.606911
learning rate: 0.008333



Evaluate:



  7%|▋         | 8/123 [00:02<00:37,  3.04it/s]



Classification Accuracy: 0.117




 23%|██▎       | 28/123 [00:06<00:20,  4.56it/s]



epoch: 5, epoch_step: 14, avg loss: 1.641708
learning rate: 0.008667




 39%|███▉      | 48/123 [00:10<00:16,  4.55it/s]



epoch: 5, epoch_step: 24, avg loss: 1.623939
learning rate: 0.009000




 55%|█████▌    | 68/123 [00:15<00:12,  4.52it/s]



epoch: 5, epoch_step: 34, avg loss: 1.645367
learning rate: 0.009333




 72%|███████▏  | 88/123 [00:19<00:07,  4.55it/s]



epoch: 5, epoch_step: 44, avg loss: 1.649816
learning rate: 0.009667




 87%|████████▋ | 107/123 [00:23<00:03,  4.54it/s]



epoch: 5, epoch_step: 54, avg loss: 1.559312
learning rate: 0.010000



Evaluate:



Classification Accuracy: 0.442




 88%|████████▊ | 108/123 [00:24<00:05,  2.52it/s]



best val loss - epoch: 4, epoch_step: 54
maximum_f_score 0.44155844155844154




100%|██████████| 123/123 [00:27<00:00,  4.40it/s]


In [58]:
def predict(dev_input, cls_model_):
    # get evidence embeddings
    start_idx = 0
    batch_size = 50
    pos_len = len(dev_input[0])
    cls_model.eval()

    cls_res = []
    correct_count = 0
    while start_idx < len(dev_input):
        end_idx = min(start_idx + batch_size, len(dev_input))

        cur_input = torch.LongTensor(dev_input[start_idx:end_idx]).view(-1, pos_len).cuda()
        cur_pos = torch.LongTensor([list(range(pos_len)) for _ in range(end_idx - start_idx)]).cuda()

        cur_res = cls_model_(cur_input, cur_pos)
        cur_res = torch.argmax(cur_res, 1).tolist()

        del cur_input, cur_pos

        cls_res.extend(cur_res)

        start_idx = end_idx

    return cls_res

In [59]:
torch.cuda.empty_cache()

In [60]:
import os
cls_model.load_state_dict(torch.load(os.path.join(save_dir, "best_cls_ckpt.bin")))

dev_classes = predict(dev_inputs, cls_model)
test_classes = predict(test_inputs, cls_model)

In [61]:
pred_dev_claims = json.load(open("data/dev_predict.json", "r"))
pred_test_claims = json.load(open("data/test-claims-unlabelled.json", "r"))

for i, j in zip(dev_claim_id, dev_classes):
    claim_label = id2labels[j]
    evidences = pred_dev_claims[i]['evidences']
    pred_dev_claims[i] = {'claim_text': pred_dev_claims[i]['claim_text'], 'claim_label': claim_label, 'evidences': evidences}

for i, j in zip(test_claim_id, test_classes):
    claim_label = id2labels[j]
    evidences = pred_test_claims[i]['evidences']
    pred_test_claims[i] = {'claim_text': pred_test_claims[i]['claim_text'], 'claim_label': claim_label, 'evidences': evidences}

json.dump(pred_dev_claims, open("data/dev_predict.json", "w"))
json.dump(pred_test_claims, open("data/test-claims-unlabelled.json", "w"))


In [62]:
from collections import Counter
print(Counter(dev_classes))

Counter({0: 154})


In [63]:
print(Counter(test_classes))

Counter({0: 153})


In [64]:
import subprocess

output = subprocess.check_output("python eval.py --predictions data/dev_predict.json --groundtruth data/dev-claims.json", shell=True)
output_str = output.decode('utf-8')

# Split the output into lines
output_lines = output_str.strip().split('\n')

# Format the output
formatted_lines = []
for line in output_lines:
    metric, value = line.split('=')
    metric = metric.strip()
    value = value.strip()
    formatted_line = f"{metric}: {value}"
    formatted_lines.append(formatted_line)

# Join the formatted lines into a single string
formatted_output = '\n'.join(formatted_lines)
print(formatted_output)

Evidence Retrieval F-score (F): 0.0916048237476809
Claim Classification Accuracy (A): 0.44155844155844154
Harmonic Mean of F and A: 0.15173169588132618


## Object Oriented Programming codes here

*You can use multiple code snippets. Just add more if needed*