# 2024 COMP90042 Project
*Make sure you change the file name with your group id.*

# Readme
*If there is something to be noted for the marker, please mention here.*

*If you are planning to implement a program with Object Oriented Programming style, please put those the bottom of this ipynb file*

**We use pytorch, nltk, scikit-learn in this project.**

# 1.DataSet Processing
(You can add as many code blocks and text blocks as you need. However, YOU SHOULD NOT MODIFY the section title)

## PreProcess for evidence and claims

### preprocessing function

In [1]:
import torch

print("CUDA available:", torch.cuda.is_available())

CUDA available: True


In [2]:
print("Current CUDA device:", torch.cuda.current_device())
print("Device count:", torch.cuda.device_count())
print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))

Current CUDA device: 0
Device count: 1
Device name: NVIDIA GeForce RTX 3070 Ti


### read files

In [3]:
import json
import nltk
import string
import re
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from collections import Counter
from statistics import mean

nltk.download('stopwords')
nltk.download('wordnet')

with open('data/train-claims.json', 'r') as input_file:
    train_claims = json.load(input_file)

# Read in development data (claim)
with open('data/dev-claims.json', 'r') as input_file:
    dev_claims = json.load(input_file)

# Read in test data (claim)
with open('data/test-claims-unlabelled.json', 'r') as input_file:
    test_claims = json.load(input_file)

# Read in evidence data
with open('data/evidence.json', 'r') as input_file:
    evidences = json.load(input_file)

#EDA
claim_count = 0
evi_count = 0
claim_length = []
evidence_count = []
evidence_length = []
labels = []

for key,value in train_claims.items():
    claim_count+=1
    claim_length.append(len(value["claim_text"]))
    evidence_count.append(len(value["evidences"]))
    evidence_length += [len(evidences[x]) for x in value["evidences"]]
    labels.append(value["claim_label"])

for key,value in evidences.items():
    evi_count+=1

print("claim count: ",claim_count)
print("evidence count: ",evi_count)
print("max claim length: ",max(claim_length))
print("min claim length: ",min(claim_length))
print("mean claim length: ",mean(claim_length))
print("max evidence count: ",max(evidence_count))
print("min evidence count: ",min(evidence_count))
print("mean evidence count: ",mean(evidence_count))
print("max evidence length: ",max(evidence_length))
print("min evidence length: ",min(evidence_length))
print("mean evidence length: ",mean(evidence_length))
print(Counter(labels))

inside = 0
outside = 0

train_evi_id = []
for claim_id,claim_value in train_claims.items():
    train_evi_id=train_evi_id+claim_value['evidences']

for claim_id,claim_value in dev_claims.items():
    test_evi_id=claim_value['evidences']
    for e in test_evi_id:
        if e in train_evi_id:
            inside += 1
        else:
            outside += 1
print("Dev evi inside train evi", inside)
print("Dev evi outside train evi", outside)

full_evidence_id = list(evidences.keys())
full_evidence_text  = list(evidences.values())
train_claim_id = list(train_claims.keys())
train_claim_text  = [ v["claim_text"] for v in train_claims.values()]
print("Train claim count: ",len(train_claim_id))

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\ABC\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\ABC\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


claim count:  1228
evidence count:  1208827
max claim length:  332
min claim length:  26
mean claim length:  122.95521172638436
max evidence count:  5
min evidence count:  1
mean evidence count:  3.3566775244299674
max evidence length:  1979
min evidence length:  13
mean evidence length:  173.5
Counter({'SUPPORTS': 519, 'NOT_ENOUGH_INFO': 386, 'REFUTES': 199, 'DISPUTED': 124})
Dev evi inside train evi 163
Dev evi outside train evi 328
Train claim count:  1228


In [4]:
lemmatizer = nltk.stem.wordnet.WordNetLemmatizer()
stopwords = set(stopwords.words('english'))

def lemmatize(word):
    lemma = lemmatizer.lemmatize(word, 'v')
    return lemma if lemma != word else lemmatizer.lemmatize(word, 'n')

def is_pure_english(text):
    english_letters = set(string.ascii_letters)
    cleaned_text = ''.join(char for char in text if char.isalpha() or char.isspace())
    return all(char in english_letters or char.isspace() for char in cleaned_text)

def remove_non_eng(dictionary):
    eng_data = {}
    for key, value in dictionary.items():
        if is_pure_english(value):
            eng_data[key] = value
    return eng_data

def contains_climate_keywords(text, keywords):
    text = text.lower()
    for keyword in keywords:
        if re.search(r"\b" + re.escape(keyword) + r"\b", text):
            return True
    return False

def filter_climate_related(dictionary, keywords):
    cs_data = {}
    for key, value in dictionary.items():
        if contains_climate_keywords(value, keywords):
            cs_data[key] = value
    return cs_data

def text_preprocessing(text, remove_stopwords=False):
    if isinstance(text, str):
        words = [lemmatize(w) for w in text.lower().split()]
    elif isinstance(text, list):
        words = [lemmatize(w.lower()) for w in text]
    else:
        raise ValueError("Unsupported type for 'text'. Expected str or list.")
    
    if remove_stopwords:
        words = [w for w in words if w not in stopwords]
    
    return " ".join(words)

In [5]:
climate_keywords = [
    "climate", "environment", "global warming", "greenhouse effect", "carbon", "co2", "carbon dioxide",
    "methane", "renewable energy", "sustainability", "ecology", "biodiversity", "fossil fuels",
    "emissions", "air quality", "ozone", "solar energy", "wind energy", "climate change", "climate crisis",
    "climate adaptation", "climate mitigation", "ocean", "sea levels", "ice melting", "deforestation",
    "reforestation", "pollution"," electricity","energy","solar","wind","renewable","fossil","fuel","emission","air","quality","ozone","solar","wind","climate","change","crisis","adaptation","mitigation","ocean","sea","level","ice","melt","deforestation",
]


def preprocess_claim_data(claim_data, existed_evidences_id=None):
    claim_data = remove_non_eng(claim_data)
    claim_data_text = []
    claim_data_id = []
    claim_data_label = []
    claim_evidences = []
    
    for key in claim_data.keys():
        claim_data[key]["claim_text"] = word_tokenize(claim_data[key]["claim_text"])
        claim_data[key]["claim_text"] = text_preprocessing(claim_data[key]["claim_text"])
        
        claim_data_text.append(claim_data[key]["claim_text"])
        claim_data_id.append(key)
        
        if "claim_label" in claim_data[key]:
            claim_data_label.append(claim_data[key]["claim_label"])
        else:
            claim_data_label.append(None)
        
        if existed_evidences_id and "evidences" in claim_data[key]:
            valid_evidences = [existed_evidences_id[i] for i in claim_data[key]["evidences"] if i in existed_evidences_id]
            claim_evidences.append(valid_evidences)
        else:
            claim_evidences.append([])
    
    return claim_data_text, claim_data_id, claim_data_label, claim_evidences


def preprocess_evi_data(evi_data, climate_keywords):
    evi_data = remove_non_eng(evi_data)
    cs_evi_data = filter_climate_related(evi_data, climate_keywords)
    
    for key in cs_evi_data.keys():
        cs_evi_data[key] = word_tokenize(cs_evi_data[key])
        cs_evi_data[key] = text_preprocessing(cs_evi_data[key], remove_stopwords=True)
        
    cleaned_evidence_text = list(cs_evi_data.values())
    cleaned_evidence_id = list(cs_evi_data.keys())
    
    return cleaned_evidence_text, cleaned_evidence_id

In [6]:
cleaned_evidence_text, cleaned_evidence_id = preprocess_evi_data(evidences, climate_keywords)

evidences_id_dict = {evidence_id: idx for idx, evidence_id in enumerate(cleaned_evidence_id)}

train_claim_text, train_claim_id, train_claim_label, train_claim_evidences = preprocess_claim_data(train_claims, evidences_id_dict)

dev_claim_text, dev_claim_id, dev_claim_label, dev_claim_evidences = preprocess_claim_data(dev_claims, evidences_id_dict)

test_claim_text, test_claim_id, _, _ = preprocess_claim_data(test_claims)

### tfidf retrieval

In [7]:

from sklearn.feature_extraction.text import TfidfVectorizer

# vectorizer = TfidfVectorizer(max_features=500000)
vectorizer = TfidfVectorizer()
vectorizer.fit(cleaned_evidence_text)
# TODO can svd 
train_tfidf = vectorizer.transform(train_claim_text)
dev_tfidf = vectorizer.transform(dev_claim_text  )
test_tfidf = vectorizer.transform(test_claim_text)
evidence_tfidf = vectorizer.transform(cleaned_evidence_text)


In [8]:
train_cos_sims = cosine_similarity(train_tfidf, evidence_tfidf)
dev_cos_sims = cosine_similarity(dev_tfidf, evidence_tfidf)
test_cos_sims = cosine_similarity(test_tfidf, evidence_tfidf)
print(train_cos_sims.shape)

(1228, 54272)


In [9]:
def test_retrieval_topk(k, cur_scores, cur_labels):
    ACC = []
    top_ids = torch.topk(torch.FloatTensor(cur_scores), k, -1).indices.tolist()
    for i in range(len(cur_labels)):
        all_count = 0
        recall_count = 0
        for cur_ in cur_labels[i]:
            if cur_ in top_ids[i]:
                recall_count += 1
            all_count += 1
        if all_count == 0:
            all_count = 1e-9  # to avoid division by zero
        ACC.append(recall_count / all_count)
    print(sum(ACC) / len(ACC))

topK = 10
test_retrieval_topk(topK, train_cos_sims, train_claim_evidences)
test_retrieval_topk(topK, dev_cos_sims, dev_claim_evidences)

0.14695982627578721
0.15670995670995672


In [10]:
# need to change this code
def sort_evidence_candidates(cos_sims):
    # dev_candis = np.argpartition(-cos_sims, candi_num, 1)[:candi_num]
    candis = []
    for i in range(cos_sims.shape[0]):
        cur_top_ids = np.argsort(-cos_sims[i]).tolist()[:10000]
        candis.append(cur_top_ids)
    return candis

In [11]:
dev_sort_evidences = sort_evidence_candidates(dev_cos_sims)
test_sort_evidences = sort_evidence_candidates(test_cos_sims)
train_sort_evidences = sort_evidence_candidates(train_cos_sims)

### construct vocab and indexing

In [12]:
# construct word2idx and idx2word for taining data
# if we use the unprocessed text else use ***_p_texts

min_count = 10
wordcount = {}
idx2word = ["<pad>", "<cls>", "<sep>", "<unk>"]
word2idx = {"<pad>": 0, "<cls>": 1, "<sep>": 2, "<unk>": 3}
for texts in train_claim_text + cleaned_evidence_text:
    for word in texts.split():
        wordcount[word] = wordcount.get(word, 0) + 1
idx = 4
for i, j in wordcount.items():
    if j > min_count:
        idx2word.append(i)
        word2idx[i] = idx
        idx += 1

In [13]:
def convert2idx(text_data, word2idx_, idx2word_):
    idx_data = []
    for texts in text_data:
        temp_idx = []
        for word in texts.split():
            temp_idx.append(word2idx_.get(word, word2idx_["<unk>"]))
        idx_data.append(temp_idx)
    return idx_data

In [14]:
train_text_idx = convert2idx(train_claim_text, word2idx, idx2word)
dev_text_idx = convert2idx(dev_claim_text, word2idx, idx2word)
test_text_idx = convert2idx(test_claim_text, word2idx, idx2word)
evidences_text_idx = convert2idx(cleaned_evidence_text, word2idx, idx2word)

In [15]:
print(max([len(i) for i in train_text_idx]), max([len(i) for i in dev_text_idx]), max([len(i) for i in test_text_idx]), max([len(i) for i in evidences_text_idx]))

76 73 60 231


In [16]:
text_pad_len = 60
evidences_pad_len = 100

In [17]:
def construct_input_text(text_idx, padding_len, word2idx_):
    idx_data = []
    for texts in text_idx:
        if len(texts) < padding_len:
            idx_data.append([word2idx_["<cls>"]] + texts + [word2idx_["<sep>"]] + [word2idx_["<pad>"]] * (padding_len - len(texts)))
        else:
            idx_data.append([word2idx_["<cls>"]] + texts[:padding_len] + [word2idx_["<sep>"]])
    return idx_data
    

In [18]:
train_input = construct_input_text(train_text_idx, text_pad_len, word2idx)
dev_input = construct_input_text(dev_text_idx, text_pad_len, word2idx)
test_input = construct_input_text(test_text_idx, text_pad_len, word2idx)
evidences_input = construct_input_text(evidences_text_idx, evidences_pad_len, word2idx)

In [19]:
print(len(train_input[0]), len(evidences_input[0]))

62 102


In [20]:
vocab_size = len(idx2word)
print(vocab_size)

7075


In [21]:
from torch.utils.data import Dataset
import random

class TrainDataset(Dataset):
    def __init__(self, text_input_data, evidence_input_data, tfidf_sort_evidences, evidence_label, negative_num=10):
        self.text_input_data = text_input_data
        self.evidence_input_data = evidence_input_data
        self.tfidf_sort_evidences = tfidf_sort_evidences
        self.evidence_label = evidence_label
        self.negative_num = 10
        self.evidence_len = len(evidence_input_data[0])
        self.text_len = len(text_input_data[0])
        
    def __len__(self):
        return len(self.text_input_data)

    def __getitem__(self, idx):
        # please note the negative evidences
        return [self.text_input_data[idx], random.sample(self.tfidf_sort_evidences[idx][10: self.negative_num*10], self.negative_num), self.evidence_label[idx]]

    def collate_fn(self, batch):
        queries = []
        queries_pos = []
        evidences = []
        temp_labels = []
        
        for i, j, k in batch:
            queries.append(i)
            queries_pos.append(list(range(self.text_len)))
            temp_labels.append(k)    
            evidences.extend(k + j)

        evidences = list(set(evidences))
        
        evidences2idx = {}
        for i, j in enumerate(evidences):
            evidences2idx[j] = i
        
        labels = []
        for i in temp_labels:
            labels.append([evidences2idx[j] for j in i])
            
        evidences = [self.evidence_input_data[i] for i in evidences]
        evidences_pos = [list(range(self.evidence_len)) for _ in range(len(evidences))]

        batch_encoding = {}
        batch_encoding["queries"] = torch.LongTensor(queries)
        batch_encoding["evidences"] = torch.LongTensor(evidences)
        
        batch_encoding["queries_pos"] = torch.LongTensor(queries_pos)
        batch_encoding["evidences_pos"] = torch.LongTensor(evidences_pos)
        batch_encoding["labels"] = labels
        
        return batch_encoding

In [22]:
train_set = TrainDataset(train_input, evidences_input, train_sort_evidences, train_claim_evidences, negative_num=1000)
from torch.utils.data import DataLoader

dataloader = DataLoader(train_set, batch_size=5, shuffle=True, num_workers=0, collate_fn=train_set.collate_fn)

# 2. Model Implementation
(You can add as many code blocks and text blocks as you need. However, YOU SHOULD NOT MODIFY the section title)

In [23]:
# from workshop
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, vocab_emb, embed_dim, hidden_size, num_layers, max_position=180, dropout=0.2):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_emb, embed_dim)
        self.pos_embedding = nn.Embedding(max_position, embed_dim)
        self.encoder = nn.LSTM(embed_dim, hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout, bidirectional=True)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, text_data, position_text):
        text_x = self.embedding(text_data) + self.pos_embedding(position_text)
        text_x = self.dropout(text_x)
        x_encoded, _ = self.encoder(text_x)
        x_encoded = self.dropout(x_encoded)
        return x_encoded

In [41]:
trans_encoder = Encoder(vocab_emb=vocab_size, embed_dim=256, hidden_size=256, num_layers=6, max_position=180)
trans_encoder.cuda()

Encoder(
  (embedding): Embedding(7075, 256)
  (pos_embedding): Embedding(180, 256)
  (encoder): LSTM(256, 256, num_layers=6, batch_first=True, dropout=0.2, bidirectional=True)
  (dropout): Dropout(p=0.2, inplace=False)
)

### Training

In [42]:
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
random.seed(42)

weight_decay = 0.01
encoder_optimizer = optim.Adam(trans_encoder.parameters(), weight_decay=weight_decay)
max_lr = 1e-3
for param_group in encoder_optimizer.param_groups:
    param_group['lr'] = max_lr

accumulate_step = 2
grad_norm = 0.1
warmup_steps = 200
report_freq = 10
eval_interval = 50
save_dir = "model_ckpts"

In [44]:
retrieval_num = 5
dev_candis_num = 10

def validate(dev_text_idx, evidence_text_idx, dev_sort_evidences, dev_claim_evidences, encoder_model):
    # get evidence embeddings
    start_idx = 0
    batch_size = 800
    evidence_len = len(evidence_text_idx[0])
    text_len = len(dev_text_idx[0])
    
    evidence_embeddings = []
    encoder_model.eval()
    
    while start_idx < len(evidence_text_idx):
        end_idx = min(start_idx + batch_size, len(evidence_text_idx))
        
        cur_evidence = torch.LongTensor(evidence_text_idx[start_idx:end_idx]).view(-1, evidence_len).cuda()
        cur_evidence_pos = torch.LongTensor([list(range(evidence_len)) for _ in range(end_idx - start_idx)]).cuda()
        print()
        cur_embedding = encoder_model(cur_evidence, cur_evidence_pos)
        cur_embedding = cur_embedding[:, -1, :].detach()
        cur_embedding_cpu = F.normalize(cur_embedding, p=2, dim=1).cpu()
        del cur_embedding, cur_evidence, cur_evidence_pos
        start_idx = end_idx
        evidence_embeddings.append(cur_embedding_cpu)
        
    evidence_embeddings = torch.cat(evidence_embeddings, dim=0).t()
    print("get all evidence embeddings!")
    f = []
    
    start_idx = 0
    batch_size = 800
    
    while start_idx < len(dev_text_idx):
        end_idx = min(start_idx + batch_size, len(dev_text_idx))
        
        cur_query = torch.LongTensor(dev_text_idx[start_idx:end_idx]).view(-1, text_len).cuda()
        cur_query_pos = torch.LongTensor([list(range(text_len)) for _ in range(end_idx - start_idx)]).cuda()

        query_embedding = encoder_model(cur_query, cur_query_pos)
        query_embedding = query_embedding[:, -1, :].detach()
        query_embedding = F.normalize(query_embedding, p=2, dim=1).cpu()
        
        scores = torch.mm(query_embedding, evidence_embeddings)

        for i in range(scores.size(0)):
            new_score = torch.index_select(scores[i], 0, torch.LongTensor(dev_sort_evidences[start_idx+i][:dev_candis_num]))
            topk_ids = torch.argsort(new_score).tolist()
            select_ids = topk_ids[:retrieval_num]
            
            evidence_correct = 0
            pred_evidences = [dev_sort_evidences[start_idx+i][j] for j in select_ids]
            label = dev_claim_evidences[start_idx+i]
            for evidence_id in label:
                if evidence_id in pred_evidences:
                    evidence_correct += 1
            if evidence_correct > 0:
                evidence_recall = float(evidence_correct) / len(label)
                evidence_precision = float(evidence_correct) / len(pred_evidences)
                evidence_fscore = (2 * evidence_precision * evidence_recall) / (evidence_precision + evidence_recall)
                print(evidence_fscore)
            else:
                evidence_fscore = 0
            f.append(evidence_fscore)
            
        start_idx = end_idx
        # print("----")
    fscore = np.mean(f)
    print("\n")
    print("Evidence Retrieval F-score: %.3f" % fscore)
    print("\n")
    encoder_model.train()
    return fscore

In [27]:
%env WANDB_NOTEBOOK_NAME MelMoxue_NLP_retrieval.ipynb

env: WANDB_NOTEBOOK_NAME=MelMoxue_NLP_retrieval.ipynb


In [45]:
# start training
import wandb
import os
wandb.init(project="nlp", name="dpr")

from tqdm import tqdm
import numpy as np

encoder_optimizer.zero_grad()
step_cnt = 0
all_step_cnt = 0
avg_loss = 0
maximum_f_score = 0

for epoch in range(5): 
    epoch_step = 0

    for (i, batch) in enumerate(tqdm(dataloader)):
        
        step_cnt += 1
        # forward pass
        # query_embeddings = trans_encoder(batch["queries"].cuda(), batch["queries_pos"].cuda())
        # evidence_embeddings = trans_encoder(batch["evidences"].cuda(), batch["evidences_pos"].cuda())
        
        # query_embeddings = query_embeddings[:, 0, :]
        # evidence_embeddings = evidence_embeddings[:, 0, :]

        query_embeddings = trans_encoder(batch["queries"].cuda(), batch["queries_pos"].cuda())
        evidence_embeddings = trans_encoder(batch["evidences"].cuda(), batch["evidences_pos"].cuda())
        
        query_embeddings = query_embeddings[:, -1, :]
        evidence_embeddings = evidence_embeddings[:, -1, :]
        
        assert query_embeddings.size(1) == evidence_embeddings.size(1), "Embedding dimensions do not match!"
        
        query_embeddings = torch.nn.functional.normalize(query_embeddings, p=2, dim=1)
        evidence_embeddings = torch.nn.functional.normalize(evidence_embeddings, p=2, dim=1)

        cos_sims = torch.mm(query_embeddings, evidence_embeddings.t())
        scores = - torch.nn.functional.log_softmax(cos_sims / 0.1, dim=1)

        loss = []
        start_idx = 0
        for idx, label in enumerate(batch["labels"]):
            label = torch.LongTensor(label).cuda()
            cur_loss = torch.mean(torch.index_select(scores[idx], 0, label))
            loss.append(cur_loss)

        loss = torch.stack(loss).mean()
        # cos_sims = torch.mm(query_embeddings, evidence_embeddings.t())
        # scores = cos_sims / 0.1
        # loss = []
        # start_idx = 0
        # criterion = torch.nn.CrossEntropyLoss()
        # for idx, labels in enumerate(batch["labels"]):
        #     labels = torch.LongTensor(labels).cuda()
        #     cur_loss = criterion(scores[idx].unsqueeze(0).repeat(len(labels), 1), labels)
        #     loss.append(cur_loss)
            
        # loss = torch.stack(loss).mean()
        loss = loss / accumulate_step
        loss.backward()

        avg_loss += loss.item()
        if step_cnt == accumulate_step:
            # updating
            if grad_norm > 0:
                nn.utils.clip_grad_norm_(trans_encoder.parameters(), grad_norm)

            step_cnt = 0
            epoch_step += 1
            all_step_cnt += 1
            
            # adjust learning rate
            if all_step_cnt <= warmup_steps:
                lr = all_step_cnt * (max_lr - 2e-8) / warmup_steps + 2e-8
            else:
                lr = max_lr - (all_step_cnt - warmup_steps) * 1e-5
                
            encoder_optimizer.step()
            encoder_optimizer.zero_grad()
        
        if all_step_cnt % report_freq == 0 and step_cnt == 0:
            if all_step_cnt <= warmup_steps:
                lr = all_step_cnt * (max_lr - 2e-8) / warmup_steps + 2e-8
            else:
                lr = max_lr - (all_step_cnt - warmup_steps) * 1e-5

            wandb.log({"learning_rate": lr}, step=all_step_cnt)
            wandb.log({"loss": avg_loss / report_freq}, step=all_step_cnt)
            
            # report stats
            print("\n")
            print("epoch: %d, epoch_step: %d, avg loss: %.6f" % (epoch + 1, epoch_step, avg_loss / report_freq))
            print(f"learning rate: {lr:.6f}")
            print("\n")
            avg_loss = 0
        del loss, cos_sims, query_embeddings, evidence_embeddings

        if all_step_cnt % eval_interval == 0 and all_step_cnt != 0 and step_cnt == 0:
            # evaluate the model as a scorer
            print("\nEvaluate:\n")
            
            f_score = validate(dev_input, evidences_input, dev_sort_evidences, dev_claim_evidences, trans_encoder)
            wandb.log({"f_score": f_score}, step=all_step_cnt)

            if f_score > maximum_f_score:
                maximum_f_score = f_score
                os.makedirs(save_dir, exist_ok=True)
                torch.save(trans_encoder.state_dict(), os.path.join(os.path.abspath(save_dir), "best_ckpt.bin"))
                print("\n")
                print("best val loss - epoch: %d, epoch_step: %d" % (epoch, epoch_step))
                print("maximum_f_score", f_score)
                print("\n")

0,1
f_score,▁
learning_rate,▁▂▃▃▄▅▆▆▇█

0,1
f_score,0.03721
learning_rate,0.0
loss,


  9%|▊         | 21/246 [00:12<00:58,  3.82it/s]



epoch: 1, epoch_step: 10, avg loss: nan
learning rate: 0.000025




 16%|█▋        | 40/246 [00:17<00:37,  5.49it/s]



epoch: 1, epoch_step: 20, avg loss: nan
learning rate: 0.000050




 25%|██▍       | 61/246 [00:27<00:58,  3.17it/s]



epoch: 1, epoch_step: 30, avg loss: nan
learning rate: 0.000075




 33%|███▎      | 81/246 [00:35<00:45,  3.64it/s]



epoch: 1, epoch_step: 40, avg loss: nan
learning rate: 0.000100




 40%|████      | 99/246 [00:46<01:17,  1.90it/s]



epoch: 1, epoch_step: 50, avg loss: nan
learning rate: 0.000125



Evaluate:





































































get all evidence embeddings!
0.22222222222222224
0.20000000000000004
0.28571428571428575
0.28571428571428575
0.28571428571428575
0.5714285714285715
0.20000000000000004
0.28571428571428575
0.28571428571428575
0.33333333333333337
0.33333333333333337
0.25
0.33333333333333337
0.25
0.25
0.33333333333333337
0.28571428571428575
0.33333333333333337
0.25
0.20000000000000004
0.33333333333333337
0.33333333333333337
0.28571428571428575
0.4000000000000001
0.28571428571428575


Evidence Retrieval F-score: 0.048




best val loss - epoch: 0, epoch_step: 50
maximum_f_score 0.04813440527726242




 49%|████▉     | 121/246 [01:07<00:41,  2.98it/s]



epoch: 1, epoch_step: 60, avg loss: nan
learning rate: 0.000150




 57%|█████▋    | 140/246 [01:20<01:45,  1.00it/s]



epoch: 1, epoch_step: 70, avg loss: nan
learning rate: 0.000175




 65%|██████▌   | 161/246 [01:46<03:37,  2.56s/it]



epoch: 1, epoch_step: 80, avg loss: nan
learning rate: 0.000200




 74%|███████▎  | 181/246 [01:59<00:25,  2.57it/s]



epoch: 1, epoch_step: 90, avg loss: nan
learning rate: 0.000225




 81%|████████  | 199/246 [02:05<00:14,  3.13it/s]



epoch: 1, epoch_step: 100, avg loss: nan
learning rate: 0.000250



Evaluate:







































































 82%|████████▏ | 201/246 [02:24<03:07,  4.18s/it]

get all evidence embeddings!
0.22222222222222224
0.33333333333333337
0.33333333333333337
0.28571428571428575
0.28571428571428575
0.28571428571428575
0.20000000000000004
0.28571428571428575
0.33333333333333337
0.28571428571428575
0.20000000000000004
0.33333333333333337
0.25
0.25
0.33333333333333337
0.28571428571428575
0.33333333333333337
0.25
0.20000000000000004
0.33333333333333337
0.28571428571428575
0.20000000000000004
0.28571428571428575
0.28571428571428575


Evidence Retrieval F-score: 0.043




 90%|████████▉ | 221/246 [02:30<00:07,  3.30it/s]



epoch: 1, epoch_step: 110, avg loss: nan
learning rate: 0.000275




 98%|█████████▊| 241/246 [02:49<00:01,  3.14it/s]



epoch: 1, epoch_step: 120, avg loss: nan
learning rate: 0.000300




100%|██████████| 246/246 [02:51<00:00,  1.43it/s]
  6%|▌         | 15/246 [00:05<01:24,  2.73it/s]



epoch: 2, epoch_step: 7, avg loss: nan
learning rate: 0.000325




 14%|█▍        | 35/246 [00:14<01:07,  3.12it/s]



epoch: 2, epoch_step: 17, avg loss: nan
learning rate: 0.000350




 22%|██▏       | 53/246 [00:19<00:41,  4.67it/s]



epoch: 2, epoch_step: 27, avg loss: nan
learning rate: 0.000375



Evaluate:






































































 22%|██▏       | 54/246 [00:47<27:29,  8.59s/it]


get all evidence embeddings!
0.22222222222222224
0.28571428571428575
0.33333333333333337
0.22222222222222224
0.33333333333333337
0.28571428571428575
0.28571428571428575
0.20000000000000004
0.5714285714285715
0.33333333333333337
0.33333333333333337
0.20000000000000004
0.25
0.25
0.33333333333333337
0.28571428571428575
0.33333333333333337
0.33333333333333337
0.25
0.20000000000000004
0.33333333333333337
0.20000000000000004
0.28571428571428575
0.28571428571428575


Evidence Retrieval F-score: 0.045




 30%|███       | 74/246 [00:59<00:42,  4.08it/s]



epoch: 2, epoch_step: 37, avg loss: nan
learning rate: 0.000400




 39%|███▊      | 95/246 [01:05<00:32,  4.58it/s]



epoch: 2, epoch_step: 47, avg loss: nan
learning rate: 0.000425




 47%|████▋     | 115/246 [01:11<00:53,  2.46it/s]



epoch: 2, epoch_step: 57, avg loss: nan
learning rate: 0.000450




 55%|█████▍    | 135/246 [01:18<00:35,  3.12it/s]



epoch: 2, epoch_step: 67, avg loss: nan
learning rate: 0.000475




 62%|██████▏   | 153/246 [01:25<00:18,  5.06it/s]



epoch: 2, epoch_step: 77, avg loss: nan
learning rate: 0.000500



Evaluate:






































































 63%|██████▎   | 154/246 [02:04<18:32, 12.09s/it]


get all evidence embeddings!
0.22222222222222224
0.28571428571428575
0.33333333333333337
0.22222222222222224
0.33333333333333337
0.5714285714285715
0.28571428571428575
0.20000000000000004
0.28571428571428575
0.28571428571428575
0.33333333333333337
0.20000000000000004
0.33333333333333337
0.33333333333333337
0.28571428571428575
0.33333333333333337
0.25
0.20000000000000004
0.33333333333333337
0.28571428571428575
0.4000000000000001
0.28571428571428575


Evidence Retrieval F-score: 0.043




 71%|███████   | 175/246 [02:10<00:22,  3.22it/s]



epoch: 2, epoch_step: 87, avg loss: nan
learning rate: 0.000525




 79%|███████▉  | 195/246 [02:15<00:09,  5.53it/s]



epoch: 2, epoch_step: 97, avg loss: nan
learning rate: 0.000550




 87%|████████▋ | 215/246 [02:21<00:09,  3.23it/s]



epoch: 2, epoch_step: 107, avg loss: nan
learning rate: 0.000575




 96%|█████████▌| 235/246 [02:26<00:02,  5.27it/s]



epoch: 2, epoch_step: 117, avg loss: nan
learning rate: 0.000600




100%|██████████| 246/246 [02:29<00:00,  1.64it/s]
  3%|▎         | 7/246 [00:00<00:28,  8.46it/s]



epoch: 3, epoch_step: 4, avg loss: nan
learning rate: 0.000625



Evaluate:





































































get all evidence embeddings!
0.22222222222222224
0.28571428571428575
0.33333333333333337
0.22222222222222224
0.33333333333333337
0.28571428571428575
0.28571428571428575
0.28571428571428575
0.20000000000000004
0.28571428571428575
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.20000000000000004
0.33333333333333337
0.33333333333333337
0.28571428571428575
0.33333333333333337
0.33333333333333337
0.25
0.20000000000000004
0.33333333333333337
0.28571428571428575
0.4000000000000001
0.28571428571428575
0.28571428571428575


Evidence Retrieval F-score: 0.049




  4%|▎         | 9/246 [00:17<14:44,  3.73s/it]



best val loss - epoch: 2, epoch_step: 4
maximum_f_score 0.04934549577406721




 12%|█▏        | 29/246 [00:24<00:33,  6.44it/s]



epoch: 3, epoch_step: 14, avg loss: nan
learning rate: 0.000650




 20%|█▉        | 49/246 [00:30<01:01,  3.21it/s]



epoch: 3, epoch_step: 24, avg loss: nan
learning rate: 0.000675




 28%|██▊       | 68/246 [00:37<00:27,  6.38it/s]



epoch: 3, epoch_step: 34, avg loss: nan
learning rate: 0.000700




 36%|███▌      | 89/246 [00:43<00:36,  4.35it/s]



epoch: 3, epoch_step: 44, avg loss: nan
learning rate: 0.000725




 43%|████▎     | 107/246 [00:47<00:21,  6.50it/s]



epoch: 3, epoch_step: 54, avg loss: nan
learning rate: 0.000750



Evaluate:





































































get all evidence embeddings!
0.22222222222222224
0.28571428571428575
0.33333333333333337
0.22222222222222224
0.33333333333333337
0.28571428571428575
0.28571428571428575
0.28571428571428575
0.20000000000000004
0.5714285714285715
0.33333333333333337
0.33333333333333337
0.20000000000000004
0.25
0.33333333333333337
0.33333333333333337
0.28571428571428575
0.33333333333333337
0.33333333333333337
0.25
0.20000000000000004
0.33333333333333337
0.28571428571428575
0.4000000000000001
0.28571428571428575
0.28571428571428575


Evidence Retrieval F-score: 0.051




best val loss - epoch: 2, epoch_step: 54
maximum_f_score 0.05065965780251495




 52%|█████▏    | 129/246 [01:10<00:25,  4.65it/s]



epoch: 3, epoch_step: 64, avg loss: nan
learning rate: 0.000775




 60%|██████    | 148/246 [01:15<00:17,  5.63it/s]



epoch: 3, epoch_step: 74, avg loss: nan
learning rate: 0.000800




 69%|██████▊   | 169/246 [01:33<01:02,  1.24it/s]



epoch: 3, epoch_step: 84, avg loss: nan
learning rate: 0.000825




 77%|███████▋  | 189/246 [01:44<00:12,  4.53it/s]



epoch: 3, epoch_step: 94, avg loss: nan
learning rate: 0.000850




 84%|████████▍ | 207/246 [01:50<00:10,  3.80it/s]



epoch: 3, epoch_step: 104, avg loss: nan
learning rate: 0.000875



Evaluate:





































































get all evidence embeddings!


 85%|████████▍ | 209/246 [02:09<02:29,  4.04s/it]

0.22222222222222224
0.28571428571428575
0.28571428571428575
0.33333333333333337
0.22222222222222224
0.33333333333333337
0.28571428571428575
0.28571428571428575
0.20000000000000004
0.5714285714285715
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.25
0.33333333333333337
0.33333333333333337
0.28571428571428575
0.33333333333333337
0.33333333333333337
0.25
0.20000000000000004
0.33333333333333337
0.28571428571428575
0.28571428571428575
0.28571428571428575


Evidence Retrieval F-score: 0.049




 93%|█████████▎| 229/246 [02:14<00:03,  5.20it/s]



epoch: 3, epoch_step: 114, avg loss: nan
learning rate: 0.000900




100%|██████████| 246/246 [02:18<00:00,  1.77it/s]
  1%|          | 2/246 [00:01<03:16,  1.24it/s]



epoch: 4, epoch_step: 1, avg loss: nan
learning rate: 0.000925




  9%|▉         | 23/246 [05:49<16:48,  4.52s/it]  



epoch: 4, epoch_step: 11, avg loss: nan
learning rate: 0.000950




 17%|█▋        | 43/246 [08:00<05:34,  1.65s/it]  



epoch: 4, epoch_step: 21, avg loss: nan
learning rate: 0.000975




 25%|██▍       | 61/246 [09:03<01:41,  1.82it/s]



epoch: 4, epoch_step: 31, avg loss: nan
learning rate: 0.001000



Evaluate:





































































get all evidence embeddings!
0.22222222222222224
0.20000000000000004
0.28571428571428575
0.33333333333333337
0.22222222222222224
0.33333333333333337
0.28571428571428575
0.20000000000000004
0.5714285714285715
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.4000000000000001
0.33333333333333337
0.33333333333333337
0.28571428571428575
0.33333333333333337
0.33333333333333337
0.25
0.20000000000000004
0.33333333333333337
0.28571428571428575
0.4000000000000001
0.28571428571428575
0.28571428571428575


Evidence Retrieval F-score: 0.052




 25%|██▌       | 62/246 [17:02<7:21:17, 143.90s/it]



best val loss - epoch: 3, epoch_step: 31
maximum_f_score 0.05225211296639869




 34%|███▎      | 83/246 [17:23<02:45,  1.01s/it]   



epoch: 4, epoch_step: 41, avg loss: nan
learning rate: 0.000900




 41%|████▏     | 102/246 [17:55<01:59,  1.20it/s]



epoch: 4, epoch_step: 51, avg loss: nan
learning rate: 0.000800




 50%|█████     | 123/246 [22:04<25:51, 12.61s/it]  



epoch: 4, epoch_step: 61, avg loss: nan
learning rate: 0.000700




 58%|█████▊    | 142/246 [22:59<03:01,  1.75s/it]



epoch: 4, epoch_step: 71, avg loss: nan
learning rate: 0.000600




 65%|██████▌   | 161/246 [24:26<10:18,  7.27s/it]



epoch: 4, epoch_step: 81, avg loss: nan
learning rate: 0.000500



Evaluate:





































































get all evidence embeddings!
0.22222222222222224
0.28571428571428575
0.33333333333333337
0.22222222222222224
0.33333333333333337
0.28571428571428575
0.28571428571428575
0.28571428571428575
0.20000000000000004
0.5714285714285715
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.20000000000000004
0.33333333333333337
0.33333333333333337
0.28571428571428575
0.33333333333333337
0.33333333333333337
0.25
0.20000000000000004
0.33333333333333337
0.28571428571428575
0.4000000000000001
0.28571428571428575
0.28571428571428575


Evidence Retrieval F-score: 0.053




 66%|██████▋   | 163/246 [31:57<2:16:06, 98.39s/it] 



best val loss - epoch: 3, epoch_step: 81
maximum_f_score 0.05336528550814265




 74%|███████▍  | 182/246 [33:01<02:30,  2.35s/it]  



epoch: 4, epoch_step: 91, avg loss: nan
learning rate: 0.000400




 82%|████████▏ | 202/246 [34:46<01:36,  2.18s/it]



epoch: 4, epoch_step: 101, avg loss: nan
learning rate: 0.000300




 91%|█████████ | 223/246 [36:04<00:46,  2.00s/it]



epoch: 4, epoch_step: 111, avg loss: nan
learning rate: 0.000200




 98%|█████████▊| 242/246 [36:16<00:01,  3.15it/s]



epoch: 4, epoch_step: 121, avg loss: nan
learning rate: 0.000100




100%|██████████| 246/246 [36:18<00:00,  8.86s/it]
  6%|▌         | 15/246 [01:28<52:14, 13.57s/it]



epoch: 5, epoch_step: 8, avg loss: nan
learning rate: 0.000000



Evaluate:





































































get all evidence embeddings!


  7%|▋         | 16/246 [06:31<6:04:40, 95.13s/it]

0.22222222222222224
0.20000000000000004
0.28571428571428575
0.33333333333333337
0.33333333333333337
0.28571428571428575
0.28571428571428575
0.28571428571428575
0.5714285714285715
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.20000000000000004
0.33333333333333337
0.33333333333333337
0.28571428571428575
0.33333333333333337
0.33333333333333337
0.25
0.20000000000000004
0.33333333333333337
0.33333333333333337
0.28571428571428575
0.4000000000000001
0.28571428571428575


Evidence Retrieval F-score: 0.052




 15%|█▌        | 37/246 [11:44<1:08:48, 19.76s/it]



epoch: 5, epoch_step: 18, avg loss: nan
learning rate: -0.000100




 23%|██▎       | 57/246 [14:10<01:48,  1.75it/s]  



epoch: 5, epoch_step: 28, avg loss: nan
learning rate: -0.000200




 29%|██▉       | 72/246 [15:35<37:41, 13.00s/it]


KeyboardInterrupt: 

In [48]:
torch.cuda.empty_cache()

# 3.Testing and Evaluation
(You can add as many code blocks and text blocks as you need. However, YOU SHOULD NOT MODIFY the section title)

In [49]:
import os
trans_encoder.load_state_dict(torch.load(os.path.join(save_dir, "best_ckpt.bin")))
trans_encoder.cuda()
trans_encoder.eval()

Encoder(
  (embedding): Embedding(7075, 256)
  (pos_embedding): Embedding(180, 256)
  (encoder): LSTM(256, 256, num_layers=6, batch_first=True, dropout=0.2, bidirectional=True)
  (dropout): Dropout(p=0.2, inplace=False)
)

In [50]:
evidence_embeddings = []
start_idx = 0
batch_size = 1000
evidence_len = len(evidences_input[0])

while start_idx < len(evidences_input):
    end_idx = min(start_idx + batch_size, len(evidences_input))
    
    cur_evidence = torch.LongTensor(evidences_input[start_idx:end_idx]).view(-1, evidence_len).cuda()
    cur_evidence_pos = torch.LongTensor([list(range(evidence_len)) for _ in range(end_idx - start_idx)]).cuda()

    cur_embedding = trans_encoder(cur_evidence, cur_evidence_pos)
    cur_embedding = cur_embedding[:, -1, :].detach()
    cur_embedding_cpu = F.normalize(cur_embedding, p=2, dim=1).cpu()  # for cosine similarity
    
    del cur_embedding, cur_evidence, cur_evidence_pos
    start_idx = end_idx
    evidence_embeddings.append(cur_embedding_cpu)
    
evidence_embeddings = torch.cat(evidence_embeddings, dim=0).t()


In [51]:
torch.cuda.empty_cache()

In [52]:
import numpy as np

In [53]:
retrieval_num = 5
dev_candis_num = 10

def validate_(dev_text_idx, evidence_embeddings, dev_sort_evidences, dev_claim_evidences, encoder_model):
    # get evidence embeddings
    encoder_model.eval()

    text_len = len(dev_text_idx[0])
    f = []
    
    start_idx = 0
    batch_size = 200
    
    while start_idx < len(dev_text_idx):
        end_idx = min(start_idx + batch_size, len(dev_text_idx))
        
        cur_query = torch.LongTensor(dev_text_idx[start_idx:end_idx]).view(-1, text_len).cuda()
        print(cur_query.size())
        cur_query_pos = torch.LongTensor([list(range(text_len)) for _ in range(end_idx - start_idx)]).cuda()

        query_embedding = encoder_model(cur_query, cur_query_pos)
        query_embedding = query_embedding[:, -1, :].detach()
        query_embedding = F.normalize(query_embedding, p=2, dim=1).cpu()
        
        scores = torch.mm(query_embedding, evidence_embeddings)
        
        for i in range(scores.size(0)):
            new_score = torch.index_select(scores[i], 0, torch.LongTensor(dev_sort_evidences[start_idx+i][:dev_candis_num]))
            topk_ids = torch.argsort(new_score).tolist()
            select_ids = topk_ids[:retrieval_num]
            
            evidence_correct = 0
            pred_evidences = [dev_sort_evidences[start_idx+i][j] for j in select_ids]
                        
            label = dev_claim_evidences[start_idx+i]
                        
            for evidence_id in label:
                if evidence_id in pred_evidences:
                    evidence_correct += 1
            if evidence_correct > 0:
                evidence_recall = float(evidence_correct) / len(label)
                evidence_precision = float(evidence_correct) / len(pred_evidences)
                evidence_fscore = (2 * evidence_precision * evidence_recall) / (evidence_precision + evidence_recall)
                # print(evidence_fscore)
            else:
                evidence_fscore = 0
            f.append(evidence_fscore)
            
        start_idx = end_idx
        # print("----")
    fscore = np.mean(f)
    print("\n")
    print("Evidence Retrieval F-score: %.3f" % fscore)
    print("\n")
    return fscore

In [54]:
retrieval_num = 5
dev_candis_num = 10
fscore = validate_(dev_input, evidence_embeddings, dev_sort_evidences, dev_claim_evidences, trans_encoder)
print(fscore)

torch.Size([154, 62])


Evidence Retrieval F-score: 0.053


0.05336528550814265


In [55]:
retrieval_num = 5
dev_candis_num = 10

def evidence_predicts(dev_text_idx, evidences_embeddings, dev_sort_evidences, cleaned_evidence_id, encoder_model):
    # get evidence embeddings
    text_len = len(dev_text_idx[0])
    encoder_model.eval()

    f = []
    start_idx = 0
    batch_size = 200
    preds = []
    while start_idx < len(dev_text_idx):
        end_idx = min(start_idx + batch_size, len(dev_text_idx))
        
        cur_query = torch.LongTensor(dev_text_idx[start_idx:end_idx]).view(-1, text_len).cuda()
        cur_query_pos = torch.LongTensor([list(range(text_len)) for _ in range(end_idx - start_idx)]).cuda()

        query_embedding = encoder_model(cur_query, cur_query_pos)
        query_embedding = query_embedding[:, -1, :].detach()
        query_embedding = F.normalize(query_embedding, p=2, dim=1).cpu()
        
        scores = torch.mm(query_embedding, evidences_embeddings)

        for i in range(scores.size(0)):
            new_score = torch.index_select(scores[i], 0, torch.LongTensor(dev_sort_evidences[start_idx+i][:dev_candis_num]))
            topk_ids = torch.argsort(new_score).tolist()
            select_ids = topk_ids[:retrieval_num]
            
            pred_evidences = [cleaned_evidence_id[dev_sort_evidences[start_idx+i][j]] for j in select_ids]
            preds.append(pred_evidences)
            
        start_idx = end_idx
    return preds

In [56]:
dev_evidences_ids = evidence_predicts(dev_input, evidence_embeddings, dev_sort_evidences, cleaned_evidence_id, trans_encoder)
test_evidences_ids = evidence_predicts(test_input, evidence_embeddings, test_sort_evidences, cleaned_evidence_id, trans_encoder)

In [57]:
pred_dev_claims = {}
pred_test_claims = {}
dev_claims = json.load(open("data/dev-claims.json", "r"))
test_claims = json.load(open("data/test-claims-unlabelled.json", "r"))

for idx, evidence_ids in enumerate(dev_evidences_ids):
    cur_data = dev_claims[dev_claim_id[idx]]
    cur_data['evidences'] = evidence_ids
    pred_dev_claims[dev_claim_id[idx]] = cur_data
    

for idx, evidence_ids in enumerate(test_evidences_ids):
    cur_data = test_claims[test_claim_id[idx]]
    cur_data['evidences'] = evidence_ids
    pred_test_claims[test_claim_id[idx]] = cur_data


In [58]:
json.dump(pred_dev_claims, open("data/dev_predict.json", "w"))
json.dump(pred_test_claims, open("data/test-claims-unlabelled.json", "w"))

In [59]:
retrieval_num = 5
dev_candis_num = 10

train_evidences_ids = evidence_predicts(train_input, evidence_embeddings, train_sort_evidences, cleaned_evidence_id, trans_encoder)

pred_train_negative_evidences = []
for idx, evidence_ids in enumerate(train_evidences_ids):
    temp_ = []
    for i in evidence_ids:
        if evidences_id_dict[i] not in train_claim_evidences[idx]:
            temp_.append(evidences_id_dict[i])
    pred_train_negative_evidences.append(temp_)

In [60]:
## save prediction data

json.dump(pred_train_negative_evidences, open("pred_train_negative_evidences.json", "w"))

In [61]:
## save cls data

dev_cls_data = []
test_cls_data = []
text_max_len = 60
evidence_max_len = 100
all_max_len = 580

for idx, dev_text in enumerate(dev_text_idx):
    cur_data = {"label": dev_claim_label[idx]}
    temp_text = [word2idx["<cls>"]] + dev_text_idx[idx][:text_max_len]
    for i in dev_evidences_ids[idx]:
        temp_text.extend([word2idx["<sep>"]] + evidences_text_idx[evidences_id_dict[i]][:evidence_max_len])
    temp_text.append(word2idx["<sep>"])
    if len(temp_text) < all_max_len:
        temp_text.extend([word2idx["<pad>"]] * (all_max_len - len(temp_text)))
    cur_data['text'] = temp_text
    dev_cls_data.append(cur_data)

for idx, dev_text in enumerate(test_text_idx):
    cur_data = {}
    temp_text = [word2idx["<cls>"]] + test_text_idx[idx][:text_max_len]
    for i in test_evidences_ids[idx]:
        temp_text.extend([word2idx["<sep>"]] + evidences_text_idx[evidences_id_dict[i]][:evidence_max_len])
    temp_text.append(word2idx["<sep>"])
    if len(temp_text) < all_max_len:
        temp_text.extend([word2idx["<pad>"]] * (all_max_len - len(temp_text)))
    cur_data['text'] = temp_text
    test_cls_data.append(cur_data)

json.dump(dev_cls_data, open("dev_cls_data.json", "w"))
json.dump(test_cls_data, open("test_cls_data.json", "w"))

Task2 

Preprocessing

In [62]:
import json

dev_cls_data = json.load(open("dev_cls_data.json", "r"))
test_cls_data = json.load(open("test_cls_data.json", "r"))


# train_text_idx = json.load(open("temp_data/train_text_idx.json", "r"))
# evidences_text_idx = json.load(open("temp_data/evidences_text_idx.json", "r"))

text_max_len = 60
evidence_max_len = 100
all_max_len = 580
retrieval_num = 5

id2labels = ["SUPPORTS", "NOT_ENOUGH_INFO", "REFUTES", "DISPUTED"]
labels2id = {"SUPPORTS": 0, "NOT_ENOUGH_INFO": 1, "REFUTES": 2, "DISPUTED": 3}

train_negative_evidences = json.load(open("pred_train_negative_evidences.json", "r"))


In [65]:
from torch.utils.data import Dataset
import random

class TrainDataset(Dataset):
    def __init__(self, text_data, evidence_data, positive_evidences, negative_evidences, cls_label, cls_idx, sep_idx, pad_idx, evidence_num=5):
        self.text_data = text_data
        self.evidence_data = evidence_data
        
        self.negative_evidences = negative_evidences

        self.cls_label = [labels2id[i] for i in cls_label]
        self.evidence_num = evidence_num
        self.positive_evidences = positive_evidences
        
        self.cls_idx = cls_idx
        self.sep_idx = sep_idx
        self.pad_idx = pad_idx
        
    def __len__(self):
        return len(self.text_data)

    def __getitem__(self, idx):
        return [self.text_data[idx][:text_max_len], self.positive_evidences[idx], self.negative_evidences[idx], self.cls_label[idx]]

    def collate_fn(self, batch):
        queries = []
        queries_pos = []
        labels = []
        
        for i, j, h, k in batch:
            temp_text = [self.cls_idx]
            temp_text.extend(i)
            for p in j:
                temp_text.append(self.sep_idx)
                temp_text.extend(self.evidence_data[p][:evidence_max_len])
            if self.evidence_num > len(j):
                n = random.sample(h, self.evidence_num - len(j))
                for p in n:
                    temp_text.append(self.sep_idx)
                    temp_text.extend(self.evidence_data[p][:evidence_max_len])
            temp_text.append(self.sep_idx)
            if len(temp_text) < all_max_len:
                temp_text.extend([self.pad_idx] * (all_max_len - len(temp_text)))
                
            queries.append(temp_text)
            queries_pos.append(list(range(all_max_len)))
            labels.append(k)    

        batch_encoding = {}
        batch_encoding["queries"] = torch.LongTensor(queries)        
        batch_encoding["queries_pos"] = torch.LongTensor(queries_pos)
        batch_encoding["labels"] = torch.LongTensor(labels)
        
        return batch_encoding

In [67]:
dev_inputs = [i['text'] for i in dev_cls_data]
test_inputs = [i['text'] for i in test_cls_data]
dev_outputs = [labels2id[i["label"]] for i in dev_cls_data]

In [68]:
train_set = TrainDataset(train_text_idx, evidences_text_idx, train_claim_evidences, train_negative_evidences, train_claim_label, word2idx["<cls>"], word2idx["<sep>"], word2idx["<pad>"], evidence_num=retrieval_num)
from torch.utils.data import DataLoader
dataloader = DataLoader(train_set, batch_size=10, shuffle=True, num_workers=0, collate_fn=train_set.collate_fn)

In [69]:
from collections import Counter
print(Counter(train_claim_label))

Counter({'SUPPORTS': 519, 'NOT_ENOUGH_INFO': 386, 'REFUTES': 199, 'DISPUTED': 124})


In [105]:
# from workshop
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class CLS(nn.Module):
    def __init__(self, vocab_emb, embed_dim, hidden_size, output_size, num_layers, max_position=all_max_len):
        super(CLS, self).__init__()

        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_emb, embed_dim)
        self.pos_embedding = nn.Embedding(max_position, embed_dim)
        
        self.encoder = nn.LSTM(embed_dim, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=True)
        self.hidden_layer = nn.Linear(hidden_size * 2, hidden_size)  
        self.cls = nn.Linear(hidden_size, output_size)  
        self.dropout = nn.Dropout(0.1)

    def forward(self, text_data, position_text):
        text_x = self.embedding(text_data) + self.pos_embedding(position_text) * 0.01
        x_encoded,_ = self.encoder(text_x)
        x_cls = x_encoded[:, 0, :]
        x_hidden = F.tanh(self.hidden_layer(x_cls))
        self.dropout(x_hidden)
        cls_res = self.cls(x_hidden)
        return cls_res


In [106]:
cls_model = CLS(vocab_emb=len(idx2word), embed_dim=512, hidden_size=512, output_size=4, num_layers=6, max_position=700)
cls_model.cuda()

CLS(
  (embedding): Embedding(7075, 512)
  (pos_embedding): Embedding(700, 512)
  (encoder): LSTM(512, 512, num_layers=6, batch_first=True, bidirectional=True)
  (hidden_layer): Linear(in_features=1024, out_features=512, bias=True)
  (cls): Linear(in_features=512, out_features=4, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [93]:
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
random.seed(42)

encoder_optimizer = optim.Adam(cls_model.parameters())
max_lr = 1e-3
for param_group in encoder_optimizer.param_groups:
    param_group['lr'] = max_lr
accumulate_step = 2
grad_norm = 4
warmup_steps = 300
report_freq = 10
eval_interval = 50
save_dir = "model_ckpts"

In [94]:
def validate(dev_input, dev_output, cls_model_):
    # get evidence embeddings
    start_idx = 0
    batch_size = 50
    pos_len = len(dev_input[0])
    cls_model.eval()

    acc = []
    correct_count = 0
    while start_idx < len(dev_output):
        end_idx = min(start_idx + batch_size, len(dev_output))
        
        cur_input = torch.LongTensor(dev_input[start_idx:end_idx]).view(-1, pos_len).cuda()
        cur_pos = torch.LongTensor([list(range(pos_len)) for _ in range(end_idx - start_idx)]).cuda()

        cur_res = cls_model_(cur_input, cur_pos)
        cur_res = torch.argmax(cur_res, 1).tolist()
        
        del cur_input, cur_pos
        
        for i, j in zip(cur_res, dev_output[start_idx: end_idx]):
            if i == j:
                correct_count += 1
        
        start_idx = end_idx
    acc = correct_count / len(dev_output)
    print("\n")
    print("Classification Accuracy: %.3f" % acc)
    print("\n")
    
    cls_model.train()
    return acc

In [None]:
%env WANDB_NOTEBOOK_NAME MelMoxue_NLP_CLS.ipynb

env: WANDB_NOTEBOOK_NAME=MelMoxue_NLP_CLS.ipynb


In [95]:
# start training
import wandb
import os
wandb.init(project="nlp", name="cls")

from tqdm import tqdm
import numpy as np

encoder_optimizer.zero_grad()
step_cnt = 0
all_step_cnt = 0
avg_loss = 0
maximum_f_score = 0
ce_fn = nn.CrossEntropyLoss(torch.FloatTensor([0.2, 0.3, 0.5, 1.]).cuda())

for epoch in range(5): 
    epoch_step = 0

    for (i, batch) in enumerate(tqdm(dataloader)):
        
        step_cnt += 1
        
        # forward pass
            
        cur_res = cls_model(batch["queries"].cuda(), batch["queries_pos"].cuda())

        loss = ce_fn(cur_res, batch["labels"].cuda())
        loss = loss / accumulate_step
        loss.backward()

        avg_loss += loss.item()
        if step_cnt == accumulate_step:
            # updating
            if grad_norm > 0:
                nn.utils.clip_grad_norm_(cls_model.parameters(), grad_norm)

            step_cnt = 0
            epoch_step += 1
            all_step_cnt += 1
            
            # adjust learning rate
            if all_step_cnt <= warmup_steps:
                lr = all_step_cnt * (max_lr - 2e-8) / warmup_steps + 2e-8
            else:
                lr = max_lr - (all_step_cnt - warmup_steps) * 1e-6
                
            encoder_optimizer.step()
            encoder_optimizer.zero_grad()
        
        if all_step_cnt % report_freq == 0 and step_cnt == 0:
            if all_step_cnt <= warmup_steps:
                lr = all_step_cnt * (max_lr - 2e-8) / warmup_steps + 2e-8
            else:
                lr = max_lr - (all_step_cnt - warmup_steps) * 1e-6

            wandb.log({"learning_rate": lr}, step=all_step_cnt)
            wandb.log({"loss": avg_loss / report_freq}, step=all_step_cnt)
            
            # report stats
            print("\n")
            print("epoch: %d, epoch_step: %d, avg loss: %.6f" % (epoch + 1, epoch_step, avg_loss / report_freq))
            print(f"learning rate: {lr:.6f}")
            print("\n")

            avg_loss = 0
        del loss, cur_res

        if all_step_cnt % eval_interval == 0 and all_step_cnt != 0 and step_cnt == 0:
            # evaluate the model as a scorer
            print("\nEvaluate:\n")
            
            f_score = validate(dev_inputs, dev_outputs, cls_model)
            wandb.log({"acc": f_score}, step=all_step_cnt)

            if f_score > maximum_f_score:
                maximum_f_score = f_score
                torch.save(cls_model.state_dict(), os.path.join(save_dir, "best_cls_ckpt.bin"))
                # torch.save(last_evidence_embeddings, os.path.join(save_dir, "evidence_embeddings"))
                print("\n")
                print("best val loss - epoch: %d, epoch_step: %d" % (epoch, epoch_step))
                print("maximum_f_score", f_score)
                print("\n")

 16%|█▋        | 20/123 [00:09<00:48,  2.13it/s]



epoch: 1, epoch_step: 10, avg loss: 1.415848
learning rate: 0.000033




 33%|███▎      | 40/123 [00:19<00:42,  1.97it/s]



epoch: 1, epoch_step: 20, avg loss: 1.386649
learning rate: 0.000067




 49%|████▉     | 60/123 [00:29<00:32,  1.95it/s]



epoch: 1, epoch_step: 30, avg loss: 1.388714
learning rate: 0.000100




 65%|██████▌   | 80/123 [00:38<00:19,  2.22it/s]



epoch: 1, epoch_step: 40, avg loss: 1.383278
learning rate: 0.000133




 80%|████████  | 99/123 [00:47<00:11,  2.17it/s]



epoch: 1, epoch_step: 50, avg loss: 1.401325
learning rate: 0.000167



Evaluate:



Classification Accuracy: 0.442




 81%|████████▏ | 100/123 [00:50<00:30,  1.32s/it]



best val loss - epoch: 0, epoch_step: 50
maximum_f_score 0.44155844155844154




 98%|█████████▊| 120/123 [00:59<00:01,  2.26it/s]



epoch: 1, epoch_step: 60, avg loss: 1.400904
learning rate: 0.000200




100%|██████████| 123/123 [01:01<00:00,  2.01it/s]
 14%|█▍        | 17/123 [00:07<00:53,  2.00it/s]



epoch: 2, epoch_step: 9, avg loss: 1.396503
learning rate: 0.000233




 30%|███       | 37/123 [00:18<00:43,  2.00it/s]



epoch: 2, epoch_step: 19, avg loss: 1.383510
learning rate: 0.000267




 46%|████▋     | 57/123 [00:27<00:29,  2.21it/s]



epoch: 2, epoch_step: 29, avg loss: 1.394570
learning rate: 0.000300




 62%|██████▏   | 76/123 [00:36<00:21,  2.21it/s]



epoch: 2, epoch_step: 39, avg loss: 1.403684
learning rate: 0.000333



Evaluate:



 63%|██████▎   | 77/123 [00:37<00:34,  1.33it/s]



Classification Accuracy: 0.266




 79%|███████▉  | 97/123 [00:46<00:11,  2.20it/s]



epoch: 2, epoch_step: 49, avg loss: 1.388360
learning rate: 0.000367




 95%|█████████▌| 117/123 [00:55<00:02,  2.33it/s]



epoch: 2, epoch_step: 59, avg loss: 1.397062
learning rate: 0.000400




100%|██████████| 123/123 [00:58<00:00,  2.11it/s]
 11%|█▏        | 14/123 [00:06<00:50,  2.18it/s]



epoch: 3, epoch_step: 7, avg loss: 1.380241
learning rate: 0.000433




 28%|██▊       | 34/123 [00:15<00:40,  2.21it/s]



epoch: 3, epoch_step: 17, avg loss: 1.379174
learning rate: 0.000467




 43%|████▎     | 53/123 [00:23<00:30,  2.33it/s]



epoch: 3, epoch_step: 27, avg loss: 1.400237
learning rate: 0.000500



Evaluate:



 44%|████▍     | 54/123 [00:25<00:49,  1.41it/s]



Classification Accuracy: 0.292




 60%|██████    | 74/123 [00:33<00:20,  2.42it/s]



epoch: 3, epoch_step: 37, avg loss: 1.388691
learning rate: 0.000533




 76%|███████▋  | 94/123 [00:42<00:12,  2.37it/s]



epoch: 3, epoch_step: 47, avg loss: 1.365255
learning rate: 0.000567




 93%|█████████▎| 114/123 [00:50<00:03,  2.39it/s]



epoch: 3, epoch_step: 57, avg loss: 1.381990
learning rate: 0.000600




100%|██████████| 123/123 [00:54<00:00,  2.26it/s]
  9%|▉         | 11/123 [00:04<00:47,  2.35it/s]



epoch: 4, epoch_step: 6, avg loss: 1.375136
learning rate: 0.000633




 24%|██▍       | 30/123 [00:12<00:39,  2.35it/s]



epoch: 4, epoch_step: 16, avg loss: 1.382608
learning rate: 0.000667



Evaluate:



 25%|██▌       | 31/123 [00:14<01:04,  1.43it/s]



Classification Accuracy: 0.279




 41%|████▏     | 51/123 [00:22<00:30,  2.37it/s]



epoch: 4, epoch_step: 26, avg loss: 1.358906
learning rate: 0.000700




 58%|█████▊    | 71/123 [00:30<00:21,  2.40it/s]



epoch: 4, epoch_step: 36, avg loss: 1.333897
learning rate: 0.000733




 74%|███████▍  | 91/123 [00:39<00:14,  2.28it/s]



epoch: 4, epoch_step: 46, avg loss: 1.346778
learning rate: 0.000767




 90%|█████████ | 111/123 [00:47<00:05,  2.39it/s]



epoch: 4, epoch_step: 56, avg loss: 1.329010
learning rate: 0.000800




100%|██████████| 123/123 [00:52<00:00,  2.32it/s]
  6%|▌         | 7/123 [00:02<00:49,  2.35it/s]



epoch: 5, epoch_step: 4, avg loss: 1.315999
learning rate: 0.000833



Evaluate:



  7%|▋         | 8/123 [00:04<01:22,  1.39it/s]



Classification Accuracy: 0.195




 23%|██▎       | 28/123 [00:12<00:40,  2.37it/s]



epoch: 5, epoch_step: 14, avg loss: 1.318257
learning rate: 0.000867




 39%|███▉      | 48/123 [00:21<00:31,  2.39it/s]



epoch: 5, epoch_step: 24, avg loss: 1.292128
learning rate: 0.000900




 55%|█████▌    | 68/123 [00:29<00:23,  2.38it/s]



epoch: 5, epoch_step: 34, avg loss: 1.251879
learning rate: 0.000933




 72%|███████▏  | 88/123 [00:38<00:15,  2.24it/s]



epoch: 5, epoch_step: 44, avg loss: 1.269737
learning rate: 0.000967




 87%|████████▋ | 107/123 [00:47<00:07,  2.19it/s]



epoch: 5, epoch_step: 54, avg loss: 1.254603
learning rate: 0.001000



Evaluate:



 88%|████████▊ | 108/123 [00:48<00:11,  1.33it/s]



Classification Accuracy: 0.286




100%|██████████| 123/123 [00:55<00:00,  2.22it/s]


In [96]:
def predict(dev_input, cls_model_):
    # get evidence embeddings
    start_idx = 0
    batch_size = 50
    pos_len = len(dev_input[0])
    cls_model.eval()

    cls_res = []
    correct_count = 0
    while start_idx < len(dev_input):
        end_idx = min(start_idx + batch_size, len(dev_input))
        
        cur_input = torch.LongTensor(dev_input[start_idx:end_idx]).view(-1, pos_len).cuda()
        cur_pos = torch.LongTensor([list(range(pos_len)) for _ in range(end_idx - start_idx)]).cuda()

        cur_res = cls_model_(cur_input, cur_pos)
        cur_res = torch.argmax(cur_res, 1).tolist()
        
        del cur_input, cur_pos
        
        cls_res.extend(cur_res)
        
        start_idx = end_idx

    return cls_res

In [97]:
torch.cuda.empty_cache()

In [98]:
import os
cls_model.load_state_dict(torch.load(os.path.join(save_dir, "best_cls_ckpt.bin")))

dev_classes = predict(dev_inputs, cls_model)
test_classes = predict(test_inputs, cls_model)

In [99]:
pred_dev_claims = json.load(open("data/dev_predict.json", "r"))
pred_test_claims = json.load(open("data/test-claims-unlabelled.json", "r"))

for i, j in zip(dev_claim_id, dev_classes):
    claim_label = id2labels[j]
    evidences = pred_dev_claims[i]['evidences']
    pred_dev_claims[i] = {'claim_text': pred_dev_claims[i]['claim_text'], 'claim_label': claim_label, 'evidences': evidences}

for i, j in zip(test_claim_id, test_classes):
    claim_label = id2labels[j]
    evidences = pred_test_claims[i]['evidences']
    pred_test_claims[i] = {'claim_text': pred_test_claims[i]['claim_text'], 'claim_label': claim_label, 'evidences': evidences}

json.dump(pred_dev_claims, open("data/dev_predict.json", "w"))
json.dump(pred_test_claims, open("data/test-claims-unlabelled.json", "w"))
    

In [100]:
from collections import Counter
print(Counter(dev_classes))

Counter({0: 154})


In [101]:
print(Counter(test_classes))

Counter({0: 153})


In [103]:
import subprocess

output = subprocess.check_output("python eval.py --predictions data/dev_predict.json --groundtruth data/dev-claims.json", shell=True)
output_str = output.decode('utf-8')

# Split the output into lines
output_lines = output_str.strip().split('\n')

# Format the output
formatted_lines = []
for line in output_lines:
    metric, value = line.split('=')
    metric = metric.strip()
    value = value.strip()
    formatted_line = f"{metric}: {value}"
    formatted_lines.append(formatted_line)

# Join the formatted lines into a single string
formatted_output = '\n'.join(formatted_lines)
print(formatted_output)

Evidence Retrieval F-score (F): 0.04954648526077098
Claim Classification Accuracy (A): 0.44155844155844154
Harmonic Mean of F and A: 0.08909570082361655


## Object Oriented Programming codes here

*You can use multiple code snippets. Just add more if needed*