In [1]:
import platform
import numpy as np
import pandas as pd
import random

import torch 
from torch import optim 
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch import nn

from tqdm.notebook import tqdm
from transformers import AutoTokenizer

# enable tqdm in pandas
tqdm.pandas()

# select device
if torch.cuda.is_available():
    device = torch.device('cuda')
elif 'arm64' in platform.platform():
    device = torch.device('mps') # 'mps'
else:
    device = torch.device('cpu')
print(f'device: {device.type}') 

# random seed
seed = 1234

# pytorch ignores this label in the loss
ignore_index = -100

# set random seed
if seed is not None:
    print(f'random seed: {seed}')
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

# which transformer to use
transformer_name = "bert-base-cased" # 'xlm-roberta-base' # 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(transformer_name)

device: mps
random seed: 1234


In [2]:

# map labels to the first token in each word
def align_labels(word_ids, labels, label_to_index):
    label_ids = []
    previous_word_id = None
    for word_id in word_ids:
        if word_id is None or word_id == previous_word_id:
            # ignore if not a word or word id has already been seen
            label_ids.append(ignore_index)
        else:
            # get label id for corresponding word
            label_id = label_to_index[labels[word_id]]
            label_ids.append(label_id)
        # remember this word id
        previous_word_id = word_id
    
    return label_ids
            
# build a set of labels in the dataset            
def read_label_set(fn):
    labels = set()
    with open(fn) as f:
        for index, line in enumerate(f):
            line = line.strip()
            tokens = line.split()
            if tokens != []:
                label = tokens[-1]
                labels.add(label)
    return labels

# converts a two-column file in the basic MTL format ("word \t label") into a dataframe
def read_dataframe(fn, label_to_index, task_id):
    # now build the actual dataframe for this dataset
    data = {'words': [], 'str_labels': [], 'input_ids': [], 'word_ids': [], 'labels': [], 'task_ids': []}
    with open(fn) as f:
        sent_words = []
        sent_labels = [] 
        for index, line in tqdm(enumerate(f)):
            line = line.strip()
            tokens = line.split()
            if tokens == []:
                data['words'].append(sent_words)
                data['str_labels'].append(sent_labels)
                
                # tokenize each sentence
                token_input = tokenizer(sent_words, is_split_into_words = True)  
                token_ids = token_input['input_ids']
                word_ids = token_input.word_ids(batch_index = 0)
                
                # map labels to the first token in each word
                token_labels = align_labels(word_ids, sent_labels, label_to_index)
                
                data['input_ids'].append(token_ids)
                data['word_ids'].append(word_ids)
                data['labels'].append(token_labels)
                data['task_ids'].append(task_id)
                sent_words = []
                sent_labels = [] 
            else:
                sent_words.append(tokens[0])
                sent_labels.append(tokens[1])
    return pd.DataFrame(data)


In [3]:
class Task():
    def __init__(self, task_id, train_file_name, dev_file_name, test_file_name):
        self.task_id = task_id
        # we need an index of labels first
        self.labels = read_label_set(train_file_name)
        self.index_to_label = {i:t for i,t in enumerate(self.labels)}
        self.label_to_index = {t:i for i,t in enumerate(self.labels)}
        self.num_labels = len(self.index_to_label)
        # create data frames for the datasets
        self.train_df = read_dataframe(train_file_name, self.label_to_index, self.task_id)
        self.dev_df = read_dataframe(dev_file_name, self.label_to_index, self.task_id)
        self.test_df = read_dataframe(test_file_name, self.label_to_index, self.task_id)
                

In [4]:
ner_task = Task(0, "data/conll-ner/train_small.txt", "data/conll-ner/dev.txt", "data/conll-ner/test.txt")
pos_task = Task(1, "data/pos/train_small.txt", "data/pos/dev.txt", "data/pos/test.txt")

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

In [5]:
ner_task.train_df

Unnamed: 0,words,str_labels,input_ids,word_ids,labels,task_ids
0,"[EU, rejects, German, call, to, boycott, Briti...","[B-ORG, O, B-MISC, O, O, O, B-MISC, O, O]","[101, 7270, 22961, 1528, 1840, 1106, 21423, 14...","[None, 0, 1, 2, 3, 4, 5, 6, 7, 7, 8, None]","[-100, 8, 4, 2, 4, 4, 4, 2, 4, -100, 4, -100]",0
1,"[Peter, Blackburn]","[B-PER, I-PER]","[101, 1943, 14428, 102]","[None, 0, 1, None]","[-100, 7, 6, -100]",0
2,"[BRUSSELS, 1996-08-22]","[B-LOC, O]","[101, 26660, 13329, 12649, 15928, 1820, 118, 4...","[None, 0, 0, 0, 0, 1, 1, 1, 1, 1, None]","[-100, 5, -100, -100, -100, 4, -100, -100, -10...",0
3,"[The, European, Commission, said, on, Thursday...","[O, B-ORG, I-ORG, O, O, O, O, O, O, B-MISC, O,...","[101, 1109, 1735, 2827, 1163, 1113, 9170, 1122...","[None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 1...","[-100, 4, 8, 0, 4, 4, 4, 4, 4, 4, 2, 4, 4, 4, ...",0
4,"[Germany, 's, representative, to, the, Europea...","[B-LOC, O, O, O, O, B-ORG, I-ORG, O, O, O, B-P...","[101, 1860, 112, 188, 4702, 1106, 1103, 1735, ...","[None, 0, 1, 1, 2, 3, 4, 5, 6, 7, 7, 8, 9, 10,...","[-100, 5, 4, -100, 4, 4, 4, 8, 0, 4, -100, 4, ...",0
5,"["", We, do, n't, support, any, such, recommend...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...","[101, 107, 1284, 1202, 183, 112, 189, 1619, 12...","[None, 0, 1, 2, 3, 3, 3, 4, 5, 6, 7, 8, 9, 10,...","[-100, 4, 4, 4, 4, -100, -100, 4, 4, 4, 4, 4, ...",0
6,"[He, said, further, scientific, study, was, re...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...","[101, 1124, 1163, 1748, 3812, 2025, 1108, 2320...","[None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 1...","[-100, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, ...",0
7,"[He, said, a, proposal, last, month, by, EU, F...","[O, O, O, O, O, O, O, B-ORG, O, O, B-PER, I-PE...","[101, 1124, 1163, 170, 5835, 1314, 2370, 1118,...","[None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 1...","[-100, 4, 4, 4, 4, 4, 4, 4, 8, 4, 4, 7, 6, -10...",0
8,"[Fischler, proposed, EU-wide, measures, after,...","[B-PER, O, B-MISC, O, O, O, O, B-LOC, O, B-LOC...","[101, 17355, 9022, 2879, 3000, 7270, 118, 2043...","[None, 0, 0, 0, 1, 2, 2, 2, 3, 4, 5, 6, 7, 8, ...","[-100, 7, -100, -100, 4, 2, -100, -100, 4, 4, ...",0
9,"[But, Fischler, agreed, to, review, his, propo...","[O, B-PER, O, O, O, O, O, O, O, B-ORG, O, O, O...","[101, 1252, 17355, 9022, 2879, 2675, 1106, 318...","[None, 0, 1, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,...","[-100, 4, 7, -100, -100, 4, 4, 4, 4, 4, 4, 4, ...",0


In [6]:
pos_task.train_df

Unnamed: 0,words,str_labels,input_ids,word_ids,labels,task_ids
0,"[Pierre, Vinken, ,, 61, years, old, ,, will, j...","[NNP, NNP, ,, CD, NNS, JJ, ,, MD, VB, DT, NN, ...","[101, 4855, 25354, 6378, 117, 5391, 1201, 1385...","[None, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11...","[-100, 44, 44, -100, 17, 12, 33, 30, 17, 22, 4...",1
1,"[Mr., Vinken, is, chairman, of, Elsevier, N.V....","[NNP, NNP, VBZ, NN, IN, NNP, NNP, ,, DT, NNP, ...","[101, 1828, 119, 25354, 6378, 1110, 3931, 1104...","[None, 0, 0, 1, 1, 2, 3, 4, 5, 5, 5, 6, 6, 6, ...","[-100, 44, -100, 44, -100, 10, 21, 18, 44, -10...",1
2,"[Rudolph, Agnew, ,, 55, years, old, and, forme...","[NNP, NNP, ,, CD, NNS, JJ, CC, JJ, NN, IN, NNP...","[101, 19922, 138, 8376, 2246, 117, 3731, 1201,...","[None, 0, 1, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,...","[-100, 44, 44, -100, -100, 17, 12, 33, 30, 42,...",1
3,"[A, form, of, asbestos, once, used, to, make, ...","[DT, NN, IN, NN, RB, VBN, TO, VB, NNP, NN, NNS...","[101, 138, 1532, 1104, 1112, 12866, 11990, 151...","[None, 0, 1, 2, 3, 3, 3, 4, 5, 6, 7, 8, 9, 10,...","[-100, 16, 21, 18, 21, -100, -100, 20, 19, 6, ...",1
4,"[The, asbestos, fiber, ,, crocidolite, ,, is, ...","[DT, NN, NN, ,, NN, ,, VBZ, RB, JJ, IN, PRP, V...","[101, 1109, 1112, 12866, 11990, 12753, 117, 17...","[None, 0, 1, 1, 1, 2, 3, 4, 4, 4, 4, 4, 5, 6, ...","[-100, 16, 21, -100, -100, 21, 17, 21, -100, -...",1
...,...,...,...,...,...,...
196,"[On, the, other, hand, ,, had, it, existed, th...","[IN, DT, JJ, NN, ,, VBD, PRP, VBN, RB, ,, NNP,...","[101, 1212, 1103, 1168, 1289, 117, 1125, 1122,...","[None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 1...","[-100, 18, 16, 30, 21, 17, 41, 9, 19, 20, 17, ...",1
197,"[Mr., Cray, ,, who, could, n't, be, reached, f...","[NNP, NNP, ,, WP, MD, RB, VB, VBN, IN, NN, ,, ...","[101, 1828, 119, 140, 6447, 117, 1150, 1180, 1...","[None, 0, 0, 1, 1, 2, 3, 4, 5, 5, 5, 6, 7, 8, ...","[-100, 44, -100, 44, -100, 17, 39, 22, 20, -10...",1
198,"[Regarded, as, the, father, of, the, supercomp...","[VBN, IN, DT, NN, IN, DT, NN, ,, NNP, NNP, VBD...","[101, 23287, 26541, 1112, 1103, 1401, 1104, 11...","[None, 0, 0, 1, 2, 3, 4, 5, 6, 6, 6, 6, 7, 8, ...","[-100, 19, -100, 18, 16, 21, 18, 16, 21, -100,...",1
199,"[At, Cray, Computer, ,, he, will, be, paid, $,...","[IN, NNP, NNP, ,, PRP, MD, VB, VBN, $, CD, .]","[101, 1335, 140, 6447, 6701, 117, 1119, 1209, ...","[None, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, ...","[-100, 18, 44, -100, 44, 17, 9, 22, 43, 19, 28...",1


In [7]:
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.models.bert.modeling_bert import BertModel, BertPreTrainedModel
from transformers import PreTrainedModel
from transformers import AutoConfig, AutoModel

# This class is adapted from: https://towardsdatascience.com/how-to-create-and-train-a-multi-task-transformer-model-18c54a146240
class TokenClassificationModel(BertPreTrainedModel):    
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config, add_pooling_layer=False)
        self.output_heads = nn.ModuleDict() # these are initialized in add_heads
        self.init_weights()
        
    def add_heads(self, tasks):
        for task in tasks:
            head = TokenClassificationHead(self.bert.config.hidden_size, task.num_labels, config.hidden_dropout_prob)
            # ModuleDict requires keys to be strings
            self.output_heads[str(task.task_id)] = head
        return self
    
    def summarize_heads(self):
        print(f'Found {len(self.output_heads)} heads')
        for task_id in self.output_heads:
            self.output_heads[task_id].summarize(task_id)
        
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, task_ids=None, **kwargs):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            **kwargs,
        )
        sequence_output = outputs[0]
        
        #print(f'batch size = {len(input_ids)}')
        #print(f'task_ids in this batch: {task_ids}')
        
        # generate specific predictions and losses for each task head
        unique_task_ids_list = torch.unique(task_ids).tolist()
        logits = None
        loss_list = []
        for unique_task_id in unique_task_ids_list:
            task_id_filter = task_ids == unique_task_id
            filtered_sequence_output = sequence_output[task_id_filter]
            filtered_labels = None if labels is None else labels[task_id_filter]
            filtered_attention_mask = None if attention_mask is None else attention_mask[task_id_filter]
            #print(f'size of batch for task {unique_task_id} is: {len(filtered_sequence_output)}')
            logits, task_loss = self.output_heads[str(unique_task_id)].forward(
                filtered_sequence_output, None,
                filtered_labels,
                filtered_attention_mask,
            )
            if filtered_labels is not None:
                loss_list.append(task_loss)
                
        loss = None if len(loss_list) == 0 else torch.stack(loss_list)
                    
        # logits are only used for eval, in which case we handle a single task at a time
        # TODO: allow all tasks in the forward pass at inference                     
        return TokenClassifierOutput(
            loss=loss.mean(),
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

class TokenClassificationHead(nn.Module):
    def __init__(self, hidden_size, num_labels, dropout_p=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout_p)
        self.classifier = nn.Linear(hidden_size, num_labels)
        self.num_labels = num_labels

        self._init_weights()

    def _init_weights(self):
        self.classifier.weight.data.normal_(mean=0.0, std=0.02)
        if self.classifier.bias is not None:
            self.classifier.bias.data.zero_()
            
    def summarize(self, task_id):
        print(f"Task {task_id} with {self.num_labels} labels.")
        print(f'Dropout is {self.dropout}')
        print(f'Classifier layer is {self.classifier}')

    def forward(self, sequence_output, pooled_output, labels=None, attention_mask=None, **kwargs):
        sequence_output_dropout = self.dropout(sequence_output)
        logits = self.classifier(sequence_output_dropout)
        
        loss = None
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()            
            inputs = logits.view(-1, self.num_labels)
            targets = labels.view(-1)
            loss = loss_fn(inputs, targets)

        return logits, loss

In [8]:
tasks = [ner_task, pos_task]
config = AutoConfig.from_pretrained(transformer_name)
model= TokenClassificationModel.from_pretrained(transformer_name, config=config).add_heads(tasks)
model.summarize_heads()

Some weights of the model checkpoint at bert-base-cased were not used when initializing TokenClassificationModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing TokenClassificationModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TokenClassificationModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Found 2 heads
Task 0 with 9 labels.
Dropout is Dropout(p=0.1, inplace=False)
Classifier layer is Linear(in_features=768, out_features=9, bias=True)
Task 1 with 45 labels.
Dropout is Dropout(p=0.1, inplace=False)
Classifier layer is Linear(in_features=768, out_features=45, bias=True)


In [9]:
from sklearn.metrics import accuracy_score

def compute_metrics(eval_pred):
    # gold labels
    label_ids = eval_pred.label_ids
    # predictions
    pred_ids = np.argmax(eval_pred.predictions, axis=-1)
    # collect gold and predicted labels, ignoring ignore_index label
    y_true, y_pred = [], []
    batch_size, seq_len = pred_ids.shape
    for i in range(batch_size):
        for j in range(seq_len):
            if label_ids[i, j] != ignore_index:
                y_true.append(label_ids[i][j]) #index_to_label[label_ids[i][j]])
                y_pred.append(pred_ids[i][j]) #index_to_label[pred_ids[i][j]])
    # return computed metrics
    return {'accuracy': accuracy_score(y_true, y_pred)}

In [10]:
from datasets import Dataset, DatasetDict

ds = DatasetDict()
ds['train'] = Dataset.from_pandas(pd.concat([ner_task.train_df, pos_task.train_df]))
ds['validation'] = Dataset.from_pandas(pd.concat([ner_task.dev_df, pos_task.dev_df]))
ds['test'] = Dataset.from_pandas(pd.concat([ner_task.test_df, pos_task.test_df]))

# these are no longer needed; discard them to save memory
ner_task.train_df = None
pos_task.train_df = None

ds

DatasetDict({
    train: Dataset({
        features: ['words', 'str_labels', 'input_ids', 'word_ids', 'labels', 'task_ids', '__index_level_0__'],
        num_rows: 244
    })
    validation: Dataset({
        features: ['words', 'str_labels', 'input_ids', 'word_ids', 'labels', 'task_ids', '__index_level_0__'],
        num_rows: 8504
    })
    test: Dataset({
        features: ['words', 'str_labels', 'input_ids', 'word_ids', 'labels', 'task_ids', '__index_level_0__'],
        num_rows: 6101
    })
})

In [11]:
from sklearn.metrics import classification_report

# compute accuracy
def evaluation_classification_report(trainer, task, name, useTest=False):
    print(f"Test classification report for task {name}:")
    num_labels = task.num_labels
    df = task.test_df if useTest == False else task.dev_df
    ds = Dataset.from_pandas(df)
    output = trainer.predict(ds)
    label_ids = output.label_ids.reshape(-1)
    predictions = output.predictions.reshape(-1, num_labels)
    predictions = np.argmax(predictions, axis=-1)
    mask = label_ids != ignore_index
    
    y_true = label_ids[mask]
    y_pred = predictions[mask]
    target_names = [task.index_to_label.get(ele, "") for ele in range(num_labels)]
    print(target_names)
    
    total = 0
    correct = 0
    for(t, p) in zip(y_true, y_pred):
        total = total + 1
        if t == p:
            correct = correct + 1
    accuracy = correct / total
    
    report = classification_report(
        y_true, y_pred,
        target_names=target_names
    )
    print(report)
    print(f'locally computed accuracy: {accuracy}')
    return accuracy

# compute loss and accuracy
def evaluate(trainer, task, name):
    print(f"Evaluating on the validation dataset for task {name}:")
    ds = Dataset.from_pandas(task.dev_df)
    scores = trainer.evaluate(ds)
    acc = evaluation_classification_report(trainer, task, name, useTest = False)
    return scores, acc

In [12]:
import os

def save_task(task_head, task, task_name, task_checkpoint):
    numpy_weights = task_head.classifier.weight.cpu().detach().numpy()
    numpy_bias = task_head.classifier.bias.cpu().detach().numpy()
    labels = task.labels
    #print(f"Shape of weights: {numpy_weights.shape}")
    #print(f"Weights are:\n{numpy_weights}")
    #print(f"Shape of bias: {numpy_bias.shape}")
    #print(f"Bias is: {numpy_bias}")
    #print(f"Labels are: {labels}")
    
    os.makedirs(task_checkpoint, exist_ok = True)
    lf = open(task_checkpoint + "/labels", "w")
    for label in labels:
        lf.write(f'{label}\n')
    lf.close()
    
    wf = open(task_checkpoint + "/weights", "w")
    wf.write(f'{numpy_weights.shape[0]} {numpy_weights.shape[1]}\n')
    for i, x in enumerate(numpy_weights):
        for j, y in enumerate(x):
            wf.write(f'{y} ')
        wf.write('\n')
    wf.close()
    
    bf = open(task_checkpoint + "/biases", "w")
    bf.write(f'{numpy_bias.shape[0]}\n')
    for i, x in enumerate(numpy_bias):
        bf.write(f'{x} ')
    bf.write('\n')
    bf.close()

def onnx_save(model, checkpoint):
    orig_words = ["Using", "transformers", "with", "ONNX", "runtime"]
    token_input = tokenizer(orig_words, is_split_into_words = True, return_tensors = "pt")
    print(token_input)
    token_ids = token_input['input_ids']
                
    inputs = (token_ids) 
    input_names = ["token_ids"] 
    output_names = ["sequence_output"]

    torch.onnx.export(model.bert,
        inputs,
        checkpoint,
        export_params=True,
        do_constant_folding=True,
        input_names = input_names,
        output_names = output_names,
        opset_version=10, 
        dynamic_axes = {"token_ids": {1: 'sent length'}}
    )


In [13]:
from transformers import TrainingArguments
from transformers import Trainer
from transformers import DataCollatorForTokenClassification
import time
from datetime import timedelta

epochs = 1
batch_size = 128
weight_decay = 0.01
use_mps_device = True if str(device) == 'mps' else False
model_name = f'{transformer_name}-mtl'

data_collator = DataCollatorForTokenClassification(tokenizer)
last_checkpoint = None

for epoch in range(1, epochs + 1):
    print(f'STARTING EPOCH {epoch}')
    if last_checkpoint != None:
        print(f'Resuming from checkpoint {last_checkpoint}')
            
    training_args = TrainingArguments(
        output_dir=model_name,
        log_level='error',
        num_train_epochs=1, # one epoch at a time
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        # evaluation_strategy='epoch',
        do_eval=False, # we will evaluate each task explicitly
        weight_decay=weight_decay,
        resume_from_checkpoint = last_checkpoint,
        use_mps_device = use_mps_device
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        # compute_metrics=compute_metrics,
        train_dataset=ds['train'],
        # eval_dataset=ds['validation'],
        tokenizer=tokenizer
    )
    
    model.summarize_heads()

    start_time = time.monotonic()
    trainer.train()
    end_time = time.monotonic()
    print(f"Elapsed time for epoch {epoch}: {timedelta(seconds=end_time - start_time)}")

    ner_scores, ner_acc = evaluate(trainer, ner_task, "NER")
    pos_scores, pos_acc = evaluate(trainer, pos_task, "POS")
    macro_loss = (ner_scores['eval_loss'] + pos_scores['eval_loss'])/2
    print(f'DEV MACRO LOSS FOR EPOCH {epoch}: {macro_loss}\n\n')
    macro_acc = (ner_acc + pos_acc)/2
    print(f'DEV MACRO ACC FOR EPOCH {epoch}: {macro_acc}')

    # save the transformer encoder + the head for each task
    last_checkpoint = training_args.output_dir + '/mtl_model_epoch' + str(epoch)
    trainer.save_model(last_checkpoint)
    save_task(model.output_heads["0"], ner_task, "NER", last_checkpoint + "/ner_head")
    save_task(model.output_heads["1"], pos_task, "POS", last_checkpoint + "/pos_head")
    
    # save the ONNX model
    onnx_checkpoint = training_args.output_dir + '/onnx_epoch' + str(epoch)
    onnx_save(model, onnx_checkpoint)


STARTING EPOCH 1
Found 2 heads
Task 0 with 9 labels.
Dropout is Dropout(p=0.1, inplace=False)
Classifier layer is Linear(in_features=768, out_features=9, bias=True)
Task 1 with 45 labels.
Dropout is Dropout(p=0.1, inplace=False)
Classifier layer is Linear(in_features=768, out_features=45, bias=True)


  output, inverse_indices, counts = torch._unique2(


Step,Training Loss


Elapsed time for epoch 1: 0:00:03.097341
Evaluating on the validation dataset for task NER:


Test classification report for task NER:
['I-ORG', 'I-LOC', 'B-MISC', 'I-MISC', 'O', 'B-LOC', 'I-PER', 'B-PER', 'B-ORG']
              precision    recall  f1-score   support

       I-ORG       0.05      0.01      0.02       835
       I-LOC       0.00      0.00      0.00       257
      B-MISC       0.00      0.00      0.00       702
      I-MISC       0.00      0.00      0.00       216
           O       0.82      0.95      0.88     38554
       B-LOC       0.02      0.01      0.01      1668
       I-PER       0.40      0.00      0.00      1156
       B-PER       0.07      0.00      0.00      1617
       B-ORG       0.03      0.00      0.00      1661

    accuracy                           0.79     46666
   macro avg       0.16      0.11      0.10     46666
weighted avg       0.70      0.79      0.73     46666

locally computed accuracy: 0.7857326533236189
Evaluating on the validation dataset for task POS:


Test classification report for task POS:
['VBG', 'WDT', 'WP$', "''", '``', '#', 'TO', 'NNPS', 'PDT', 'PRP', 'VBZ', 'RBS', 'CD', '.', '-LRB-', 'EX', 'DT', ',', 'IN', 'VBN', 'RB', 'NN', 'MD', 'VBP', 'JJS', 'WRB', 'PRP$', 'FW', '$', '-RRB-', 'JJ', 'SYM', 'RBR', 'NNS', 'POS', 'JJR', ':', 'RP', 'UH', 'WP', 'LS', 'VBD', 'CC', 'VB', 'NNP']
              precision    recall  f1-score   support

         VBG       0.00      0.00      0.00       817
         WDT       0.10      0.00      0.01       280
         WP$       0.00      0.00      0.00        21
          ''       0.02      0.01      0.01       512
          ``       0.01      0.01      0.01       531
           #       0.00      0.00      0.00         5
          TO       0.00      0.00      0.00      1240
        NNPS       0.00      0.00      0.00        42
         PDT       0.00      0.00      0.00        18
         PRP       0.01      0.00      0.00      1055
         VBZ       0.01      0.01      0.01      1233
         RBS    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'input_ids': tensor([[  101,  7993, 11303,  1468,  1114, 21748,  2249,  3190,  1576,  4974,
           102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}


RuntimeError: Placeholder storage has not been allocated on MPS device!

In [50]:
#model = TokenClassificationModel.from_pretrained('bert-base-cased-mtl/mtl_model_epoch2', local_files_only=True)
#model.summarize_heads()

In [51]:
ner_acc = evaluation_classification_report(trainer, ner_task, "NER", useTest = True)
pos_acc = evaluation_classification_report(trainer, pos_task, "POS", useTest = True)
macro_acc = (ner_acc + pos_acc)/2
print(f"MTL macro accuracy: {macro_acc}")

Test classification report for task NER:
['I-MISC', 'B-MISC', 'O', 'B-PER', 'I-PER', 'B-LOC', 'I-ORG', 'B-ORG', 'I-LOC']
              precision    recall  f1-score   support

      I-MISC       0.00      0.00      0.00       346
      B-MISC       0.31      0.53      0.39       922
           O       0.97      0.97      0.97     42975
       B-PER       0.58      0.59      0.59      1842
       I-PER       0.70      0.50      0.58      1307
       B-LOC       0.53      0.76      0.63      1837
       I-ORG       0.32      0.33      0.32       751
       B-ORG       0.27      0.16      0.20      1341
       I-LOC       0.00      0.00      0.00       257

    accuracy                           0.88     51578
   macro avg       0.41      0.43      0.41     51578
weighted avg       0.89      0.88      0.88     51578

locally computed accuracy: 0.8836131684051339
Test classification report for task POS:
['VBN', 'POS', 'IN', 'VBZ', 'PRP', 'WP', '.', 'VB', 'WRB', '-LRB-', 'DT', "''", '``', '

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
