In [None]:
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
from transformers import DataCollatorForTokenClassification

In [None]:
import pandas as pd

def _get_predicates_from_sentence(lines):
    # here we expect that sentence is a number of lines
    pred_positions = []
    pred_columns = []

    for line in lines[2:]:
        split_line = line.split('\t')
        if any([c == 'V' for c in split_line[11:]]):
            pred_index = split_line.index('V')
            if not split_line[0].isdigit():
                print('found shit:', split_line[0:1])
            else:
                pred_positions.append(int(split_line[0])-1)
                pred_columns.append(pred_index)
    return pred_positions, pred_columns



def process_file(conll_file)->pd.DataFrame:
    big_df = pd.DataFrame(columns=['sentence', 'predicate', 'pred columns', 'labels'])
    with open(conll_file) as f:
        text = f.read()
    sentences = text.split('\n\n')  # split by empty line - sent id+text+table with features
    for s in sentences:
        lines = s.split('\n')
        if lines[0].startswith('# propbank'):
            lines = lines[1:]
        if lines[0].startswith('# newdoc'):
            lines = lines[1:]


        if len(lines) > 1:
            sentence_words_list = [l.split('\t')[1] for l in lines[2:]]
            pred_idxs, pred_cols = _get_predicates_from_sentence(lines)

            labels = find_tokens_args(lines, pred_cols)

            for idx, col, label in zip(pred_idxs, pred_cols, labels):
                word = lines[idx+2].split('\t')[1]
                big_df.loc[len(big_df.index)] = [sentence_words_list, word, col, ', '.join(label)]

    print('process_file(): dataframe len:', len(big_df))
    return big_df


def _get_context_of_predicate(sentence_words_list, word, idx):
    context = ['_', word, '_']

    if idx >= 1:
        context[0] = sentence_words_list[idx-1]

    if idx < len(sentence_words_list)-1:
        context[2] = sentence_words_list[idx+1]

    return context


def advanced_process_file(conll_file)->pd.DataFrame:
    big_df = pd.DataFrame(columns=['sentence', 'predicate', 'pred columns', 'context', 'labels'])
    with open(conll_file) as f:
        text = f.read()
    sentences = text.split('\n\n')  # split by empty line - sent id+text+table with features
    for s in sentences:
        lines = s.split('\n')
        if lines[0].startswith('# propbank'):
            lines = lines[1:]
        if lines[0].startswith('# newdoc'):
            lines = lines[1:]
        if len(lines) > 1:
            sentence_words_list = [l.split('\t')[1] for l in lines[2:]]
            pred_idxs, pred_cols = _get_predicates_from_sentence(lines)
            labels = find_tokens_args(lines, pred_cols)
            for idx, col, label in zip(pred_idxs, pred_cols, labels):
                word = lines[idx+2].split('\t')[1]
                context = _get_context_of_predicate(sentence_words_list, word, idx)
                big_df.loc[len(big_df.index)] = [sentence_words_list, word, col, context, ', '.join(label)]

    print('advanced_process_file(): dataframe len:', len(big_df))
    return big_df


def find_tokens_args(lines, pred_cols):
    labels = []
    for i, predicate_col in enumerate(pred_cols):
        labels.append([])
        for line in lines[2:]:
            tags = line.split('\t')
            try:
                label = tags[predicate_col]
                if label == '':
                    label = '_'
                labels[i].append(label)
            except:
                pass
    return labels


def extract_features(dataframe)->pd.DataFrame:
    raise ValueError("do not use this method!!")
    df = pd.DataFrame(columns=['sentences', 'labels', 'labels_list'])
    df.sentences = [a + ['[SEP]', b] for a, b in zip(dataframe['sentence'].values, dataframe['predicate'].values)]
    # for now I put here the index of word, but it should be the label we predict
    df.labels = dataframe['pred columns values']
    df.labels_list = [l.split(', ') for l in df.labels]
    return df


In [None]:
import pandas as pd
from datasets import Dataset, DatasetDict, load_metric
import transformers
from transformers import AutoTokenizer
import numpy as np

task = "ner"
model_checkpoint = "distilbert-base-uncased"
labels_list = None
metric = load_metric("seqeval")

batch_size = 32 # subject to change, the bigger the better, but should fit into memory

def convert_to_dataset(train:pd.DataFrame,
                       val:pd.DataFrame,
                       test:pd.DataFrame)->DatasetDict:
    global labels_list
    train_ds = Dataset.from_pandas(train)
    val_ds = Dataset.from_pandas(val)
    test_ds = Dataset.from_pandas(test)

    ds = DatasetDict()

    ds['train'] = train_ds
    ds['validation'] = val_ds
    ds['test'] = test_ds

    if not labels_list:
        labels_list = get_labels_list_from_dataset(ds)

    return ds


def get_labels_list_from_dataset(ds:DatasetDict):
    labels_set = set()

    for ds_name in ['train', 'test', 'validation']:
        for label in ds[ds_name]['labels']:
            vals = label.split(', ')
            for v in vals:
                labels_set.add(v)
    labels_list = list(labels_set)
    return labels_list


class Tokenizer:
    def __init__(self, model_checkpoint, labels_list) -> None:
        """
        :param model_checkpoint
        """
        self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
        assert isinstance(self.tokenizer, transformers.PreTrainedTokenizerFast), "tokenizer is not PreTrainedTokenizerFast!"

        self.labels_list = labels_list


    def _tokenize_input_string(self, input):
        if isinstance(input, str):
            return self.tokenizer(input, truncation=True)
        elif isinstance(input, list):
            return self.tokenizer(input, truncation=True, is_split_into_words=True)
        else:
            raise TypeError(f'tokenizer input should be str or list, got {type(input)}')


    def tokenize_align_labels_no_pred(self, examples):
        global labels_list
        tokenized_sentences = self.tokenizer(examples["sentence"], truncation=True, is_split_into_words=True)

        list_of_labels_list = [l.split(', ') for l in examples['labels']]

        labels_out = []
        for i, (sentence, labels_as_list) in enumerate(zip(examples['sentence'], list_of_labels_list)):
            tokenized_sentence = self.tokenizer(sentence, truncation=True, is_split_into_words=True)
            labels = []
            for word_id in tokenized_sentence.word_ids():
                try:
                    labels.append(-100 if word_id is None else labels_list.index(labels_as_list[word_id]))
                except:
                    labels.append(labels_list.index('_')) # for specific example with 28 words and 27 labels
            labels_out.append(labels)

        tokenized_sentences['labels'] = labels_out

        return tokenized_sentences


    def tokenize_and_align_labels_pred(self, examples):
        global labels_list
        tokenized_sentences = self.tokenizer(examples["sentence"], truncation=False, is_split_into_words=True)
        tokenized_predicates = self.tokenizer(examples["predicate"], truncation=False, is_split_into_words=False)

        tokenized_inputs = dict()
        for key in tokenized_sentences.keys():
            tokenized_inputs[key] = [v1 + v2[1:] for v1, v2 in zip(tokenized_sentences[key], tokenized_predicates[key])]

        list_of_labels_list = [l.split(', ') for l in examples['labels']]

        labels_out = []
        for i, (sentence, predicate, labels_as_list) in enumerate(zip(examples['sentence'], examples['predicate'], list_of_labels_list)):
            # sentence = ex['sentence']
            tokenized_sentence = self.tokenizer(sentence, truncation=True, is_split_into_words=True)
            labels = []
            pred_position = sentence.index(predicate)
            for word_id in tokenized_sentence.word_ids():
                try:
                    labels.append(-100 if word_id is None else labels_list.index(labels_as_list[word_id]))
                except:
                    labels.append(labels_list.index('_')) # for specific example with 28 words and 27 labels

            count = tokenized_sentence.word_ids().count(pred_position)

            labels += [labels_list.index('_')]*count
            labels.append(-100)

            labels_out.append(labels)

        tokenized_inputs['labels'] = labels_out

        return tokenized_inputs

    def tokenize_and_align_labels_context(self, examples):
        global labels_list
        tokenized_sentences = self.tokenizer(examples["sentence"], truncation=False, is_split_into_words=True)
        tokenized_context = self.tokenizer(examples["context"], truncation=False, is_split_into_words=True)

        tokenized_inputs = dict()
        for key in tokenized_sentences.keys():
            tokenized_inputs[key] = [v1 + v2[1:] for v1, v2 in zip(tokenized_sentences[key], tokenized_context[key])]

        list_of_labels_list = [l.split(', ') for l in examples['labels']]

        labels_out = []
        for i, (sentence, context, labels_as_list) in enumerate(zip(examples['sentence'], examples['context'], list_of_labels_list)):
            tokenized_sentence = self.tokenizer(sentence, truncation=True, is_split_into_words=True)
            labels = []
            pred_positions = [sentence.index(c) if c!='_' else -1 for c in context]
            word_ids = tokenized_sentence.word_ids()
            for word_id in word_ids:
                try:
                    labels.append(-100 if word_id is None else labels_list.index(labels_as_list[word_id]))
                except:
                    labels.append(labels_list.index('_')) # for specific example with 28 words and 27 labels


            base_count = len(self.tokenizer('_')['input_ids'])-2
            for pred_position in pred_positions:
                if pred_position == -1:
                    labels += [labels_list.index('_')]*base_count
                    continue
                count = word_ids.count(pred_position)
                labels += [labels_list.index(labels_as_list[pred_position])]*count
            labels.append(-100)

            labels_out.append(labels)

        tokenized_inputs['labels'] = labels_out
        return tokenized_inputs

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [labels_list[p] for (p, l) in zip(prediction, label) if l != -100 and p < len(labels_list)]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [labels_list[l] for (p, l) in zip(prediction, label) if l != -100 and p < len(labels_list)]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }


In [None]:
task = "ner" # Should be one of "ner", "pos" or "chunk"
model_checkpoint = "distilbert-base-uncased"
batch_size = 32

In [None]:
df_val = advanced_process_file('/content/en_ewt-up-dev.conllu')
df_train = advanced_process_file('/content/en_ewt-up-train.conllu')
df_test = advanced_process_file('/content/en_ewt-up-test.conllu')

# df_val = process_file('/content/en_ewt-up-dev.conllu')
# df_train = process_file('/content/en_ewt-up-train.conllu')
# df_test = process_file('/content/en_ewt-up-test.conllu')

In [None]:
dataset = convert_to_dataset(df_train, df_val, df_test)
labels_list = get_labels_list_from_dataset(dataset)
print(sorted(labels_list))

In [None]:
tok = Tokenizer(model_checkpoint, labels_list)

In [None]:
# tokenized_datasets = dataset.map(tok.tokenize_and_align_labels_context, batched=True)
tokenized_datasets = dataset.map(tok.tokenize_and_align_labels_pred, batched=True)

In [None]:
# model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(labels_list))
model = AutoModelForTokenClassification.from_pretrained('/content/model_checkpoints/pred', num_labels=len(labels_list))

In [None]:
model_name = model_checkpoint.split("/")[-1]
args = TrainingArguments(
    f"model_checkpoints/pred",
    evaluation_strategy = 'epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=1,
    weight_decay=0.01,
    push_to_hub=False,
)

data_collator = DataCollatorForTokenClassification(tok.tokenizer)
metric = load_metric("seqeval")

trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tok.tokenizer,
    compute_metrics=compute_metrics
)


In [None]:
for i, (s, l) in enumerate(zip(tokenized_datasets['train']['input_ids'], tokenized_datasets['train']['labels'])):
    if len(s) != len(l):
        print(f'{i}: s:{len(s)}, l:{len(l)}')

In [None]:
# trainer.train()

In [None]:
trainer.evaluate()

In [None]:
predictions_raw, labels, _ = trainer.predict(tokenized_datasets["validation"])
predictions = np.argmax(predictions_raw, axis=2)

In [None]:
true_predictions = [
    [labels_list[p] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]
true_labels = [
    [labels_list[l] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]

results = metric.compute(predictions=true_predictions, references=true_labels)
results

In [None]:
# trainer.save_model()

In [None]:
df = pd.DataFrame(columns=['sentence', 'prediction', 'gold'])
for tokens, prediction, gold in zip(tokenized_datasets['validation']['input_ids'], true_predictions, true_labels):
    sentence = tok.tokenizer.decode(tokens)
    df.loc[len(df.index)] = [sentence, prediction, gold]

In [None]:
df.to_csv('base.csv')

In [None]:
df.head()