In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [29]:
# data imports
import glob
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from collections import Counter
# model imports
# similar to https://github.com/huggingface/transformers/blob/14e9d2954c3a7256a49a3e581ae25364c76f521e/src/transformers/models/bert/modeling_bert.py
import logging

from dataclasses import dataclass

import torch
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss, HingeEmbeddingLoss

from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.models.bert.modeling_bert import BertEmbeddings, BertModel, BertPreTrainedModel, BertConfig
from transformers.utils import logging

from transformers.file_utils import ModelOutput
from typing import Optional

# logger = logging.get_logger(__name__)

# Trainer imports
from transformers import Trainer, TrainingArguments

# Dataset

In [3]:
# each training instance consists of a paragraph, edu splits and edu labels

class ArgumentDataset(Dataset):
    
    def __init__(self, tokenizer, paragraph_files, edus_files, labels_files, max_len=256, max_edu_seq=50):
        
        self.max_len, self.max_edu_seq = max_len, max_edu_seq
        self.tokenizer = tokenizer
        tokenizer.add_special_tokens({'additional_special_tokens':['[EDU_SEP]']})
        
        self.paragraphs = [''.join(open(file).readlines()) for file in glob.glob(paragraph_files)]
        self.edus = [open(file).readlines() for file in glob.glob(edus_files)]
        self.labels = [open(file).readlines() for file in glob.glob(labels_files)]
        self.label2id = {'B-claim': 1, 'I-claim': 2, 'B-premise': 3, 'I-premise': 4, 'O' : 0}
        
        ######
        # filterout = [7, 24, 89, 231, 298, 348, 370, 373, 421, 473, 481, 485, 496, 508, 599, 680] # linux file order
        filterout = [27, 99, 163, 183, 191, 194, 226, 239, 259, 271, 289, 377, 410, 582, 626, 656] # mac file order
        for i in filterout[::-1]:
            self.paragraphs.pop(i); self.edus.pop(i); self.labels.pop(i)
        ######
        
        self.labels = [
            [{'edu': line.rstrip().split('\t')[0], 'tokens': line.rstrip().split('\t')[1]} for line in para_labels]
                      for para_labels in self.labels
        ]
        self.label_edus = [[0 for _ in range(self.max_edu_seq)] for _ in self.labels]
        self.label_tokens = [[[0 for _ in range(self.max_len)] for _ in range(self.max_edu_seq)] for _ in self.labels]
        
        self.para_edu_splits = [' [EDU_SEP] '.join([line.rstrip() for line in para_edus]) for para_edus in self.edus]
        self.para_edu_splits_tok = self.tokenizer(self.para_edu_splits, truncation=True, padding='max_length', max_length=self.max_len)
                
        for i, para_edus in enumerate(self.labels):
            for j in range(min(self.max_edu_seq, len(self.labels[i]))):
                self.label_edus[i][j] = self.label2id[self.labels[i][j]['edu']]
                for k in range(min(self.max_len, len(self.labels[i][j]['tokens'].split()))):
                    self.label_tokens[i][j][k] = self.label2id[self.labels[i][j]['tokens'].split()[k]]
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, i):
        return {'input_ids': self.para_edu_splits_tok['input_ids'][i],
                'attention_mask': self.para_edu_splits_tok['attention_mask'][i],
                'token_type_ids': self.para_edu_splits_tok['token_type_ids'][i],
                'edu_labels' : self.label_edus[i],
                'token_labels' : self.label_tokens[i]
               }
        

In [53]:
# argdata.edus[0], argdata.para_edu_splits[0], argdata.label_edus[0], argdata.labels[0]
# for i, j in zip(argdata.edus[0], argdata.labels[0]):
#     print(i, j)
argdata.para_edu_splits[0], argdata.para_edu_splits_tok['input_ids'][0]

("The first reason [EDU_SEP] war is needed [EDU_SEP] is people being treated equally . [EDU_SEP] Mexican - American War [EDU_SEP] ( 1846 - 1848 ) - This war was fought [EDU_SEP] following the annexation of Texas , [EDU_SEP] with Mexico [EDU_SEP] still claiming the land as their own . [EDU_SEP] Pretty much the Mexicans fought a war , [EDU_SEP] for Texas The U.S. outfought the Mexicans , [EDU_SEP] retaining Texas [EDU_SEP] and incorporating it as a state . [EDU_SEP] American Revolution [EDU_SEP] ( 1775 - 1783 ) - [EDU_SEP] The American Revolution gave the 13 North American colonies independence from British rule [EDU_SEP] and established the United States of America Internet Encyclopedia of Philosophy . [EDU_SEP] The Americans fought a war [EDU_SEP] to claim independence from England . [EDU_SEP] These were two of many examples [EDU_SEP] that war sorted out all of the problems . [EDU_SEP] In both of those situations they tried to talk [EDU_SEP] and it did n't work , [EDU_SEP] so they went

# Model

In [61]:
''' this model should use BertModel to extract the embeddings of the paragraph
    then do the following:
        1. use EDU split to get the embeddings of each tokens of an EDU
        2. represent the EDU as the average embedding of its member tokens
        3. pass the EDU embedding to the classifier layer to make predictions
        4. calculate the loss based on the predicted and gold EDU labels
'''

class BertForPhraseClassification(BertPreTrainedModel):

    # _keys_to_ignore_on_load_unexpected = [r"pooler"]

    def __init__(self, config, edu_sequence_length=50):
        super().__init__(config)
        self.config = config
        self.num_labels = config.num_labels
        self.edu_sep_id = 30522
        self.edu_sequence_length = edu_sequence_length

        self.bert = BertModel(config, add_pooling_layer=False)
        # self.bert = BertModel(config)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        # print(self.config)
        self.init_weights()


    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        edu_labels=None,
        token_labels=None,
    ):

        outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        sequence_output = outputs[0]
        # print(input_ids.shape, outputs.shape)
        edu_embeddings = self.get_edu_emb(input_ids, sequence_output)
        print(edu_embeddings.shape)
        
        edu_embeddings = self.dropout(edu_embeddings)
        logits = self.classifier(edu_embeddings)
        print(logits.shape)

        loss = None
        if edu_labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), edu_labels.view(-1))
            
        output = (logits,) + outputs[2:]
        return ((loss,) + output) if loss is not None else output

    
    def get_edu_emb(self, input_ids, outputs, edu_seperator_id=30522):
        'Returns a sequence of 50 EDUs (padded or truncated) per paragraph represented as the average embeddings of their tokens'
        batch_size = outputs.shape[0]
        batch_para_edu_avg_emb = torch.zeros(batch_size,  self.edu_sequence_length, self.config.hidden_size)
        
        # finding the "[EDU_SEP]" token in each paragraph given a batch of input_ids
        # seperators[0] has paragraph id in a batch
        # seperators[1] has index of "[EDU_SEP]" in all paragraphs
        seperators = (input_ids == edu_seperator_id).nonzero(as_tuple=True)
        
        # getting the number of edus in each paragraph
        edu_per_para, all_keys, i = [], range(batch_size), 0
        for k, v in Counter([t.item() for t in seperators[0]]).items():
            while k != all_keys[i] and i < len(all_keys):
                edu_per_para.append(0); i+=1
            edu_per_para.append(v); i+=1
        
        # calculating the average embeddings for each EDU
        seperators_idx = 0
        for i, edu_count_per_para in enumerate(edu_per_para):
            prev_edu_sep = 0
            for j in range(edu_count_per_para):
                if j < self.edu_sequence_length:
                    cur_edu_sep = seperators[1][seperators_idx].item()
                    # print(i, j, prev_edu_sep, cur_edu_sep)
                    assert input_ids[i][prev_edu_sep] in [101, edu_seperator_id]
                    assert input_ids[i][cur_edu_sep] in [edu_seperator_id, 102]

                    batch_para_edu_avg_emb[i][j] = torch.mean(outputs[i][prev_edu_sep+1:cur_edu_sep], dim=0)
                    prev_edu_sep = cur_edu_sep

                seperators_idx += 1;
            
            if j < self.edu_sequence_length -1:
                # calculating embeddings of the last EDU that is between [EDU_SEP] and [SEP]
                cur_edu_sep = (input_ids[i] == 102).nonzero(as_tuple=True)[0].item()
                # print(i, j+1, prev_edu_sep, cur_edu_sep)
                assert input_ids[i][prev_edu_sep] in [101, edu_seperator_id]
                assert input_ids[i][cur_edu_sep] in [edu_seperator_id, 102] 
                batch_para_edu_avg_emb[i][j] = torch.mean(outputs[i][prev_edu_sep+1:cur_edu_sep], dim=0)
            
        return batch_para_edu_avg_emb
    

        

# Trainer

In [62]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

paragraph_files, edus_files, labels_files = '../data/ets/para_text/*', '../data/ets/para_edu/*', '../data/ets/para_edu_label/*'
argdata = ArgumentDataset(tokenizer, paragraph_files, edus_files, labels_files)

# edu_sep_id = tokenizer.convert_tokens_to_ids('[EDU_SEP]')
config = BertConfig.from_pretrained("bert-base-uncased", num_labels=5)
edu_tag_model = BertForPhraseClassification.from_pretrained("bert-base-uncased", config=config)
edu_tag_model.resize_token_embeddings(len(tokenizer))

training_args = TrainingArguments(
    output_dir='./',      
    num_train_epochs=3,
    per_device_train_batch_size=4,  
    save_steps=0, 
    do_train=True,
    dataloader_drop_last=True
)

trainer = Trainer(
    model=edu_tag_model,        
    args=training_args,                
    train_dataset=argdata,
)

loading configuration file https://huggingface.co/bert-base-uncased/resolve/main/config.json from cache at /Users/tariq/.cache/huggingface/transformers/3c61d016573b14f7f008c02c4e51a366c67ab274726fe2910691e2a761acf43e.37395cee442ab11005bcd270f3c34464dc1704b715b5d7d52b1a461abe3b9e4e
Model config BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.11.3",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

loading file https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt from cache at /Users

In [63]:
trainer.train()
# trainer.evaluate(test_data)

***** Running training *****
  Num examples = 783
  Num Epochs = 3
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 1
  Total optimization steps = 585


torch.Size([4, 50, 768])
torch.Size([4, 50, 5])


Step,Training Loss


torch.Size([4, 50, 768])
torch.Size([4, 50, 5])
torch.Size([4, 50, 768])
torch.Size([4, 50, 5])


KeyboardInterrupt: 

# Sketching

In [76]:
paragraph_files, edus_files, labels_files = '../data/ets/para_text/*', '../data/ets/para_edu/*', '../data/ets/para_edu_label/*'
for pf, ef, lf in zip( glob.glob('../data/ets/para_text/*'), glob.glob('../data/ets/para_edu/*'), glob.glob('../data/ets/para_edu_label/*') ):
    print('{}\n{}\n{}\n\n'.format(pf, ef, lf))

../data/ets/para_text/1B0AnGw2RHR8NKeHQZIJ1bNHHgbohhltY6J8W7CMwuBk_last_2.txt
../data/ets/para_edu/1B0AnGw2RHR8NKeHQZIJ1bNHHgbohhltY6J8W7CMwuBk_last_2.txt
../data/ets/para_edu_label/1B0AnGw2RHR8NKeHQZIJ1bNHHgbohhltY6J8W7CMwuBk_last_2.txt


../data/ets/para_text/137PIyRZC1ULCCQA3CTFBHAx0p0Gu72UTJVuBt4bOQsY_3.txt
../data/ets/para_edu/137PIyRZC1ULCCQA3CTFBHAx0p0Gu72UTJVuBt4bOQsY_3.txt
../data/ets/para_edu_label/137PIyRZC1ULCCQA3CTFBHAx0p0Gu72UTJVuBt4bOQsY_3.txt


../data/ets/para_text/1gAqk5scB6PMsemM2Y2r2-mvy9B8Z3rp1Pqzad6ulnM4_5.txt
../data/ets/para_edu/1gAqk5scB6PMsemM2Y2r2-mvy9B8Z3rp1Pqzad6ulnM4_5.txt
../data/ets/para_edu_label/1gAqk5scB6PMsemM2Y2r2-mvy9B8Z3rp1Pqzad6ulnM4_5.txt


../data/ets/para_text/UNASSIGNED_DOCUMENT_573_last_3.txt
../data/ets/para_edu/UNASSIGNED_DOCUMENT_573_last_3.txt
../data/ets/para_edu_label/UNASSIGNED_DOCUMENT_573_last_3.txt


../data/ets/para_text/1DcTNSXuT3VFnaBY9rCopGtzilHeBGTbVXK42qqocSP4_5.txt
../data/ets/para_edu/1DcTNSXuT3VFnaBY9rCopGtzilHeBGTbVXK42qq

In [203]:
input_ids_batch = torch.tensor(argdata.para_edu_splits_tok['input_ids'][:16])
outputs, edu_seperator_id = torch.rand(16, 512, 768), 30522
seperators = (input_ids_batch == edu_seperator_id).nonzero(as_tuple=True)

edu_per_para = list(Counter([t.item() for t in seperators[0]]).values())

seperators[0].shape, seperators[1].shape, input_ids_batch.shape

(torch.Size([356]), torch.Size([356]), torch.Size([16, 512]))

In [238]:
edu_per_para, all_keys, i = [], range(16), 0
for k, v in Counter([t.item() for t in seperators[0]]).items():
    while k != all_keys[i] and i < len(all_keys):
        edu_per_para.append(0); i+=1
    edu_per_para.append(v); i+=1

len(edu_per_para), edu_per_para

(16, [23, 32, 34, 40, 14, 45, 16, 4, 0, 19, 9, 54, 39, 0, 11, 16])

In [275]:
seperators_idx = 0
batch_para_edu_avg_emb = torch.zeros(16,  50, 768)

for i, edu_count_per_para in enumerate(edu_per_para):
    para_edu_avg_emb, prev_edu_sep = [], 0
    
    for j in range(edu_count_per_para):
        if j < 50:
            cur_edu_sep = seperators[1][seperators_idx].item()
            print(i, j, prev_edu_sep, cur_edu_sep)
            assert input_ids_batch[i][prev_edu_sep] in [101, 30522]
            assert input_ids_batch[i][cur_edu_sep] in [30522, 102]

            batch_para_edu_avg_emb[i][j] = torch.mean(outputs[i][prev_edu_sep+1:cur_edu_sep], dim=0)
            prev_edu_sep = cur_edu_sep
        
        seperators_idx += 1;
    
    if j < 50-1:
        # calculating embeddings of the last EDU that is between [EDU_SEP] and [SEP]
        cur_edu_sep = (input_ids_batch[i] == 102).nonzero(as_tuple=True)[0].item()
        print(i, j+1, prev_edu_sep, cur_edu_sep)
        assert input_ids_batch[i][prev_edu_sep] in [101, 30522]
        assert input_ids_batch[i][cur_edu_sep] in [30522, 102] 
        batch_para_edu_avg_emb[i][j] = torch.mean(outputs[i][prev_edu_sep+1:cur_edu_sep], dim=0)
    

0 0 0 4
0 1 4 8
0 2 8 15
0 3 15 20
0 4 20 31
0 5 31 38
0 6 38 41
0 7 41 50
0 8 50 60
0 9 60 75
0 10 75 78
0 11 78 86
0 12 86 89
0 13 89 96
0 14 96 110
0 15 110 123
0 16 123 129
0 17 129 136
0 18 136 143
0 19 143 153
0 20 153 163
0 21 163 172
0 22 172 178
0 23 178 186
1 0 0 15
1 1 15 21
1 2 21 26
1 3 26 36
1 4 36 51
1 5 51 58
1 6 58 67
1 7 67 70
1 8 70 77
1 9 77 87
1 10 87 93
1 11 93 101
1 12 101 119
1 13 119 126
1 14 126 136
1 15 136 145
1 16 145 159
1 17 159 169
1 18 169 178
1 19 178 183
1 20 183 197
1 21 197 217
1 22 217 225
1 23 225 230
1 24 230 236
1 25 236 245
1 26 245 254
1 27 254 262
1 28 262 272
1 29 272 279
1 30 279 283
1 31 283 292
1 32 292 301
2 0 0 4
2 1 4 15
2 2 15 28
2 3 28 38
2 4 38 42
2 5 42 45
2 6 45 58
2 7 58 62
2 8 62 75
2 9 75 81
2 10 81 86
2 11 86 93
2 12 93 109
2 13 109 118
2 14 118 121
2 15 121 128
2 16 128 139
2 17 139 142
2 18 142 152
2 19 152 160
2 20 160 169
2 21 169 176
2 22 176 181
2 23 181 188
2 24 188 199
2 25 199 207
2 26 207 215
2 27 215 227
2 28 227 23

In [55]:
len(argdata)

783

In [315]:
argdata[0]['edu_seq_input_ids'].shape, argdata[0]['edu_seq_attention_mask'].shape, argdata[0]['edu_seq_token_type_ids'].shape, \
len(argdata[0]['edu_labels']), len(argdata[0]['token_labels']), \
argdata[0]['edu_labels'][0], len(argdata[0]['token_labels'][0])

(torch.Size([50, 128]),
 torch.Size([50, 128]),
 torch.Size([50, 128]),
 50,
 50,
 1,
 128)

In [66]:
L = [[(i, j+1, len(line.rstrip().split('\t'))) for j, line in enumerate(para_labels) if len(line.rstrip().split('\t')) != 2] for i, para_labels in enumerate(labels)]
[l[0][0] for l in L if len(l)>0]

[293, 342, 363, 365, 412, 463, 470, 473, 483, 494, 583, 662]

In [19]:
edus = open('../data/ets/para_edu/1-AbxwVc5Fvl-bX7-RBPA8fLHOghrgNifbzu0hYLtSRY_1.txt').readlines()
text = open('../data/ets/para_text/1-AbxwVc5Fvl-bX7-RBPA8fLHOghrgNifbzu0hYLtSRY_1.txt').readlines()
labels = open('../data/ets/para_edu_label_all/1-AbxwVc5Fvl-bX7-RBPA8fLHOghrgNifbzu0hYLtSRY_1.txt').readlines()
para_labels = [{'edu': line.rstrip().split('\t')[0], 'tokens': line.rstrip().split('\t')[1]} for line in labels]

In [30]:
len(''.join(text).split()), len(''.join(edus).split()), sum([len(line['tokens'].split()) for line in para_labels])

(81, 81, 81)

In [139]:
seq = torch.zeros(10, 5)
seq[0] = torch.Tensor([1, 2, 3, 4, 5])
seq

tensor([[1., 2., 3., 4., 5.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])

In [103]:
from transformers import BertModel, BertConfig
config = BertConfig()
bert = BertModel(config, add_pooling_layer=False)
# bert = bert.from_pretrained('bert-based-uncased')

string_tok = tokenizer('I love movies')
ids = torch.tensor(string_tok['input_ids']).unsqueeze(0)
attn_mask = torch.tensor(string_tok['attention_mask']).unsqueeze(0)
seg_ids = torch.tensor(string_tok['token_type_ids']).unsqueeze(0)

res = bert(ids, attention_mask=attn_mask, token_type_ids=seg_ids)