# 2024 COMP90042 Project
*Make sure you change the file name with your group id.*

# Readme
*If there is something to be noted for the marker, please mention here.*

*If you are planning to implement a program with Object Oriented Programming style, please put those the bottom of this ipynb file*

**We use pytorch, nltk, scikit-learn in this project.**

# 1.DataSet Processing
(You can add as many code blocks and text blocks as you need. However, YOU SHOULD NOT MODIFY the section title)

## PreProcess for evidence and claims

### preprocessing function

In [2]:
import torch

print("CUDA available:", torch.cuda.is_available())

CUDA available: True


In [3]:
print("Current CUDA device:", torch.cuda.current_device())
print("Device count:", torch.cuda.device_count())
print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))

Current CUDA device: 0
Device count: 1
Device name: NVIDIA GeForce RTX 3070 Ti


### read files

In [59]:
import json
import nltk
import string
import re
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from collections import Counter
from statistics import mean

nltk.download('stopwords')
nltk.download('wordnet')

with open('data/train-claims.json', 'r') as input_file:
    train_claims = json.load(input_file)

# Read in development data (claim)
with open('data/dev-claims.json', 'r') as input_file:
    dev_claims = json.load(input_file)

# Read in test data (claim)
with open('data/test-claims-unlabelled.json', 'r') as input_file:
    test_claims = json.load(input_file)

# Read in evidence data
with open('data/evidence.json', 'r') as input_file:
    evidences = json.load(input_file)

#EDA


claim_count = 0
evi_count = 0
claim_length = []
evidence_count = []
evidence_length = []
labels = []

for key,value in train_claims.items():
    claim_count+=1
    claim_length.append(len(value["claim_text"]))
    evidence_count.append(len(value["evidences"]))
    evidence_length += [len(evidences[x]) for x in value["evidences"]]
    labels.append(value["claim_label"])

for key,value in evidences.items():
    evi_count+=1

print("claim count: ",claim_count)
print("evidence count: ",evi_count)
print("max claim length: ",max(claim_length))
print("min claim length: ",min(claim_length))
print("mean claim length: ",mean(claim_length))
print("max evidence count: ",max(evidence_count))
print("min evidence count: ",min(evidence_count))
print("mean evidence count: ",mean(evidence_count))
print("max evidence length: ",max(evidence_length))
print("min evidence length: ",min(evidence_length))
print("mean evidence length: ",mean(evidence_length))
print(Counter(labels))



inside = 0
outside = 0

train_evi_id = []
for claim_id,claim_value in train_claims.items():
    train_evi_id=train_evi_id+claim_value['evidences']

for claim_id,claim_value in dev_claims.items():
    test_evi_id=claim_value['evidences']
    for e in test_evi_id:
        if e in train_evi_id:
            inside += 1
        else:
            outside += 1
print("Dev evi inside train evi", inside)
print("Dev evi outside train evi", outside)

full_evidence_id = list(evidences.keys())
full_evidence_text  = list(evidences.values())
train_claim_id = list(train_claims.keys())
train_claim_text  = [ v["claim_text"] for v in train_claims.values()]
print("Train claim count: ",len(train_claim_id))

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\ABC\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\ABC\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


claim count:  1228
evidence count:  1208827
max claim length:  332
min claim length:  26
mean claim length:  122.95521172638436
max evidence count:  5
min evidence count:  1
mean evidence count:  3.3566775244299674
max evidence length:  1979
min evidence length:  13
mean evidence length:  173.5
Counter({'SUPPORTS': 519, 'NOT_ENOUGH_INFO': 386, 'REFUTES': 199, 'DISPUTED': 124})
Dev evi inside train evi 163
Dev evi outside train evi 328
Train claim count:  1228


In [60]:
lemmatizer = nltk.stem.wordnet.WordNetLemmatizer()
stopwords = set(stopwords.words('english'))

def lemmatize(word):
    lemma = lemmatizer.lemmatize(word, 'v')
    return lemma if lemma != word else lemmatizer.lemmatize(word, 'n')

def is_pure_english(text):
    english_letters = set(string.ascii_letters)
    cleaned_text = ''.join(char for char in text if char.isalpha() or char.isspace())
    return all(char in english_letters or char.isspace() for char in cleaned_text)

def remove_non_eng(dictionary):
    eng_data = {}
    for key, value in dictionary.items():
        if is_pure_english(value):
            eng_data[key] = value
    return eng_data

def contains_climate_keywords(text, keywords):
    text = text.lower()
    for keyword in keywords:
        if re.search(r"\b" + re.escape(keyword) + r"\b", text):
            return True
    return False

def filter_climate_related(dictionary, keywords):
    cs_data = {}
    for key, value in dictionary.items():
        if contains_climate_keywords(value, keywords):
            cs_data[key] = value
    return cs_data

def text_preprocessing(text, remove_stopwords=False):
    if isinstance(text, str):
        words = [lemmatize(w) for w in text.lower().split()]
    elif isinstance(text, list):
        words = [lemmatize(w.lower()) for w in text]
    else:
        raise ValueError("Unsupported type for 'text'. Expected str or list.")
    
    if remove_stopwords:
        words = [w for w in words if w not in stopwords]
    
    return " ".join(words)

In [61]:
climate_keywords = [
    "climate", "environment", "global warming", "greenhouse effect", "carbon", "co2", "carbon dioxide",
    "methane", "renewable energy", "sustainability", "ecology", "biodiversity", "fossil fuels",
    "emissions", "air quality", "ozone", "solar energy", "wind energy", "climate change", "climate crisis",
    "climate adaptation", "climate mitigation", "ocean", "sea levels", "ice melting", "deforestation",
    "reforestation", "pollution"," electricity","energy","solar","wind","renewable","fossil","fuel","emission","air","quality","ozone","solar","wind","climate","change","crisis","adaptation","mitigation","ocean","sea","level","ice","melt","deforestation",
]


def preprocess_claim_data(claim_data, existed_evidences_id=None):
    claim_data = remove_non_eng(claim_data)
    claim_data_text = []
    claim_data_id = []
    claim_data_label = []
    claim_evidences = []
    
    for key in claim_data.keys():
        claim_data[key]["claim_text"] = word_tokenize(claim_data[key]["claim_text"])
        claim_data[key]["claim_text"] = text_preprocessing(claim_data[key]["claim_text"])
        
        claim_data_text.append(claim_data[key]["claim_text"])
        claim_data_id.append(key)
        
        if "claim_label" in claim_data[key]:
            claim_data_label.append(claim_data[key]["claim_label"])
        else:
            claim_data_label.append(None)
        
        if existed_evidences_id and "evidences" in claim_data[key]:
            valid_evidences = [existed_evidences_id[i] for i in claim_data[key]["evidences"] if i in existed_evidences_id]
            claim_evidences.append(valid_evidences)
        else:
            claim_evidences.append([])
    
    return claim_data_text, claim_data_id, claim_data_label, claim_evidences


def preprocess_evi_data(evi_data, climate_keywords):
    evi_data = remove_non_eng(evi_data)
    cs_evi_data = filter_climate_related(evi_data, climate_keywords)
    
    for key in cs_evi_data.keys():
        cs_evi_data[key] = word_tokenize(cs_evi_data[key])
        cs_evi_data[key] = text_preprocessing(cs_evi_data[key], remove_stopwords=True)
        
    cleaned_evidence_text = list(cs_evi_data.values())
    cleaned_evidence_id = list(cs_evi_data.keys())
    
    return cleaned_evidence_text, cleaned_evidence_id

In [63]:
cleaned_evidence_text, cleaned_evidence_id = preprocess_evi_data(evidences, climate_keywords)

evidences_texts = cleaned_evidence_text
evidences_ids = cleaned_evidence_id
evidences_id_dict = {evidence_id: idx for idx, evidence_id in enumerate(cleaned_evidence_id)}
evidences_p_texts = cleaned_evidence_text


train_claim_text, train_claim_id, train_claim_label, train_claim_evidences = preprocess_claim_data(train_claims, evidences_id_dict)

train_texts = train_claim_text
train_ids = train_claim_id
train_labels = train_claim_label
train_evidences = train_claim_evidences
train_p_texts = train_claim_text


dev_claim_text, dev_claim_id, dev_claim_label, dev_claim_evidences = preprocess_claim_data(dev_claims, evidences_id_dict)

dev_texts = dev_claim_text
dev_ids = dev_claim_id
dev_labels = dev_claim_label
dev_evidences = dev_claim_evidences
dev_p_texts = dev_claim_text  


test_claim_text, test_claim_id, _, _ = preprocess_claim_data(test_claims)

test_texts = test_claim_text
test_ids = test_claim_id
test_p_texts = test_claim_text

In [68]:
# save for debug
# need to delete when submitting
json.dump(train_ids, open("temp_data/train_ids.json", "w"))
json.dump(train_texts, open("temp_data/train_texts.json", "w"))
json.dump(train_p_texts, open("temp_data/train_p_texts.json", "w"))
json.dump(train_evidences, open("temp_data/train_evidences.json", "w"))
json.dump(train_labels, open("temp_data/train_labels.json", "w"))

json.dump(dev_ids, open("temp_data/dev_ids.json", "w"))
json.dump(dev_texts, open("temp_data/dev_texts.json", "w"))
json.dump(dev_p_texts, open("temp_data/dev_p_texts.json", "w"))
json.dump(dev_evidences, open("temp_data/dev_evidences.json", "w"))
json.dump(dev_labels, open("temp_data/dev_labels.json", "w"))

json.dump(test_ids, open("temp_data/test_ids.json", "w"))
json.dump(test_texts, open("temp_data/test_texts.json", "w"))
json.dump(test_p_texts, open("temp_data/test_p_texts.json", "w"))

json.dump(evidences_texts, open("temp_data/evidences_texts.json", "w"))
json.dump(evidences_p_texts, open("temp_data/evidences_p_texts.json", "w"))
json.dump(evidences_ids, open("temp_data/evidences_ids.json", "w"))
json.dump(evidences_id_dict, open("temp_data/evidences_id_dict.json", "w"))


In [8]:

train_ids = json.load(open("temp_data/train_ids.json", "r"))
train_texts = json.load(open("temp_data/train_texts.json", "r"))
train_p_texts = json.load(open("temp_data/train_p_texts.json", "r"))
train_evidences = json.load(open("temp_data/train_evidences.json", "r"))
train_labels = json.load(open("temp_data/train_labels.json", "r"))

dev_ids = json.load(open("temp_data/dev_ids.json", "r"))
dev_texts = json.load(open("temp_data/dev_texts.json", "r"))
dev_p_texts = json.load(open("temp_data/dev_p_texts.json", "r"))
dev_evidences = json.load(open("temp_data/dev_evidences.json", "r"))
dev_labels = json.load(open("temp_data/dev_labels.json", "r"))

test_ids = json.load(open("temp_data/test_ids.json", "r"))
test_texts = json.load(open("temp_data/test_texts.json", "r"))
test_p_texts = json.load(open("temp_data/test_p_texts.json", "r"))

evidences_texts = json.load(open("temp_data/evidences_texts.json", "r"))
evidences_p_texts = json.load(open("temp_data/evidences_p_texts.json", "r"))
evidences_ids = json.load(open("temp_data/evidences_ids.json", "r"))
evidences_id_dict = json.load(open("temp_data/evidences_id_dict.json", "r"))

### tfidf retrieval

In [64]:
# from workshop and a1
from sklearn.feature_extraction.text import TfidfVectorizer

# vectorizer = TfidfVectorizer(max_features=500000)
vectorizer = TfidfVectorizer()
vectorizer.fit(evidences_texts)
# TODO can svd 
train_tfidf = vectorizer.transform(train_p_texts)
dev_tfidf = vectorizer.transform(dev_p_texts)
test_tfidf = vectorizer.transform(test_p_texts)
evidence_tfidf = vectorizer.transform(evidences_p_texts)


In [65]:
# need to change this code

# scikit-learn can calculate cosine similarity
# train_cos_sims = np.dot(train_tfidf, evidence_tfidf.transpose()).toarray()
# dev_cos_sims = np.dot(dev_tfidf, evidence_tfidf.transpose()).toarray()
# test_cos_sims = np.dot(test_tfidf, evidence_tfidf.transpose()).toarray()
train_cos_sims = cosine_similarity(train_tfidf, evidence_tfidf)
dev_cos_sims = cosine_similarity(dev_tfidf, evidence_tfidf)
test_cos_sims = cosine_similarity(test_tfidf, evidence_tfidf)
print(train_cos_sims.shape)

(1228, 54272)


In [69]:
def test_retrieval_topk(k, cur_scores, cur_labels):
    ACC = []
    top_ids = torch.topk(torch.FloatTensor(cur_scores), k, -1).indices.tolist()
    for i in range(len(cur_labels)):
        all_count = 0
        recall_count = 0
        for cur_ in cur_labels[i]:
            if cur_ in top_ids[i]:
                recall_count += 1
            all_count += 1
        if all_count == 0:
            all_count = 1e-9  # 避免除以0
        ACC.append(recall_count / all_count)
    print(sum(ACC) / len(ACC))

topK = 10
test_retrieval_topk(topK, train_cos_sims, train_evidences)
test_retrieval_topk(topK, dev_cos_sims, dev_evidences)

0.14695982627578721
0.15670995670995672


In [70]:
# need to change this code
def sort_evidence_candidates(cos_sims):
    # dev_candis = np.argpartition(-cos_sims, candi_num, 1)[:candi_num]
    candis = []
    for i in range(cos_sims.shape[0]):
        cur_top_ids = np.argsort(-cos_sims[i]).tolist()[:10000]
        candis.append(cur_top_ids)
    return candis

In [71]:
dev_sort_evidences = sort_evidence_candidates(dev_cos_sims)
test_sort_evidences = sort_evidence_candidates(test_cos_sims)
train_sort_evidences = sort_evidence_candidates(train_cos_sims)

In [72]:
json.dump(dev_sort_evidences, open("temp_data/dev_sort_evidences.json", "w"))
json.dump(test_sort_evidences, open("temp_data/test_sort_evidences.json", "w"))
json.dump(train_sort_evidences, open("temp_data/train_sort_evidences.json", "w"))

### construct vocab and indexing

In [73]:
# construct word2idx and idx2word for taining data
# if we use the unprocessed text else use ***_p_texts

min_count = 10
wordcount = {}
idx2word = ["<pad>", "<cls>", "<sep>", "<unk>"]
word2idx = {"<pad>": 0, "<cls>": 1, "<sep>": 2, "<unk>": 3}
for texts in train_texts + evidences_texts:
    for word in texts.split():
        wordcount[word] = wordcount.get(word, 0) + 1
idx = 4
for i, j in wordcount.items():
    if j > min_count:
        idx2word.append(i)
        word2idx[i] = idx
        idx += 1

json.dump(idx2word, open("temp_data/idx2word.json", "w"))
json.dump(word2idx, open("temp_data/word2idx.json", "w"))

In [74]:
def convert2idx(text_data, word2idx_, idx2word_):
    idx_data = []
    for texts in text_data:
        temp_idx = []
        for word in texts.split():
            temp_idx.append(word2idx_.get(word, word2idx_["<unk>"]))
        idx_data.append(temp_idx)
    return idx_data

In [75]:
train_text_idx = convert2idx(train_texts, word2idx, idx2word)
dev_text_idx = convert2idx(dev_texts, word2idx, idx2word)
test_text_idx = convert2idx(test_texts, word2idx, idx2word)
evidences_text_idx = convert2idx(evidences_texts, word2idx, idx2word)

In [76]:
json.dump(train_text_idx, open("temp_data/train_text_idx.json", "w"))
json.dump(dev_text_idx, open("temp_data/dev_text_idx.json", "w"))
json.dump(test_text_idx, open("temp_data/test_text_idx.json", "w"))
json.dump(evidences_text_idx, open("temp_data/evidences_text_idx.json", "w"))

In [77]:
print(max([len(i) for i in train_text_idx]), max([len(i) for i in dev_text_idx]), max([len(i) for i in test_text_idx]), max([len(i) for i in evidences_text_idx]))

76 73 60 231


In [78]:
text_pad_len = 60
evidences_pad_len = 100

In [79]:
def construct_input_text(text_idx, padding_len, word2idx_):
    idx_data = []
    for texts in text_idx:
        if len(texts) < padding_len:
            idx_data.append([word2idx_["<cls>"]] + texts + [word2idx_["<sep>"]] + [word2idx_["<pad>"]] * (padding_len - len(texts)))
        else:
            idx_data.append([word2idx_["<cls>"]] + texts[:padding_len] + [word2idx_["<sep>"]])
    return idx_data
    

In [80]:
train_input = construct_input_text(train_text_idx, text_pad_len, word2idx)
dev_input = construct_input_text(dev_text_idx, text_pad_len, word2idx)
test_input = construct_input_text(test_text_idx, text_pad_len, word2idx)
evidences_input = construct_input_text(evidences_text_idx, evidences_pad_len, word2idx)

In [81]:
print(len(train_input[0]), len(evidences_input[0]))

62 102


In [82]:
json.dump(train_input, open("temp_data/train_input.json", "w"))
json.dump(dev_input, open("temp_data/dev_input.json", "w"))
json.dump(test_input, open("temp_data/test_input.json", "w"))
json.dump(evidences_input, open("temp_data/evidences_input.json", "w"))

In [83]:
# load save data
# need to delete when submitting

train_ids = json.load(open("temp_data/train_ids.json", "r"))
train_texts = json.load(open("temp_data/train_texts.json", "r"))
train_p_texts = json.load(open("temp_data/train_p_texts.json", "r"))
train_evidences = json.load(open("temp_data/train_evidences.json", "r"))
train_labels = json.load(open("temp_data/train_labels.json", "r"))

dev_ids = json.load(open("temp_data/dev_ids.json", "r"))
dev_texts = json.load(open("temp_data/dev_texts.json", "r"))
dev_p_texts = json.load(open("temp_data/dev_p_texts.json", "r"))
dev_evidences = json.load(open("temp_data/dev_evidences.json", "r"))
dev_labels = json.load(open("temp_data/dev_labels.json", "r"))

test_ids = json.load(open("temp_data/test_ids.json", "r"))
test_texts = json.load(open("temp_data/test_texts.json", "r"))
test_p_texts = json.load(open("temp_data/test_p_texts.json", "r"))

evidences_texts = json.load(open("temp_data/evidences_texts.json", "r"))
evidences_p_texts = json.load(open("temp_data/evidences_p_texts.json", "r"))
evidences_ids = json.load(open("temp_data/evidences_ids.json", "r"))
evidences_id_dict = json.load(open("temp_data/evidences_id_dict.json", "r"))

idx2word = json.load(open("temp_data/idx2word.json", "r"))
word2idx = json.load(open("temp_data/word2idx.json", "r"))

train_text_idx = json.load(open("temp_data/train_text_idx.json", "r"))
dev_text_idx = json.load(open("temp_data/dev_text_idx.json", "r"))
test_text_idx = json.load(open("temp_data/test_text_idx.json", "r"))
evidences_text_idx = json.load(open("temp_data/evidences_text_idx.json", "r"))

train_input = json.load(open("temp_data/train_input.json", "r"))
dev_input = json.load(open("temp_data/dev_input.json", "r"))
test_input = json.load(open("temp_data/test_input.json", "r"))
evidences_input = json.load(open("temp_data/evidences_input.json", "r"))

dev_sort_evidences = json.load(open("temp_data/dev_sort_evidences.json", "r"))
test_sort_evidences = json.load(open("temp_data/test_sort_evidences.json", "r"))
train_sort_evidences = json.load(open("temp_data/train_sort_evidences.json", "r"))

In [84]:
vocab_size = len(idx2word)
print(vocab_size)

7075


In [85]:
from torch.utils.data import Dataset
import random

class TrainDataset(Dataset):
    def __init__(self, text_input_data, evidence_input_data, tfidf_sort_evidences, evidence_label, negative_num=10):
        self.text_input_data = text_input_data
        self.evidence_input_data = evidence_input_data
        self.tfidf_sort_evidences = tfidf_sort_evidences
        self.evidence_label = evidence_label
        self.negative_num = 10
        self.evidence_len = len(evidence_input_data[0])
        self.text_len = len(text_input_data[0])
        
    def __len__(self):
        return len(self.text_input_data)

    def __getitem__(self, idx):
        # please note the negative evidences
        return [self.text_input_data[idx], random.sample(self.tfidf_sort_evidences[idx][10: self.negative_num*10], self.negative_num), self.evidence_label[idx]]

    def collate_fn(self, batch):
        queries = []
        queries_pos = []
        evidences = []
        temp_labels = []
        
        for i, j, k in batch:
            queries.append(i)
            queries_pos.append(list(range(self.text_len)))
            temp_labels.append(k)    
            evidences.extend(k + j)

        evidences = list(set(evidences))
        
        evidences2idx = {}
        for i, j in enumerate(evidences):
            evidences2idx[j] = i
        
        labels = []
        for i in temp_labels:
            labels.append([evidences2idx[j] for j in i])
            
        evidences = [self.evidence_input_data[i] for i in evidences]
        evidences_pos = [list(range(self.evidence_len)) for _ in range(len(evidences))]

        batch_encoding = {}
        batch_encoding["queries"] = torch.LongTensor(queries)
        batch_encoding["evidences"] = torch.LongTensor(evidences)
        
        batch_encoding["queries_pos"] = torch.LongTensor(queries_pos)
        batch_encoding["evidences_pos"] = torch.LongTensor(evidences_pos)
        batch_encoding["labels"] = labels
        
        return batch_encoding

In [86]:
train_set = TrainDataset(train_input, evidences_input, train_sort_evidences, train_evidences, negative_num=1000)
from torch.utils.data import DataLoader

dataloader = DataLoader(train_set, batch_size=5, shuffle=True, num_workers=0, collate_fn=train_set.collate_fn)

# 2. Model Implementation
(You can add as many code blocks and text blocks as you need. However, YOU SHOULD NOT MODIFY the section title)

In [87]:
# from workshop
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, vocab_emb, embed_dim, hidden_size, nhead, num_layers, max_position=180):
        super(Encoder, self).__init__()

        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_emb, embed_dim)
        self.pos_embedding = nn.Embedding(max_position, embed_dim)  # https://pytorch.org/tutorials/beginner/translation_transformer.html other method
        
        #encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=nhead, batch_first=True)
        #self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, norm=nn.LayerNorm(hidden_size))
        self.encoder = nn.LSTM(embed_dim, hidden_size, num_layers=2,batch_first=True)

    def forward(self, text_data, position_text):
        mask_ = text_data == 0
        text_x = self.embedding(text_data) + self.pos_embedding(position_text)
        #x_encoded = self.encoder(text_x, src_key_padding_mask=mask_)  # https://zhuanlan.zhihu.com/p/353365423
        x_encoded,_ = self.encoder(text_x)
        return x_encoded

In [88]:
trans_encoder = Encoder(vocab_emb=vocab_size, embed_dim=512, hidden_size=512, nhead=8, num_layers=6, max_position=180)
trans_encoder.cuda()

Encoder(
  (embedding): Embedding(7075, 512)
  (pos_embedding): Embedding(180, 512)
  (encoder): LSTM(512, 512, num_layers=2, batch_first=True)
)

In [89]:
def to_cuda(batch):
    # for n in batch.keys():
    #     if n != "labels":
    #         batch[n] = batch[n].cuda()
    ## TODO: you can use to define .cuda()
    pass

### Training

In [90]:
# from workshop but need to change because I add some speciall setting

torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
random.seed(42)

encoder_optimizer = optim.Adam(trans_encoder.parameters())
max_lr = 1e-2
for param_group in encoder_optimizer.param_groups:
    param_group['lr'] = max_lr

accumulate_step = 2
grad_norm = 0.1
warmup_steps = 200
report_freq = 10
eval_interval = 50
save_dir = "model_ckpts"

In [91]:
retrieval_num = 5
dev_candis_num = 10

def validate(dev_text_idx, evidence_text_idx, dev_sort_evidences, dev_evidences, encoder_model):
    # get evidence embeddings
    start_idx = 0
    batch_size = 1000
    evidence_len = len(evidence_text_idx[0])
    text_len = len(dev_text_idx[0])
    
    evidence_embeddings = []
    encoder_model.eval()
    
    while start_idx < len(evidence_text_idx):
        end_idx = min(start_idx + batch_size, len(evidence_text_idx))
        
        cur_evidence = torch.LongTensor(evidence_text_idx[start_idx:end_idx]).view(-1, evidence_len).cuda()
        cur_evidence_pos = torch.LongTensor([list(range(evidence_len)) for _ in range(end_idx - start_idx)]).cuda()
        print()
        cur_embedding = encoder_model(cur_evidence, cur_evidence_pos)
        cur_embedding = cur_embedding[:, 0, :].detach()
        cur_embedding_cpu = F.normalize(cur_embedding, p=2, dim=1).cpu()
        del cur_embedding, cur_evidence, cur_evidence_pos
        start_idx = end_idx
        evidence_embeddings.append(cur_embedding_cpu)
        
    evidence_embeddings = torch.cat(evidence_embeddings, dim=0).t()
    print("get all evidence embeddings!")
    f = []
    
    start_idx = 0
    batch_size = 1000
    
    while start_idx < len(dev_text_idx):
        end_idx = min(start_idx + batch_size, len(dev_text_idx))
        
        cur_query = torch.LongTensor(dev_text_idx[start_idx:end_idx]).view(-1, text_len).cuda()
        cur_query_pos = torch.LongTensor([list(range(text_len)) for _ in range(end_idx - start_idx)]).cuda()

        query_embedding = encoder_model(cur_query, cur_query_pos)
        query_embedding = query_embedding[:, 0, :].detach()
        query_embedding = F.normalize(query_embedding, p=2, dim=1).cpu()
        
        scores = torch.mm(query_embedding, evidence_embeddings)

        for i in range(scores.size(0)):
            new_score = torch.index_select(scores[i], 0, torch.LongTensor(dev_sort_evidences[start_idx+i][:dev_candis_num]))
            topk_ids = torch.argsort(new_score).tolist()
            select_ids = topk_ids[:retrieval_num]
            
            evidence_correct = 0
            pred_evidences = [dev_sort_evidences[start_idx+i][j] for j in select_ids]
            label = dev_evidences[start_idx+i]
            for evidence_id in label:
                if evidence_id in pred_evidences:
                    evidence_correct += 1
            if evidence_correct > 0:
                evidence_recall = float(evidence_correct) / len(label)
                evidence_precision = float(evidence_correct) / len(pred_evidences)
                evidence_fscore = (2 * evidence_precision * evidence_recall) / (evidence_precision + evidence_recall)
                print(evidence_fscore)
            else:
                evidence_fscore = 0
            f.append(evidence_fscore)
            
        start_idx = end_idx
        # print("----")
    fscore = np.mean(f)
    print("\n")
    print("Evidence Retrieval F-score: %.3f" % fscore)
    print("\n")
    encoder_model.train()
    return fscore

In [92]:
%env WANDB_NOTEBOOK_NAME MelMoxue_NLP_retrieval.ipynb

env: WANDB_NOTEBOOK_NAME=MelMoxue_NLP_retrieval.ipynb


In [35]:
# import wandb
# import sys
# print(sys.path)
# print(wandb.__path__)

In [93]:
# start training
import wandb
import os
wandb.init(project="nlp", name="dpr")

from tqdm import tqdm
import numpy as np

encoder_optimizer.zero_grad()
step_cnt = 0
all_step_cnt = 0
avg_loss = 0
maximum_f_score = 0

for epoch in range(5): 
    epoch_step = 0

    for (i, batch) in enumerate(tqdm(dataloader)):
        
        step_cnt += 1
        # forward pass
        query_embeddings = trans_encoder(batch["queries"].cuda(), batch["queries_pos"].cuda())
        evidence_embeddings = trans_encoder(batch["evidences"].cuda(), batch["evidences_pos"].cuda())
        
        query_embeddings = query_embeddings[:, 0, :]
        evidence_embeddings = evidence_embeddings[:, 0, :]

        query_embeddings = torch.nn.functional.normalize(query_embeddings, p=2, dim=1)
        evidence_embeddings = torch.nn.functional.normalize(evidence_embeddings, p=2, dim=1)

        cos_sims = torch.mm(query_embeddings, evidence_embeddings.t())
        scores = - torch.nn.functional.log_softmax(cos_sims / 0.1, dim=1)

        loss = []
        start_idx = 0
        for idx, label in enumerate(batch["labels"]):
            label = torch.LongTensor(label).cuda()
            # print(scores[idx].size())
            # assert 0 == 1
            cur_loss = torch.mean(torch.index_select(scores[idx], 0, label))
            loss.append(cur_loss)

        loss = torch.stack(loss).mean()
        loss = loss / accumulate_step
        loss.backward()

        avg_loss += loss.item()
        if step_cnt == accumulate_step:
            # updating
            if grad_norm > 0:
                nn.utils.clip_grad_norm_(trans_encoder.parameters(), grad_norm)

            step_cnt = 0
            epoch_step += 1
            all_step_cnt += 1
            
            # adjust learning rate
            if all_step_cnt <= warmup_steps:
                lr = all_step_cnt * (max_lr - 2e-8) / warmup_steps + 2e-8
            else:
                lr = max_lr - (all_step_cnt - warmup_steps) * 1e-5
                
            encoder_optimizer.step()
            encoder_optimizer.zero_grad()
        
        if all_step_cnt % report_freq == 0 and step_cnt == 0:
            if all_step_cnt <= warmup_steps:
                lr = all_step_cnt * (max_lr - 2e-8) / warmup_steps + 2e-8
            else:
                lr = max_lr - (all_step_cnt - warmup_steps) * 1e-5

            wandb.log({"learning_rate": lr}, step=all_step_cnt)
            wandb.log({"loss": avg_loss / report_freq}, step=all_step_cnt)
            
            # report stats
            print("\n")
            print("epoch: %d, epoch_step: %d, avg loss: %.6f" % (epoch + 1, epoch_step, avg_loss / report_freq))
            print(f"learning rate: {lr:.6f}")
            print("\n")
            avg_loss = 0
        del loss, cos_sims, query_embeddings, evidence_embeddings

        if all_step_cnt % eval_interval == 0 and all_step_cnt != 0 and step_cnt == 0:
            # evaluate the model as a scorer
            print("\nEvaluate:\n")
            
            f_score = validate(dev_input, evidences_input, dev_sort_evidences, dev_evidences, trans_encoder)
            wandb.log({"f_score": f_score}, step=all_step_cnt)

            if f_score > maximum_f_score:
                maximum_f_score = f_score
                torch.save(trans_encoder.state_dict(), os.path.join(save_dir, "best_ckpt.bin"))
                # torch.save(last_evidence_embeddings, os.path.join(save_dir, "evidence_embeddings"))
                print("\n")
                print("best val loss - epoch: %d, epoch_step: %d" % (epoch, epoch_step))
                print("maximum_f_score", f_score)
                print("\n")

0,1
f_score,▁▁▁▁▁▁▁▁▁▁▁▁
learning_rate,▁▁▂▂▃▄▄▅▅▆▇▇██████▇▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆▆▆▅▅▅▅
loss,▆▆▅▅▇▇▅▆▂▆▄▄▆▆▇▅▂█▆▇▆▇▆▅▁▆▅▆▅▆▅▅▂█▆▇▄▆▅▄

0,1
f_score,0.07792
learning_rate,0.0059
loss,4.18025


  9%|▉         | 23/246 [00:01<00:10, 20.33it/s]



epoch: 1, epoch_step: 10, avg loss: nan
learning rate: 0.000500




 17%|█▋        | 41/246 [00:02<00:09, 21.51it/s]



epoch: 1, epoch_step: 20, avg loss: nan
learning rate: 0.001000




 25%|██▌       | 62/246 [00:03<00:08, 20.78it/s]



epoch: 1, epoch_step: 30, avg loss: nan
learning rate: 0.001500




 34%|███▎      | 83/246 [00:04<00:08, 20.15it/s]



epoch: 1, epoch_step: 40, avg loss: nan
learning rate: 0.002000




 40%|███▉      | 98/246 [00:04<00:07, 20.83it/s]



epoch: 1, epoch_step: 50, avg loss: nan
learning rate: 0.002500



Evaluate:
























































get all evidence embeddings!
0.22222222222222224
0.28571428571428575
0.33333333333333337
0.22222222222222224
0.33333333333333337
0.28571428571428575
0.28571428571428575
0.20000000000000004
0.5714285714285715
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.20000000000000004
0.33333333333333337
0.33333333333333337
0.28571428571428575
0.33333333333333337
0.33333333333333337
0.25
0.20000000000000004
0.33333333333333337
0.28571428571428575
0.4000000000000001
0.28571428571428575
0.28571428571428575


Evidence Retrieval F-score: 0.052




 41%|████      | 101/246 [00:11<01:35,  1.52it/s]



best val loss - epoch: 0, epoch_step: 50
maximum_f_score 0.05150999793856938




 50%|████▉     | 122/246 [00:12<00:12, 10.09it/s]



epoch: 1, epoch_step: 60, avg loss: nan
learning rate: 0.003000




 58%|█████▊    | 143/246 [00:13<00:05, 20.00it/s]



epoch: 1, epoch_step: 70, avg loss: nan
learning rate: 0.003500




 65%|██████▌   | 161/246 [00:13<00:04, 20.76it/s]



epoch: 1, epoch_step: 80, avg loss: nan
learning rate: 0.004000




 74%|███████▍  | 182/246 [00:14<00:02, 21.77it/s]



epoch: 1, epoch_step: 90, avg loss: nan
learning rate: 0.004500




 80%|████████  | 197/246 [00:15<00:02, 20.84it/s]



epoch: 1, epoch_step: 100, avg loss: nan
learning rate: 0.005000



Evaluate:
























































get all evidence embeddings!
0.22222222222222224
0.28571428571428575
0.33333333333333337
0.22222222222222224
0.33333333333333337
0.28571428571428575
0.28571428571428575
0.20000000000000004
0.5714285714285715
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.20000000000000004
0.33333333333333337
0.33333333333333337
0.28571428571428575
0.33333333333333337
0.33333333333333337
0.25
0.20000000000000004
0.33333333333333337
0.28571428571428575
0.4000000000000001
0.28571428571428575
0.28571428571428575


Evidence Retrieval F-score: 0.052




 90%|████████▉ | 221/246 [00:22<00:02, 10.36it/s]



epoch: 1, epoch_step: 110, avg loss: nan
learning rate: 0.005500




 98%|█████████▊| 242/246 [00:23<00:00, 18.28it/s]



epoch: 1, epoch_step: 120, avg loss: nan
learning rate: 0.006000




100%|██████████| 246/246 [00:24<00:00, 10.24it/s]
  6%|▌         | 15/246 [00:00<00:10, 21.49it/s]



epoch: 2, epoch_step: 7, avg loss: nan
learning rate: 0.006500




 15%|█▍        | 36/246 [00:01<00:10, 20.44it/s]



epoch: 2, epoch_step: 17, avg loss: nan
learning rate: 0.007000




 21%|██        | 51/246 [00:02<00:09, 21.57it/s]



epoch: 2, epoch_step: 27, avg loss: nan
learning rate: 0.007500



Evaluate:
























































get all evidence embeddings!
0.22222222222222224
0.28571428571428575
0.33333333333333337
0.22222222222222224
0.33333333333333337
0.28571428571428575
0.28571428571428575
0.20000000000000004
0.5714285714285715
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.20000000000000004
0.33333333333333337
0.33333333333333337
0.28571428571428575
0.33333333333333337
0.33333333333333337
0.25
0.20000000000000004
0.33333333333333337
0.28571428571428575
0.4000000000000001
0.28571428571428575
0.28571428571428575


Evidence Retrieval F-score: 0.052




 31%|███       | 76/246 [00:09<00:14, 11.72it/s]



epoch: 2, epoch_step: 37, avg loss: nan
learning rate: 0.008000




 39%|███▉      | 97/246 [00:10<00:07, 19.14it/s]



epoch: 2, epoch_step: 47, avg loss: nan
learning rate: 0.008500




 48%|████▊     | 118/246 [00:11<00:05, 23.76it/s]



epoch: 2, epoch_step: 57, avg loss: nan
learning rate: 0.009000




 55%|█████▌    | 136/246 [00:12<00:05, 21.06it/s]



epoch: 2, epoch_step: 67, avg loss: nan
learning rate: 0.009500




 61%|██████▏   | 151/246 [00:13<00:04, 21.03it/s]



epoch: 2, epoch_step: 77, avg loss: nan
learning rate: 0.010000



Evaluate:


























































 63%|██████▎   | 154/246 [00:19<01:00,  1.52it/s]

get all evidence embeddings!
0.22222222222222224
0.28571428571428575
0.33333333333333337
0.22222222222222224
0.33333333333333337
0.28571428571428575
0.28571428571428575
0.20000000000000004
0.5714285714285715
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.20000000000000004
0.33333333333333337
0.33333333333333337
0.28571428571428575
0.33333333333333337
0.33333333333333337
0.25
0.20000000000000004
0.33333333333333337
0.28571428571428575
0.4000000000000001
0.28571428571428575
0.28571428571428575


Evidence Retrieval F-score: 0.052




 72%|███████▏  | 176/246 [00:20<00:06, 11.65it/s]



epoch: 2, epoch_step: 87, avg loss: nan
learning rate: 0.009900




 80%|████████  | 197/246 [00:21<00:02, 20.31it/s]



epoch: 2, epoch_step: 97, avg loss: nan
learning rate: 0.009800




 87%|████████▋ | 215/246 [00:22<00:01, 20.13it/s]



epoch: 2, epoch_step: 107, avg loss: nan
learning rate: 0.009700




 96%|█████████▌| 236/246 [00:23<00:00, 22.27it/s]



epoch: 2, epoch_step: 117, avg loss: nan
learning rate: 0.009600




100%|██████████| 246/246 [00:23<00:00, 10.34it/s]
  2%|▏         | 6/246 [00:00<00:11, 20.70it/s]



epoch: 3, epoch_step: 4, avg loss: nan
learning rate: 0.009500



Evaluate:

























































  4%|▎         | 9/246 [00:06<03:51,  1.02it/s]


get all evidence embeddings!
0.22222222222222224
0.28571428571428575
0.33333333333333337
0.22222222222222224
0.33333333333333337
0.28571428571428575
0.28571428571428575
0.20000000000000004
0.5714285714285715
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.20000000000000004
0.33333333333333337
0.33333333333333337
0.28571428571428575
0.33333333333333337
0.33333333333333337
0.25
0.20000000000000004
0.33333333333333337
0.28571428571428575
0.4000000000000001
0.28571428571428575
0.28571428571428575


Evidence Retrieval F-score: 0.052




 12%|█▏        | 30/246 [00:07<00:21, 10.09it/s]



epoch: 3, epoch_step: 14, avg loss: nan
learning rate: 0.009400




 21%|██        | 51/246 [00:08<00:10, 19.42it/s]



epoch: 3, epoch_step: 24, avg loss: nan
learning rate: 0.009300




 29%|██▉       | 72/246 [00:09<00:08, 21.02it/s]



epoch: 3, epoch_step: 34, avg loss: nan
learning rate: 0.009200




 37%|███▋      | 90/246 [00:10<00:07, 20.96it/s]



epoch: 3, epoch_step: 44, avg loss: nan
learning rate: 0.009100




 43%|████▎     | 105/246 [00:11<00:06, 20.83it/s]



epoch: 3, epoch_step: 54, avg loss: nan
learning rate: 0.009000



Evaluate:
























































get all evidence embeddings!
0.22222222222222224
0.28571428571428575
0.33333333333333337
0.22222222222222224
0.33333333333333337
0.28571428571428575
0.28571428571428575
0.20000000000000004
0.5714285714285715
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.20000000000000004
0.33333333333333337
0.33333333333333337
0.28571428571428575
0.33333333333333337
0.33333333333333337
0.25
0.20000000000000004
0.33333333333333337


 45%|████▌     | 111/246 [00:17<01:04,  2.11it/s]

0.28571428571428575
0.4000000000000001
0.28571428571428575
0.28571428571428575


Evidence Retrieval F-score: 0.052




 53%|█████▎    | 131/246 [00:18<00:09, 12.33it/s]



epoch: 3, epoch_step: 64, avg loss: nan
learning rate: 0.008900




 61%|██████    | 149/246 [00:19<00:05, 18.61it/s]



epoch: 3, epoch_step: 74, avg loss: nan
learning rate: 0.008800




 69%|██████▉   | 170/246 [00:20<00:03, 21.23it/s]



epoch: 3, epoch_step: 84, avg loss: nan
learning rate: 0.008700




 78%|███████▊  | 191/246 [00:21<00:02, 20.73it/s]



epoch: 3, epoch_step: 94, avg loss: nan
learning rate: 0.008600




 84%|████████▎ | 206/246 [00:22<00:01, 21.39it/s]



epoch: 3, epoch_step: 104, avg loss: nan
learning rate: 0.008500



Evaluate:
























































get all evidence embeddings!
0.22222222222222224
0.28571428571428575
0.33333333333333337
0.22222222222222224
0.33333333333333337
0.28571428571428575
0.28571428571428575
0.20000000000000004
0.5714285714285715
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.20000000000000004
0.33333333333333337
0.33333333333333337
0.28571428571428575
0.33333333333333337
0.33333333333333337
0.25
0.20000000000000004
0.33333333333333337
0.28571428571428575
0.4000000000000001
0.28571428571428575
0.28571428571428575


Evidence Retrieval F-score: 0.052




 93%|█████████▎| 230/246 [00:28<00:01, 10.95it/s]



epoch: 3, epoch_step: 114, avg loss: nan
learning rate: 0.008400




100%|██████████| 246/246 [00:29<00:00,  8.37it/s]
  2%|▏         | 6/246 [00:00<00:09, 25.16it/s]



epoch: 4, epoch_step: 1, avg loss: nan
learning rate: 0.008300




 10%|▉         | 24/246 [00:01<00:09, 23.33it/s]



epoch: 4, epoch_step: 11, avg loss: nan
learning rate: 0.008200




 18%|█▊        | 45/246 [00:01<00:08, 23.19it/s]



epoch: 4, epoch_step: 21, avg loss: nan
learning rate: 0.008100




 24%|██▍       | 60/246 [00:02<00:08, 22.95it/s]



epoch: 4, epoch_step: 31, avg loss: nan
learning rate: 0.008000



Evaluate:
























































get all evidence embeddings!
0.22222222222222224
0.28571428571428575
0.33333333333333337
0.22222222222222224
0.33333333333333337
0.28571428571428575
0.28571428571428575
0.20000000000000004
0.5714285714285715
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.20000000000000004
0.33333333333333337
0.33333333333333337
0.28571428571428575
0.33333333333333337
0.33333333333333337
0.25
0.20000000000000004
0.33333333333333337
0.28571428571428575
0.4000000000000001
0.28571428571428575
0.28571428571428575


Evidence Retrieval F-score: 0.052




 34%|███▍      | 84/246 [00:09<00:15, 10.47it/s]



epoch: 4, epoch_step: 41, avg loss: nan
learning rate: 0.007900




 43%|████▎     | 106/246 [00:10<00:06, 21.70it/s]



epoch: 4, epoch_step: 51, avg loss: nan
learning rate: 0.007800




 50%|█████     | 124/246 [00:11<00:05, 22.25it/s]



epoch: 4, epoch_step: 61, avg loss: nan
learning rate: 0.007700




 59%|█████▉    | 145/246 [00:12<00:04, 23.28it/s]



epoch: 4, epoch_step: 71, avg loss: nan
learning rate: 0.007600




 65%|██████▌   | 160/246 [00:12<00:03, 21.53it/s]



epoch: 4, epoch_step: 81, avg loss: nan
learning rate: 0.007500



Evaluate:


























































 66%|██████▋   | 163/246 [00:18<00:49,  1.68it/s]

get all evidence embeddings!
0.22222222222222224
0.28571428571428575
0.33333333333333337
0.22222222222222224
0.33333333333333337
0.28571428571428575
0.28571428571428575
0.20000000000000004
0.5714285714285715
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.20000000000000004
0.33333333333333337
0.33333333333333337
0.28571428571428575
0.33333333333333337
0.33333333333333337
0.25
0.20000000000000004
0.33333333333333337
0.28571428571428575
0.4000000000000001
0.28571428571428575
0.28571428571428575


Evidence Retrieval F-score: 0.052




 75%|███████▍  | 184/246 [00:19<00:05, 11.67it/s]



epoch: 4, epoch_step: 91, avg loss: nan
learning rate: 0.007400




 83%|████████▎ | 205/246 [00:20<00:01, 21.29it/s]



epoch: 4, epoch_step: 101, avg loss: nan
learning rate: 0.007300




 92%|█████████▏| 226/246 [00:21<00:00, 23.62it/s]



epoch: 4, epoch_step: 111, avg loss: nan
learning rate: 0.007200




 99%|█████████▉| 244/246 [00:21<00:00, 22.88it/s]



epoch: 4, epoch_step: 121, avg loss: nan
learning rate: 0.007100




100%|██████████| 246/246 [00:21<00:00, 11.24it/s]
  6%|▌         | 15/246 [00:00<00:10, 22.61it/s]



epoch: 5, epoch_step: 8, avg loss: nan
learning rate: 0.007000



Evaluate:
























































get all evidence embeddings!
0.22222222222222224
0.28571428571428575
0.33333333333333337
0.22222222222222224
0.33333333333333337
0.28571428571428575
0.28571428571428575
0.20000000000000004
0.5714285714285715
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.20000000000000004
0.33333333333333337
0.33333333333333337
0.28571428571428575
0.33333333333333337
0.33333333333333337
0.25
0.20000000000000004
0.33333333333333337
0.28571428571428575
0.4000000000000001
0.28571428571428575
0.28571428571428575


Evidence Retrieval F-score: 0.052




 15%|█▌        | 37/246 [00:07<00:21,  9.76it/s]



epoch: 5, epoch_step: 18, avg loss: nan
learning rate: 0.006900




 24%|██▍       | 60/246 [00:08<00:09, 20.49it/s]



epoch: 5, epoch_step: 28, avg loss: nan
learning rate: 0.006800




 32%|███▏      | 78/246 [00:09<00:08, 20.30it/s]



epoch: 5, epoch_step: 38, avg loss: nan
learning rate: 0.006700




 40%|████      | 99/246 [00:10<00:06, 22.08it/s]



epoch: 5, epoch_step: 48, avg loss: nan
learning rate: 0.006600




 46%|████▋     | 114/246 [00:11<00:06, 21.04it/s]



epoch: 5, epoch_step: 58, avg loss: nan
learning rate: 0.006500



Evaluate:
























































get all evidence embeddings!
0.22222222222222224
0.28571428571428575
0.33333333333333337
0.22222222222222224
0.33333333333333337
0.28571428571428575
0.28571428571428575
0.20000000000000004
0.5714285714285715
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.20000000000000004
0.33333333333333337
0.33333333333333337
0.28571428571428575
0.33333333333333337
0.33333333333333337
0.25
0.20000000000000004
0.33333333333333337
0.28571428571428575
0.4000000000000001
0.28571428571428575
0.28571428571428575


Evidence Retrieval F-score: 0.052




 56%|█████▌    | 138/246 [00:18<00:09, 11.05it/s]



epoch: 5, epoch_step: 68, avg loss: nan
learning rate: 0.006400




 65%|██████▍   | 159/246 [00:19<00:04, 20.78it/s]



epoch: 5, epoch_step: 78, avg loss: nan
learning rate: 0.006300




 73%|███████▎  | 180/246 [00:20<00:02, 24.26it/s]



epoch: 5, epoch_step: 88, avg loss: nan
learning rate: 0.006200




 80%|████████  | 198/246 [00:21<00:02, 22.49it/s]



epoch: 5, epoch_step: 98, avg loss: nan
learning rate: 0.006100




 87%|████████▋ | 213/246 [00:21<00:01, 22.22it/s]



epoch: 5, epoch_step: 108, avg loss: nan
learning rate: 0.006000



Evaluate:


























































 89%|████████▉ | 219/246 [00:27<00:11,  2.33it/s]

get all evidence embeddings!
0.22222222222222224
0.28571428571428575
0.33333333333333337
0.22222222222222224
0.33333333333333337
0.28571428571428575
0.28571428571428575
0.20000000000000004
0.5714285714285715
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.33333333333333337
0.20000000000000004
0.33333333333333337
0.33333333333333337
0.28571428571428575
0.33333333333333337
0.33333333333333337
0.25
0.20000000000000004
0.33333333333333337
0.28571428571428575
0.4000000000000001
0.28571428571428575
0.28571428571428575


Evidence Retrieval F-score: 0.052




 98%|█████████▊| 240/246 [00:28<00:00, 13.34it/s]



epoch: 5, epoch_step: 118, avg loss: nan
learning rate: 0.005900




100%|██████████| 246/246 [00:28<00:00,  8.57it/s]


In [94]:
torch.cuda.empty_cache()

# 3.Testing and Evaluation
(You can add as many code blocks and text blocks as you need. However, YOU SHOULD NOT MODIFY the section title)

In [95]:
import os
trans_encoder.load_state_dict(torch.load(os.path.join(save_dir, "best_ckpt.bin")))
trans_encoder.cuda()
trans_encoder.eval()

Encoder(
  (embedding): Embedding(7075, 512)
  (pos_embedding): Embedding(180, 512)
  (encoder): LSTM(512, 512, num_layers=2, batch_first=True)
)

In [96]:
evidence_embeddings = []
start_idx = 0
batch_size = 1000
evidence_len = len(evidences_input[0])

while start_idx < len(evidences_input):
    end_idx = min(start_idx + batch_size, len(evidences_input))
    
    cur_evidence = torch.LongTensor(evidences_input[start_idx:end_idx]).view(-1, evidence_len).cuda()
    cur_evidence_pos = torch.LongTensor([list(range(evidence_len)) for _ in range(end_idx - start_idx)]).cuda()

    cur_embedding = trans_encoder(cur_evidence, cur_evidence_pos)
    cur_embedding = cur_embedding[:, 0, :].detach()
    cur_embedding_cpu = F.normalize(cur_embedding, p=2, dim=1).cpu()  # for cosine similarity
    
    del cur_embedding, cur_evidence, cur_evidence_pos
    start_idx = end_idx
    evidence_embeddings.append(cur_embedding_cpu)
    
evidence_embeddings = torch.cat(evidence_embeddings, dim=0).t()


In [97]:
torch.cuda.empty_cache()

In [98]:
import numpy as np

In [99]:
retrieval_num = 5
dev_candis_num = 10

def validate_(dev_text_idx, evidence_embeddings, dev_sort_evidences, dev_evidences, encoder_model):
    # get evidence embeddings
    encoder_model.eval()

    text_len = len(dev_text_idx[0])
    f = []
    
    start_idx = 0
    batch_size = 200
    
    while start_idx < len(dev_text_idx):
        end_idx = min(start_idx + batch_size, len(dev_text_idx))
        
        cur_query = torch.LongTensor(dev_text_idx[start_idx:end_idx]).view(-1, text_len).cuda()
        print(cur_query.size())
        cur_query_pos = torch.LongTensor([list(range(text_len)) for _ in range(end_idx - start_idx)]).cuda()

        query_embedding = encoder_model(cur_query, cur_query_pos)
        query_embedding = query_embedding[:, 0, :].detach()
        query_embedding = F.normalize(query_embedding, p=2, dim=1).cpu()
        
        scores = torch.mm(query_embedding, evidence_embeddings)
        
        for i in range(scores.size(0)):
            new_score = torch.index_select(scores[i], 0, torch.LongTensor(dev_sort_evidences[start_idx+i][:dev_candis_num]))
            topk_ids = torch.argsort(new_score).tolist()
            select_ids = topk_ids[:retrieval_num]
            
            evidence_correct = 0
            pred_evidences = [dev_sort_evidences[start_idx+i][j] for j in select_ids]
                        
            label = dev_evidences[start_idx+i]
                        
            for evidence_id in label:
                if evidence_id in pred_evidences:
                    evidence_correct += 1
            if evidence_correct > 0:
                evidence_recall = float(evidence_correct) / len(label)
                evidence_precision = float(evidence_correct) / len(pred_evidences)
                evidence_fscore = (2 * evidence_precision * evidence_recall) / (evidence_precision + evidence_recall)
                # print(evidence_fscore)
            else:
                evidence_fscore = 0
            f.append(evidence_fscore)
            
        start_idx = end_idx
        # print("----")
    fscore = np.mean(f)
    print("\n")
    print("Evidence Retrieval F-score: %.3f" % fscore)
    print("\n")
    return fscore

In [100]:
retrieval_num = 5
dev_candis_num = 10
fscore = validate_(dev_input, evidence_embeddings, dev_sort_evidences, dev_evidences, trans_encoder)
print(fscore)

torch.Size([154, 62])


Evidence Retrieval F-score: 0.052


0.05150999793856938


In [101]:
retrieval_num = 5
dev_candis_num = 10

def evidence_predicts(dev_text_idx, evidences_embeddings, dev_sort_evidences, evidences_ids, encoder_model):
    # get evidence embeddings
    text_len = len(dev_text_idx[0])
    encoder_model.eval()

    f = []
    start_idx = 0
    batch_size = 200
    preds = []
    while start_idx < len(dev_text_idx):
        end_idx = min(start_idx + batch_size, len(dev_text_idx))
        
        cur_query = torch.LongTensor(dev_text_idx[start_idx:end_idx]).view(-1, text_len).cuda()
        cur_query_pos = torch.LongTensor([list(range(text_len)) for _ in range(end_idx - start_idx)]).cuda()

        query_embedding = encoder_model(cur_query, cur_query_pos)
        query_embedding = query_embedding[:, 0, :].detach()
        query_embedding = F.normalize(query_embedding, p=2, dim=1).cpu()
        
        scores = torch.mm(query_embedding, evidences_embeddings)

        for i in range(scores.size(0)):
            new_score = torch.index_select(scores[i], 0, torch.LongTensor(dev_sort_evidences[start_idx+i][:dev_candis_num]))
            topk_ids = torch.argsort(new_score).tolist()
            select_ids = topk_ids[:retrieval_num]
            
            pred_evidences = [evidences_ids[dev_sort_evidences[start_idx+i][j]] for j in select_ids]
            preds.append(pred_evidences)
            
        start_idx = end_idx
    return preds

In [102]:
dev_evidences_ids = evidence_predicts(dev_input, evidence_embeddings, dev_sort_evidences, evidences_ids, trans_encoder)
test_evidences_ids = evidence_predicts(test_input, evidence_embeddings, test_sort_evidences, evidences_ids, trans_encoder)

In [103]:
pred_dev_claims = {}
pred_test_claims = {}
dev_claims = json.load(open("data/dev-claims.json", "r"))
test_claims = json.load(open("data/test-claims-unlabelled.json", "r"))

for idx, evidence_ids in enumerate(dev_evidences_ids):
    cur_data = dev_claims[dev_ids[idx]]
    cur_data['evidences'] = evidence_ids
    pred_dev_claims[dev_ids[idx]] = cur_data
    

for idx, evidence_ids in enumerate(test_evidences_ids):
    cur_data = test_claims[test_ids[idx]]
    cur_data['evidences'] = evidence_ids
    pred_test_claims[test_ids[idx]] = cur_data


In [104]:
json.dump(pred_dev_claims, open("data/dev_predict.json", "w"))
json.dump(pred_test_claims, open("data/test-claims-unlabelled.json", "w"))

In [105]:
retrieval_num = 5
dev_candis_num = 10

train_evidences_ids = evidence_predicts(train_input, evidence_embeddings, train_sort_evidences, evidences_ids, trans_encoder)

pred_train_negative_evidences = []
for idx, evidence_ids in enumerate(train_evidences_ids):
    temp_ = []
    for i in evidence_ids:
        if evidences_id_dict[i] not in train_evidences[idx]:
            temp_.append(evidences_id_dict[i])
    pred_train_negative_evidences.append(temp_)

In [106]:
## save prediction data

json.dump(pred_train_negative_evidences, open("pred_train_negative_evidences.json", "w"))

In [107]:
## save cls data

dev_cls_data = []
test_cls_data = []
text_max_len = 60
evidence_max_len = 100
all_max_len = 580

for idx, dev_text in enumerate(dev_text_idx):
    cur_data = {"label": dev_labels[idx]}
    temp_text = [word2idx["<cls>"]] + dev_text_idx[idx][:text_max_len]
    for i in dev_evidences_ids[idx]:
        temp_text.extend([word2idx["<sep>"]] + evidences_text_idx[evidences_id_dict[i]][:evidence_max_len])
    temp_text.append(word2idx["<sep>"])
    if len(temp_text) < all_max_len:
        temp_text.extend([word2idx["<pad>"]] * (all_max_len - len(temp_text)))
    cur_data['text'] = temp_text
    dev_cls_data.append(cur_data)

for idx, dev_text in enumerate(test_text_idx):
    cur_data = {}
    temp_text = [word2idx["<cls>"]] + test_text_idx[idx][:text_max_len]
    for i in test_evidences_ids[idx]:
        temp_text.extend([word2idx["<sep>"]] + evidences_text_idx[evidences_id_dict[i]][:evidence_max_len])
    temp_text.append(word2idx["<sep>"])
    if len(temp_text) < all_max_len:
        temp_text.extend([word2idx["<pad>"]] * (all_max_len - len(temp_text)))
    cur_data['text'] = temp_text
    test_cls_data.append(cur_data)

json.dump(dev_cls_data, open("dev_cls_data.json", "w"))
json.dump(test_cls_data, open("test_cls_data.json", "w"))

In [108]:
import subprocess

# proc = subprocess.Popen(["python", "eval.py", "--predictions", "data\dev_predict.json", "--groundtruth", "data\dev-claims.json"
# ], stdout=subprocess.PIPE, shell=True)
# (out, err) = proc.communicate()
# print(str(out))

# 高自动化模型/预处理选择，可以自动读取准确度
output = subprocess.check_output("python eval.py --predictions data/dev_predict.json --groundtruth data/dev-claims.json", shell=True)
output_str = output.decode('utf-8')

# Split the output into lines
output_lines = output_str.strip().split('\n')

# Format the output
formatted_lines = []
for line in output_lines:
    metric, value = line.split('=')
    metric = metric.strip()
    value = value.strip()
    formatted_line = f"{metric}: {value}"
    formatted_lines.append(formatted_line)

# Join the formatted lines into a single string
formatted_output = '\n'.join(formatted_lines)
print(formatted_output)

Evidence Retrieval F-score (F): 0.0476911976911977
Claim Classification Accuracy (A): 1.0
Harmonic Mean of F and A: 0.09104056194476966


## Object Oriented Programming codes here

*You can use multiple code snippets. Just add more if needed*