In [1]:
from typing import List, Dict
import codecs
import torch
import sys
#import myutils
from transformers import AutoModel, AutoTokenizer


# set seed for consistency
torch.manual_seed(8446)
# Set some constants
MLM = 'distilbert-base-cased'
BATCH_SIZE = 8
LEARNING_RATE = 0.00001
EPOCHS = 3
# We have an UNK label for robustness purposes, it makes it easier to run on
# data with other labels, or without labels.
UNK = "[UNK]"
MAX_TRAIN_SENTS=64
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"


In [2]:
def read_data(file_path: str):
    sents = []
    labels = []

    for line in codecs.open(file_path, encoding='utf-8'):
        tok = line.strip().split('\t')
        if len(tok)>=2:
            sents.append(tok[0])
            labels.append(tok[1])

    return sents, labels

def labels2lookup(labels: List[str], PAD):

    id2label = [PAD]
    label2id = {PAD: 0}
    for label in labels:
        if label not in label2id:
            label2id[label] = len(label2id)
            id2label.append(label)
    return id2label, label2id

def tok(data: List[str], tokzr: AutoTokenizer):
    tok_data = []

    for sent in data:
        tok_data.append(tokzr.encode(sent))

    return tok_data

def to_batch(text: List[List[int]], labels: List[int], batch_size: int, padding_id: int, DEVICE: str):
    text_batches = []
    label_batches = []
    num_batches = int(len(text)/batch_size)

    for batch_idx in range(num_batches):
        beg_idx = batch_idx * batch_size
        end_idx = (batch_idx+1) * batch_size
        max_len = max([len(sent) for sent in text[beg_idx:end_idx]])

        new_batch_text = torch.full((batch_size, max_len), padding_id, dtype = torch.long, device=DEVICE)
        new_batch_labels = torch.zeros(batch_size, dtype = torch.long, device=DEVICE)

        for sent_idx in range(batch_size):
            new_batch_labels[sent_idx] = labels[beg_idx + sent_idx]
            for word_idx, word_id in enumerate(text[beg_idx + sent_idx]):
                new_batch_text[sent_idx][word_idx] = word_id
        text_batches.append(new_batch_text)
        label_batches.append(new_batch_labels)

    return text_batches, label_batches

In [13]:
class ClassModel(torch.nn.Module):
    def __init__(self, nlabels: int, mlm: str):

        super().__init__()

        # The transformer model to use
        self.mlm = AutoModel.from_pretrained(mlm)

        # Find the size of the output of the masked language model
        if hasattr(self.mlm.config, 'hidden_size'):
            self.mlm_out_size = self.mlm.config.hidden_size
        elif hasattr(self.mlm.config, 'dim'):
            self.mlm_out_size = self.mlm.config.dim
        else: # if not found, guess
            self.mlm_out_size = 768

        # Create prediction layer
        self.hidden_to_label = torch.nn.Linear(self.mlm_out_size, nlabels)

    def forward(self, input: torch.tensor):
        """
        Forward pass
    
        Parameters
        ----------
        input : torch.tensor
            Tensor with wordpiece indices. shape=(batch_size, max_sent_len).

        Returns
        -------
        output_scores : torch.tensor
            ?. shape=(?,?)
        """
        # Run transformer model on input
        mlm_out = self.mlm(input)

        # Keep only the last layer: shape=(batch_size, max_len, DIM_EMBEDDING)
        mlm_out = mlm_out.last_hidden_state
        # Keep only the output for the first ([CLS]) token: shape=(batch_size, DIM_EMBEDDING)
        mlm_out = mlm_out[:,:1,:].squeeze()

        # Matrix multiply to get scores for each label: shape=(?,?)
        output_scores = self.hidden_to_label(mlm_out)

        return output_scores

    def run_eval(self, text_batched: List[torch.tensor], labels_batched: List[torch.tensor]):
        """
        Run evaluation: predict and score
    
        Parameters
        ----------
        text_batched : List[torch.tensor]
            list with batches of text, containing wordpiece indices.
        labels_batched : List[torch.tensor]
            list with batches of labels (converted to ints).
        model : torch.nn.module
            The model to use for prediction.
    
        Returns
        -------
        score : float
            accuracy of model on labels_batches given feats_batches
        """
        self.eval()
        match = 0
        total = 0
        for sents, labels in zip(text_batched, labels_batched):
            output_scores = self.forward(sents)
            pred_labels = torch.argmax(output_scores, 1)
            print(pred_labels)
            for gold_label, pred_label in zip(labels, pred_labels):
                total += 1
                if gold_label.item() == pred_label.item():
                    match+= 1
        return(match/total)       
         
    def predict(self, text_batched: List[torch.tensor], labels_batched: List[torch.tensor]):
        self.eval()
        match = 0
        total = 0
        all_pred_labels=[]
        for sents, labels in zip(text_batched, labels_batched):
            output_scores = self.forward(sents)
            pred_labels = torch.argmax(output_scores, 1)
            all_pred_labels.append(pred_labels)
            print(pred_labels)
            for gold_label, pred_label in zip(labels, pred_labels):
                total += 1
                if gold_label.item() == pred_label.item():
                    match+= 1
        return(all_pred_labels)       

if __name__ == '__main__':
    print('reading data...')
    train_text, train_labels = read_data('en_ewt_nn_train.conll')
    train_text = train_text[:MAX_TRAIN_SENTS]
    train_labels = train_labels[:MAX_TRAIN_SENTS]
    
    
    id2label, label2id = labels2lookup(train_labels, UNK)
    NLABELS = len(id2label)
    #print(train_labels)
    #print(label2id)
    #print(id2label)
    
    train_labels = [label2id[label] for label in train_labels]
    
    dev_text, dev_labels = read_data('en_ewt_nn_answers_test.conll')
    #print(dev_labels)
    
    convertedlabels=[]
    for label in dev_labels:
        if label in label2id:
            convertedlabels.append(label2id[label])
        else:
            convertedlabels.append(label2id[UNK])
    #dev_labels = [label2id[label] for label in dev_labels]
    dev_labels=convertedlabels
    
    print('tokenizing...')
    tokzr = AutoTokenizer.from_pretrained(MLM)
    train_tokked = tok(train_text, tokzr)
    dev_tokked = tok(dev_text, tokzr)
    PAD = tokzr.pad_token_id
    
    print('converting to batches...')
    train_text_batched, train_labels_batched = to_batch(train_tokked, train_labels, BATCH_SIZE, PAD, DEVICE)
    # Note, some data is trown away if len(text_tokked)%BATCH_SIZE!= 0
    dev_text_batched, dev_labels_batched = to_batch(dev_tokked, dev_labels, BATCH_SIZE, PAD, DEVICE)
    
    print('initializing model...')
    model = ClassModel(NLABELS, MLM)
    model.to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    loss_function = torch.nn.CrossEntropyLoss(ignore_index=0, reduction='sum')
    
    print('training...')
    for epoch in range(EPOCHS):
        print('=====================')
        print('starting epoch ' + str(epoch))
        model.train() 
    
        # Loop over batches
        loss = 0
        for batch_idx in range(0, len(train_text_batched)):
            optimizer.zero_grad()

            output_scores = model.forward(train_text_batched[batch_idx])
            batch_loss = loss_function(output_scores, train_labels_batched[batch_idx])
            loss += batch_loss.item()
    
            batch_loss.backward()

            optimizer.step()
    
        dev_score = model.run_eval(dev_text_batched, dev_labels_batched)
        print('Loss: {:.2f}'.format(loss))
        print('Acc(dev): {:.2f}'.format(100*dev_score))
        print()
    


reading data...
tokenizing...
converting to batches...
initializing model...


Some weights of the model checkpoint at distilbert-base-cased were not used when initializing DistilBertModel: ['vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


training...
starting epoch 0
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 4])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 

In [16]:
preds = model.predict(dev_text_batched, dev_labels_batched)

tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([3,

In [17]:
preds

[tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([3, 3, 3, 3,

In [12]:
print(dev_labels_batched[0])

print(dev_text_batched[0])

tensor([3, 3, 3, 3, 3, 7, 3, 4])
tensor([[  101,  1327,   102,     0,     0,     0],
        [  101,  1846,   102,     0,     0,     0],
        [  101,  1110,   102,     0,     0,     0],
        [  101,  5029,   102,     0,     0,     0],
        [  101,  1107,   102,     0,     0,     0],
        [  101,   146, 13855, 10337,   102,     0],
        [  101,   136,   102,     0,     0,     0],
        [  101,  4104,  9610, 10589,  1162,   102]])


In [11]:
label2id

{'[UNK]': 0,
 'B-ORG': 1,
 'I-ORG': 2,
 'O': 3,
 'B-LOCderiv': 4,
 'B-PER': 5,
 'I-PER': 6,
 'B-LOC': 7}

In [9]:
print(dev_labels)

[3, 3, 3, 3, 3, 7, 3, 4, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 7, 3, 3, 3, 3, 3, 3, 7, 3, 3, 3, 3, 3, 3, 3, 3, 7, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 7, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 7, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 7, 3, 7, 3, 3, 4, 3, 3, 3, 3, 7, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 3, 3, 3, 3, 5, 6, 3, 3, 3, 0, 3, 3, 0, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 3, 3, 3, 3, 3, 3, 3, 7, 3, 3, 3, 3, 3, 0, 3, 3, 0, 3, 3, 3, 3, 3, 1, 2, 2, 3, 3, 7, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 2, 2, 3, 3, 3, 3, 7, 3, 3, 1, 2, 3, 1, 3, 3, 3, 3, 3, 3, 3, 7, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 2, 3, 7, 0, 3, 3, 3, 5, 6, 6, 3, 7, 3, 3, 3, 3, 3, 3, 3, 4, 3, 7, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 