In [33]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [3]:
# data imports
import glob
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from collections import Counter
# model imports
# similar to https://github.com/huggingface/transformers/blob/14e9d2954c3a7256a49a3e581ae25364c76f521e/src/transformers/models/bert/modeling_bert.py
import logging

from dataclasses import dataclass

import torch
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss, HingeEmbeddingLoss

from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.models.bert.modeling_bert import BertEmbeddings, BertModel, BertPreTrainedModel, BertConfig
from transformers.utils import logging

from transformers.file_utils import ModelOutput
from typing import Optional

# logger = logging.get_logger(__name__)

# Trainer imports
from transformers import Trainer, TrainingArguments

# Dataset

In [71]:
# each training instance consists of a paragraph, edu splits and edu labels

class ArgumentDataset(Dataset):
    
    def __init__(self, tokenizer, paragraph_files, edus_files, labels_files, max_len=256, max_edu_seq=50):
        
        self.max_len, self.max_edu_seq = max_len, max_edu_seq
        self.tokenizer = tokenizer
        
        self.paragraphs = [''.join(open(file).readlines()) for file in glob.glob(paragraph_files)]
        self.edus = [open(file).readlines() for file in glob.glob(edus_files)]
        self.labels = [open(file).readlines() for file in glob.glob(labels_files)]
        self.label2id = {'B-claim': 1, 'I-claim': 2, 'B-premise': 3, 'I-premise': 4, 'O' : 0}
        
        ######
        # filterout = [7, 24, 89, 231, 298, 348, 370, 373, 421, 473, 481, 485, 496, 508, 599, 680] # linux file order
        filterout = [27, 99, 163, 183, 191, 194, 226, 239, 259, 271, 289, 377, 410, 582, 626, 656] # mac file order
        for i in filterout[::-1]:
            self.paragraphs.pop(i); self.edus.pop(i); self.labels.pop(i)
        ######
        
        self.labels = [
            [{'edu': line.rstrip().split('\t')[0], 'tokens': line.rstrip().split('\t')[1]} for line in para_labels]
                      for para_labels in self.labels
        ]
        self.label_edus = [[0 for _ in range(self.max_edu_seq)] for _ in self.labels]
        self.label_tokens = [[[0 for _ in range(self.max_len)] for _ in range(self.max_edu_seq)] for _ in self.labels]
        
        self.para_edu_splits = [' [EDU_SEP] '.join([line.rstrip() for line in para_edus]) for para_edus in self.edus]
        self.para_edu_splits_tok = self.tokenizer(self.para_edu_splits, truncation=True, padding='max_length', max_length=self.max_len)
                
        for i, para_edus in enumerate(self.labels):
            for j in range(min(self.max_edu_seq, len(self.labels[i]))):
                self.label_edus[i][j] = self.label2id[self.labels[i][j]['edu']]
                for k in range(min(self.max_len, len(self.labels[i][j]['tokens'].split()))):
                    self.label_tokens[i][j][k] = self.label2id[self.labels[i][j]['tokens'].split()[k]]
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, i):
        return {'input_ids': self.para_edu_splits_tok['input_ids'][i],
                'attention_mask': self.para_edu_splits_tok['attention_mask'][i],
                'token_type_ids': self.para_edu_splits_tok['token_type_ids'][i],
                'edu_labels' : self.label_edus[i],
                'token_labels' : self.label_tokens[i]
               }
        

In [70]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
tokenizer.add_special_tokens({'additional_special_tokens':['[EDU_SEP]']})

paragraph_files, edus_files, labels_files = '../data/ets/para_text/*', '../data/ets/para_edu/*', '../data/ets/para_edu_label/*'
argdata_old = ArgumentDataset(tokenizer, paragraph_files, edus_files, labels_files)

# Model

In [61]:
''' this model should use BertModel to extract the embeddings of the paragraph
    then do the following:
        1. use EDU split to get the embeddings of each tokens of an EDU
        2. represent the EDU as the average embedding of its member tokens
        3. pass the EDU embedding to the classifier layer to make predictions
        4. calculate the loss based on the predicted and gold EDU labels
'''

class BertForPhraseClassification(BertPreTrainedModel):

    # _keys_to_ignore_on_load_unexpected = [r"pooler"]

    def __init__(self, config, edu_sequence_length=50):
        super().__init__(config)
        self.config = config
        self.num_labels = config.num_labels
        self.edu_sep_id = 30522
        self.edu_sequence_length = edu_sequence_length

        self.bert = BertModel(config, add_pooling_layer=False)
        # self.bert = BertModel(config)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        # print(self.config)
        self.init_weights()


    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        edu_labels=None,
        token_labels=None,
    ):

        outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        sequence_output = outputs[0]
        # print(input_ids.shape, outputs.shape)
        edu_embeddings = self.get_edu_emb(input_ids, sequence_output)
        print(edu_embeddings.shape)
        
        edu_embeddings = self.dropout(edu_embeddings)
        logits = self.classifier(edu_embeddings)
        print(logits.shape)

        loss = None
        if edu_labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), edu_labels.view(-1))
            
        output = (logits,) + outputs[2:]
        return ((loss,) + output) if loss is not None else output

    
    def get_edu_emb(self, input_ids, outputs, edu_seperator_id=30522):
        'Returns a sequence of 50 EDUs (padded or truncated) per paragraph represented as the average embeddings of their tokens'
        batch_size = outputs.shape[0]
        batch_para_edu_avg_emb = torch.zeros(batch_size,  self.edu_sequence_length, self.config.hidden_size)
        
        # finding the "[EDU_SEP]" token in each paragraph given a batch of input_ids
        # seperators[0] has paragraph id in a batch
        # seperators[1] has index of "[EDU_SEP]" in all paragraphs
        seperators = (input_ids == edu_seperator_id).nonzero(as_tuple=True)
        
        # getting the number of edus in each paragraph
        edu_per_para, all_keys, i = [], range(batch_size), 0
        for k, v in Counter([t.item() for t in seperators[0]]).items():
            while k != all_keys[i] and i < len(all_keys):
                edu_per_para.append(0); i+=1
            edu_per_para.append(v); i+=1
        
        # calculating the average embeddings for each EDU
        seperators_idx = 0
        for i, edu_count_per_para in enumerate(edu_per_para):
            prev_edu_sep = 0
            for j in range(edu_count_per_para):
                if j < self.edu_sequence_length:
                    cur_edu_sep = seperators[1][seperators_idx].item()
                    # print(i, j, prev_edu_sep, cur_edu_sep)
                    assert input_ids[i][prev_edu_sep] in [101, edu_seperator_id]
                    assert input_ids[i][cur_edu_sep] in [edu_seperator_id, 102]

                    batch_para_edu_avg_emb[i][j] = torch.mean(outputs[i][prev_edu_sep+1:cur_edu_sep], dim=0)
                    prev_edu_sep = cur_edu_sep

                seperators_idx += 1;
            
            if j < self.edu_sequence_length -1:
                # calculating embeddings of the last EDU that is between [EDU_SEP] and [SEP]
                cur_edu_sep = (input_ids[i] == 102).nonzero(as_tuple=True)[0].item()
                # print(i, j+1, prev_edu_sep, cur_edu_sep)
                assert input_ids[i][prev_edu_sep] in [101, edu_seperator_id]
                assert input_ids[i][cur_edu_sep] in [edu_seperator_id, 102] 
                batch_para_edu_avg_emb[i][j] = torch.mean(outputs[i][prev_edu_sep+1:cur_edu_sep], dim=0)
            
        return batch_para_edu_avg_emb
    

        

In [111]:
len(argdata)

783

# Trainer

In [9]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
tokenizer.add_special_tokens({'additional_special_tokens':['[EDU_SEP]']})

paragraph_files, edus_files, labels_files = '../data/ets/para_text/*', '../data/ets/para_edu/*', '../data/ets/para_edu_label/*'
argdata = ArgumentDataset(tokenizer, paragraph_files, edus_files, labels_files)

config = BertConfig.from_pretrained("bert-base-uncased", num_labels=5)
edu_tag_model = BertForPhraseClassification.from_pretrained("bert-base-uncased", config=config)
edu_tag_model.resize_token_embeddings(len(tokenizer))

training_args = TrainingArguments(
    output_dir='./',      
    num_train_epochs=3,
    per_device_train_batch_size=4,  
    save_steps=0, 
    do_train=True,
    dataloader_drop_last=True
)

trainer = Trainer(
    model=edu_tag_model,        
    args=training_args,                
    train_dataset=argdata,
)

# trainer.train()

# Sketching

In [4]:
paragraph_files, edus_files, labels_files = '../data/ets/para_text/*', '../data/ets/para_edu/*', '../data/ets/para_edu_label/*'
for pf, ef, lf in zip( glob.glob('../data/ets/para_text/*'), glob.glob('../data/ets/para_edu/*'), glob.glob('../data/ets/para_edu_label/*') ):
    print('{}\n{}\n{}\n\n'.format(pf, ef, lf))

In [203]:
input_ids_batch = torch.tensor(argdata.para_edu_splits_tok['input_ids'][:16])
outputs, edu_seperator_id = torch.rand(16, 512, 768), 30522
seperators = (input_ids_batch == edu_seperator_id).nonzero(as_tuple=True)

edu_per_para = list(Counter([t.item() for t in seperators[0]]).values())

seperators[0].shape, seperators[1].shape, input_ids_batch.shape

(torch.Size([356]), torch.Size([356]), torch.Size([16, 512]))

In [66]:
L = [[(i, j+1, len(line.rstrip().split('\t'))) for j, line in enumerate(para_labels) if len(line.rstrip().split('\t')) != 2] for i, para_labels in enumerate(labels)]
[l[0][0] for l in L if len(l)>0]

[293, 342, 363, 365, 412, 463, 470, 473, 483, 494, 583, 662]

In [19]:
edus = open('../data/ets/para_edu/1-AbxwVc5Fvl-bX7-RBPA8fLHOghrgNifbzu0hYLtSRY_1.txt').readlines()
text = open('../data/ets/para_text/1-AbxwVc5Fvl-bX7-RBPA8fLHOghrgNifbzu0hYLtSRY_1.txt').readlines()
labels = open('../data/ets/para_edu_label_all/1-AbxwVc5Fvl-bX7-RBPA8fLHOghrgNifbzu0hYLtSRY_1.txt').readlines()
para_labels = [{'edu': line.rstrip().split('\t')[0], 'tokens': line.rstrip().split('\t')[1]} for line in labels]

In [103]:
from transformers import BertModel, BertConfig
config = BertConfig()
bert = BertModel(config, add_pooling_layer=False)
# bert = bert.from_pretrained('bert-based-uncased')

string_tok = tokenizer('I love movies')
ids = torch.tensor(string_tok['input_ids']).unsqueeze(0)
attn_mask = torch.tensor(string_tok['attention_mask']).unsqueeze(0)
seg_ids = torch.tensor(string_tok['token_type_ids']).unsqueeze(0)

res = bert(ids, attention_mask=attn_mask, token_type_ids=seg_ids)