In [26]:
import torch
import json
from transformers import AutoTokenizer, DistilBertForSequenceClassification

In [27]:
def load_jsonl(data_path):
    data = []
    with open(data_path) as f:
        for line in f:
            data.append(json.loads(line))
    return data

In [54]:
bert_model = AutoTokenizer.from_pretrained("distilbert-base-uncased")
train_data = load_jsonl("reddit_data/reddit_cands_100.jsonl")
# print(train_data)
print(type(train_data))
# print(train_data[0])
# print(data[0].get("text"))

<class 'list'>


In [55]:
# text_list = []
# for entry in train_data:
#     text_list += entry['text']
# print(len(text_list))

text_list = [entry['text'] for entry in train_data]

In [56]:
class MatchSum(nn.Module):
    def __init__(self, candidate_num, encoder, hidden_size=768):
        super().__init()
        self.encoder = AutoTokenizer.from_pretrained("distilbert-base-uncased")
        
        self.hidden_size = hidden_size
        self.candidate_num  = candidate_num
        
    def forward(self, dataset):
        batch_size = 10
        pad_id = 0
        
        # extract text_id, summary_id, candidate_id from dataset
        text_id = [entry['text'] for entry in dataset]
        summary_id = [entry['summary'] for entry in dataset]
        candidate_id = [entry['idx'] for entry in dataset]
        
        
        # get document embedding
        input_mask = ~(text_id == pad_id)
        out = bert_model(text_id, attention_mask=input_mask)[0] # last layer
        doc_emb = out[:, 0, :]
        assert doc_emb.size() == (batch_size, self.hidden_size) # [batch_size, hidden_size]

        # get summary embedding
        input_mask = ~(summary_id == pad_id)
        out = bert_model(summary_id, attention_mask=input_mask)[0] # last layer
        summary_emb = out[:, 0, :]
        assert summary_emb.size() == (batch_size, self.hidden_size) # [batch_size, hidden_size]

        # get summary score
        summary_score = torch.cosine_similarity(summary_emb, doc_emb, dim=-1)

        # get candidate embedding
        candidate_num = candidate_id.size(1)
        candidate_id = candidate_id.view(-1, candidate_id.size(-1))
        input_mask = ~(candidate_id == pad_id)
        out = bert_model(candidate_id, attention_mask=input_mask)[0]
        candidate_emb = out[:, 0, :].view(batch_size, candidate_num, self.hidden_size)  # [batch_size, candidate_num, hidden_size]
        assert candidate_emb.size() == (batch_size, candidate_num, self.hidden_size)

        # get candidate score
        doc_emb = doc_emb.unsqueeze(1).expand_as(candidate_emb)
        cand_score = torch.cosine_similarity(candidate_emb, doc_emb, dim=-1) # [batch_size, candidate_num]
        assert cand_score.size() == (batch_size, candidate_num)

        return {'cand_score': cand_score, 'summary_score': summary_score }
    
    
    def train_model(self, train_data, num_iterations, batch_size = 10):
        
        self.train()
        
        for t in range(num_iterations):
            if batch_size >= len(train_data):
                batch = train_data
            else:
                batch_indices = np.random.randint(len(train_data), size=batch_size)
                batch = train_data[batch_indices]
        
        scores = self.forward(batch)
        