# Installs

In [1]:
# !pip install -q torch==1.13.1 torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116

In [2]:
# !pip install -q -r requirements.txt

# Specs

In [3]:
!nvidia-smi

Sat Sep 16 12:29:02 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04   Driver Version: 525.116.04   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A6000    Off  | 00000000:00:05.0 Off |                  Off |
| 30%   51C    P8    32W / 300W |  19567MiB / 49140MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# Imports

In [4]:
import os
import gc
import re
import ast
import sys
import copy
import json
import time
import math
import string
import pickle
import random
import joblib
import itertools
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)

import subprocess
from joblib import Parallel, delayed
import multiprocessing

import cv2
import PIL
from PIL import Image
import matplotlib.pyplot as plt

from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold, StratifiedGroupKFold, GroupKFold
from sklearn.metrics import f1_score, mean_squared_error, accuracy_score

In [44]:
from collections import Counter
from glob import glob

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
from torch import LongTensor
from torch import nn, optim
from torch.nn import CrossEntropyLoss
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler

In [7]:
import argparse
import logging

from scipy.sparse import save_npz, load_npz
# from seqeval.metrics import f1_score, precision_score, recall_score
from tqdm import tqdm, trange
import datasets
from datasets import load_dataset, DatasetDict, Dataset as HFDataset
from datasets import concatenate_datasets, interleave_datasets
from datasets import ClassLabel, load_metric

import transformers
from transformers import (
    CONFIG_MAPPING,
    MODEL_FOR_MASKED_LM_MAPPING,
    AutoConfig,
    AutoModelForMaskedLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    EarlyStoppingCallback,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    set_seed,
    AutoModelForTokenClassification,
    AutoModelForSequenceClassification,
    DataCollatorForTokenClassification,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version

%env TOKENIZERS_PARALLELISM=false

env: TOKENIZERS_PARALLELISM=false


In [8]:
# %load_ext watermark
# %watermark --iversions

# Envs

In [9]:
def disable_warnings(strict=False):
	warnings.simplefilter('ignore')
	if strict:
		logging.disable(logging.WARNING)

def seed_everything(seed=42):
	random.seed(seed)
	os.environ['PYTHONHASHSEED'] = str(seed)
	np.random.seed(seed)
	torch.manual_seed(seed)
	torch.cuda.manual_seed(seed)
	torch.backends.cudnn.deterministic = True
	torch.backends.cudnn.benchmark = False

In [10]:
SEED = 42

disable_warnings()
seed_everything(SEED)

# Data

In [11]:
path  = '../masakhane-pos/data/'
add_path = '../masakhane-pos/transfer_corpus/'
pseudo_path = './pseudos/pos-ner-afro-xlmr-large-75L-best.csv'

langs = sorted(os.listdir(path))
add_langs = os.listdir(add_path)

In [12]:
lang2family = {
 'bam': 'Mande',
 'bbj': 'Grassfields',
 'ewe': 'Kwa',
 'fon': 'Volta-Niger',
 'hau': 'Chadic',
 'ibo': 'Volta-Niger',
 'kin': 'Bantu',
 'lug': 'Bantu',
 'mos': 'Gur',
 'nya': 'Bantu',
 'pcm': 'English-Creole',
 'sna': 'Bantu',
 'swa': 'Bantu',
 'twi': 'Kwa',
 'wol': 'Senegambia',
 'xho': 'Bantu',
 'yor': 'Volta-Niger',
 'zul': 'Bantu'
}

lang2region = {
    'bam': 'West',
    'bbj': 'Central',
    'ewe': 'West',
    'fon': 'West',
    'hau': 'West',
    'ibo': 'West',
    'kin': 'East',
    'lug': 'East',
    'mos': 'West',
    'nya': 'East',
    'pcm': 'West',
    'sna': 'South',
    'swa': 'East',
    'twi': 'South',
    'wol': 'West',
    'xho': 'South',
    'yor': 'West',
    'zul': 'South'
}

mapper = {'NOUN': 0,
 'ADJ': 1,
 'PUNCT': 2,
 'CCONJ': 3,
 'PRON': 4,
 'ADV': 5,
 'AUX': 6,
 'VERB': 7,
 'ADP': 8,
 'PART': 9,
 'SCONJ': 10,
 'PROPN': 11,
 'X': 12,
 'NUM': 13,
 'INTJ': 14,
 'SYM': 15,
 'DET': 16}

In [13]:
lang2fix = {
    'yo': 'yor',
    'bm': 'bam'
}

def load_lang_data(data_path, lang, split='train'):
    data = []
    language = lang.split('_')[0]
    language = lang2fix.get(language, language)
    
    try:
        with open(f'{data_path}{lang}/{split}.txt', 'r') as f:
            for line in f.readlines():
                line = line.rstrip()
                splits = line.split(' ')
                if len(splits) != 2:
                    continue
                data.append(splits)
        data = np.array(data).squeeze()
        data = pd.DataFrame({
            'Word': data[:, 0],
            'Pos': data[:, 1],
            'Language': language.split('_')[0]
        })
        data['family'] = lang2family.get(language, language)
        data['region'] = lang2region.get(language, language)
        return data
    except FileNotFoundError:
        return pd.DataFrame()

In [14]:
def load_lang_data_ner(data_path, lang, split='train'):
    data = []
    language = lang.split('_')[0]
    language = lang2fix.get(language, language)
    
    try:
        with open(f'{data_path}{lang}/{split}.txt', 'r') as f:
            sentences, labels = [], []
            
            for line in f.readlines():
                line = line.rstrip()
                splits = line.split(' ')
                
                if len(splits) == 2:
                    sentences.append(splits[0])
                    labels.append(splits[1])
                elif len(splits) == 1:
                    data.append(
                        [
                            ' '.join(sentences),
                            ' '.join(labels),
                        ]
                    )
                    
                    sentences, labels = [], []

        data = np.array(data).squeeze()
        data = pd.DataFrame({
            'Word': data[:, 0],
            'Pos': data[:, 1],
            'Language': language.split('_')[0]
        })
        data['family'] = lang2family.get(language, language)
        data['region'] = lang2region.get(language, language)
        return data
    except FileNotFoundError:
        return pd.DataFrame()

In [15]:
def clean_data(df):  
    # pos = df.Pos.unique()
    use_pos = ['PUNCT', 'SYM']
    
    df_others = df[~df.Pos.isin(use_pos)]
    df_toclean = df[df.Pos.isin(use_pos)]
    
    dfs = [df_others]
    
    for p in use_pos:
        df_p = df_toclean[df_toclean.Pos == p]
        words = df_p.Word.value_counts()[df_p.Word.value_counts() < 2].index.values
        df_p = df_p[~df_p.Word.isin(words)]
        
        dfs.append(df_p)
        
    return pd.concat(dfs).sample(frac=1).reset_index(drop=True)

In [16]:
def load_ner_data(with_pseudo=False):
    train = pd.concat([load_lang_data_ner(path, lang, split) for lang in langs for split in ['train', 'dev', 'test']], axis=0).dropna().reset_index(drop=True)
    add_data = (
        pd.concat([load_lang_data_ner(add_path, lang, split) for lang in add_langs for split in ['train', 'dev', 'test']], axis=0)
        .dropna()
        .drop_duplicates(subset=['Word', 'Pos'], keep='first')
        .reset_index(drop=True)
    )
    
    train = pd.concat([train, add_data])
    if with_pseudo:
        train = pd.concat([train, pd.read_csv(pseudo_path)])
    train = train.drop_duplicates().reset_index(drop=True)
    
    return train

def load_pos_data():
    train = pd.concat([load_lang_data(path, lang, split) for lang in langs for split in ['train', 'dev', 'test']], axis=0).dropna().reset_index(drop=True)
    add_data = (
        pd.concat([load_lang_data(add_path, lang, split) for lang in add_langs for split in ['train', 'dev', 'test']], axis=0)
        .dropna()
        .drop_duplicates(subset=['Word', 'Pos'], keep='first')
        .reset_index(drop=True)
    )

    train = pd.concat([train, add_data])
    train = train[train.Pos.isin(list(mapper.keys()))]

    train = train.drop_duplicates().drop_duplicates(subset=['Word', 'Language'], keep='first').reset_index(drop=True)
    # train = train.drop_duplicates().reset_index(drop=True)
    # train = clean_data(train)
    
    return train

In [17]:
with_pseudo = True

train_ner = load_ner_data(with_pseudo)
train = load_pos_data()

train.head()

Unnamed: 0,Word,Pos,Language,family,region
0,Muso,NOUN,bam,Mande,West
1,ŋana,ADJ,bam,Mande,West
2,",",PUNCT,bam,Mande,West
3,Afiriki,NOUN,bam,Mande,West
4,tilebinyanfan,NOUN,bam,Mande,West


In [18]:
train_ner.head()

Unnamed: 0,Word,Pos,Language,family,region
0,"Muso ŋana , Afiriki tilebinyanfan n' a cɛmancɛ...",NOUN ADJ PUNCT NOUN NOUN CCONJ PRON NOUN ADV C...,bam,Mande,West
1,"Ni mɔgɔ ka dɔgɔn kojugu , i bɛ mɔgɔw juguya mi...",CCONJ NOUN PART ADJ NOUN PUNCT PRON AUX NOUN N...,bam,Mande,West
2,A kɔrɔtalen a ka taɲɛ fɛ kow ɲɛmɔgɔyabaaraw la...,PRON VERB PRON PART NOUN PART NOUN NOUN PART C...,bam,Mande,West
3,Ale y' a ( basikɛti ) to a ka se ka bɔ a ka so...,PRON PART PRON PUNCT NOUN PUNCT VERB PRON PART...,bam,Mande,West
4,Sannayɛlɛn galabukɛnɛya a sera ka min fiyɛ Afi...,NOUN NOUN PRON VERB PART PRON VERB PROPN NOUN ...,bam,Mande,West


In [19]:
test = pd.read_csv('../Test.csv')
test.head()

Unnamed: 0,Id,Word,Language,Pos
0,Id00qog2f11n_0,Ne,luo,
1,Id00qog2f11n_1,otim,luo,
2,Id00qog2f11n_2,penj,luo,
3,Id00qog2f11n_3,e,luo,
4,Id00qog2f11n_4,kind,luo,


In [20]:
train.shape, train_ner.shape, test.shape

((255140, 5), (104039, 5), (32045, 4))

In [21]:
base_train = train_ner[train_ner.Language.isin(langs)].reset_index(drop=True)
add_train = train_ner[~train_ner.Language.isin(langs)].reset_index(drop=True)

In [22]:
train['fold'] = -1
train_ner['fold'] = -1


# val_langs = ['fon', 'pcm', 'twi', 'xho']
# train.loc[train.Language.isin(val_langs), 'fold'] = 0

base_train['fold'] = base_train['Language'].map(dict(zip(langs, range(len(langs)))))
base_train.loc[base_train['fold'].isna(), 'fold'] = -1
base_train['fold'] = base_train['fold'].astype(int)

In [23]:
n_fold = 5
# Fold = StratifiedKFold(n_splits=n_fold, shuffle=True, random_state=0)
# Fold = StratifiedGroupKFold(n_splits=n_fold, shuffle=True, random_state=0)
Fold = GroupKFold(n_splits=n_fold)

base_train['fold'] = -1
# for n, (train_index, val_index) in enumerate(Fold.split(train, train['region'], groups=train['Language'])):
for n, (train_index, val_index) in enumerate(Fold.split(base_train, base_train['Pos'], groups=base_train['Language'])):
    base_train.loc[val_index, 'fold'] = int(n)
    
base_train['fold'] = base_train['fold'].astype(int)

display(base_train.groupby('fold').size())

fold
0    10338
1     6908
2     6541
3     6133
4     7515
dtype: int64

In [24]:
train_ner.Language.unique()

array(['bam', 'bbj', 'ewe', 'fon', 'hau', 'ibo', 'kin', 'lug', 'mos',
       'nya', 'pcm', 'sna', 'swa', 'twi', 'wol', 'xho', 'yor', 'zul',
       'en', 'fr', 'eng-ron-wol-sna', 'ar', 'af', 'luo', 'tsn'],
      dtype=object)

In [25]:
add_train['fold'] = -1
base_train.loc[base_train.Language.isin(['sna', 'wol']), 'fold'] = -1
# base_train.loc[base_train.Language.isin(['sna', 'bam', 'pcm', 'yor', 'wol']), 'fold'] = -1

train_ner = pd.concat([base_train, add_train]).sample(frac=1).reset_index(drop=True)

In [26]:
train_ner.head()

Unnamed: 0,Word,Pos,Language,family,region,fold
0,Make anibodi wey dey call say make dem impeach...,VERB PRON PRON AUX VERB SCONJ VERB PRON VERB D...,pcm,English-Creole,West,0
1,وزير كورى : حل القضية النووية مع كوريا الشمالي...,NOUN ADJ PUNCT NOUN NOUN ADJ ADP X ADJ VERB SC...,ar,ar,ar,-1
2,Mano nyalo konyi ahinya e yudo tich .,PRON VERB VERB ADV ADP VERB NOUN PUNCT,luo,luo,luo,-1
3,"Son nom d' espèce , composé de loyalti et de l...",DET NOUN ADP NOUN PUNCT VERB ADP PROPN CCONJ A...,fr,fr,fr,-1
4,"Les lemmings sont des êtres expressifs , capab...",DET NOUN AUX DET NOUN ADJ PUNCT ADJ ADP NOUN C...,fr,fr,fr,-1


In [27]:
train_ner.groupby(['fold', 'Language']).count()

Unnamed: 0_level_0,Unnamed: 1_level_0,Word,Pos,family,region
fold,Language,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
-1,af,1899,1899,1899,1899
-1,ar,7534,7534,7534,7534
-1,en,15532,15532,15532,15532
-1,eng-ron-wol-sna,13120,13120,13120,13120
-1,fr,16339,16339,16339,16339
-1,luo,7244,7244,7244,7244
-1,sna,1492,1492,1492,1492
-1,tsn,4936,4936,4936,4936
-1,wol,1563,1563,1563,1563
0,pcm,10338,10338,10338,10338


# Utils

In [28]:
def paralellize(fct, data, verbose=0, with_tqdm=True):
    fn = map(delayed(fct), data)
    if with_tqdm:
        fn = tqdm(fn, total=len(data))
    return Parallel(n_jobs=-1, verbose=verbose, backend="multiprocessing")(fn)

In [29]:
class Timer:
    def __init__(self):
        self._time = 0

    def start(self):
        self._time = time()

    @property
    def time(self):
        return (time() - self._time) / 60

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

# Training

In [30]:
train_ner.head()

Unnamed: 0,Word,Pos,Language,family,region,fold
0,Make anibodi wey dey call say make dem impeach...,VERB PRON PRON AUX VERB SCONJ VERB PRON VERB D...,pcm,English-Creole,West,0
1,وزير كورى : حل القضية النووية مع كوريا الشمالي...,NOUN ADJ PUNCT NOUN NOUN ADJ ADP X ADJ VERB SC...,ar,ar,ar,-1
2,Mano nyalo konyi ahinya e yudo tich .,PRON VERB VERB ADV ADP VERB NOUN PUNCT,luo,luo,luo,-1
3,"Son nom d' espèce , composé de loyalti et de l...",DET NOUN ADP NOUN PUNCT VERB ADP PROPN CCONJ A...,fr,fr,fr,-1
4,"Les lemmings sont des êtres expressifs , capab...",DET NOUN AUX DET NOUN ADJ PUNCT ADJ ADP NOUN C...,fr,fr,fr,-1


In [31]:
text_column_name = 'Word'
label_column_name = 'Pos'
label_list = sorted(train[label_column_name].unique()) + ['NAW']

label_to_id = {l: i for i, l in enumerate(label_list)}
num_labels = len(label_list)

label_to_id

{'ADJ': 0,
 'ADP': 1,
 'ADV': 2,
 'AUX': 3,
 'CCONJ': 4,
 'DET': 5,
 'INTJ': 6,
 'NOUN': 7,
 'NUM': 8,
 'PART': 9,
 'PRON': 10,
 'PROPN': 11,
 'PUNCT': 12,
 'SCONJ': 13,
 'SYM': 14,
 'VERB': 15,
 'X': 16,
 'NAW': 17}

In [32]:
id_to_label = {v:k for k,v in label_to_id.items()}
id_to_label

{0: 'ADJ',
 1: 'ADP',
 2: 'ADV',
 3: 'AUX',
 4: 'CCONJ',
 5: 'DET',
 6: 'INTJ',
 7: 'NOUN',
 8: 'NUM',
 9: 'PART',
 10: 'PRON',
 11: 'PROPN',
 12: 'PUNCT',
 13: 'SCONJ',
 14: 'SYM',
 15: 'VERB',
 16: 'X',
 17: 'NAW'}

In [41]:
def load_model():
    # model_path = glob(f'./pos-tagging-ner/{model_name.replace("/", "-")}/{fold}/checkpoint-*')[0]
    
    model = AutoModelForTokenClassification.from_pretrained(
        model_name,
        # model_path,
        # config=config,
        num_labels=num_labels, id2label=id_to_label, label2id=label_to_id,
        ignore_mismatched_sizes=True
    )
    
    return model

In [34]:
# model_name = 'Davlan/afro-xlmr-base'
# model_name = 'Davlan/afro-xlmr-large-61L'
# model_name = 'Davlan/afro-xlmr-large-75L'
# model_name = 'masakhane/afroxlmr-large-ner-masakhaner-1.0_2.0'
model_name = 'google/rembert'
# model_name = 'bonadossou/afrolm_active_learning'

max_seq_length = 256
padding = False

config = AutoConfig.from_pretrained(
    model_name,
    num_labels=num_labels
)

tokenizer_name_or_path = model_name
if config.model_type in {"gpt2", "roberta"}:
    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_name_or_path,
        use_fast=True,
        add_prefix_space=True,
    )
else:
    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_name_or_path,
        use_fast=True,
    )
    
# data_collator = DataCollatorForTokenClassification(tokenizer)
data_collator = DataCollatorForTokenClassification(tokenizer, max_length=max_seq_length)

In [35]:
def process_dataset(examples):
    is_test = examples.get(label_column_name) is None
    
    for idx in range(len(examples[text_column_name])):
        if not is_test:
            examples[label_column_name][idx] = examples[label_column_name][idx].split()
        examples[text_column_name][idx] = examples[text_column_name][idx].split()
        
    tokenized_inputs = tokenizer(
        examples[text_column_name],
        padding=padding,
        truncation=True,
        max_length=max_seq_length if not is_test else None,
        is_split_into_words=True,
    )
    
    if not is_test:
        labels = []
        for i, label in enumerate(examples[label_column_name]):
            word_ids = tokenized_inputs.word_ids(batch_index=i)
            previous_word_idx = None
            label_ids = []
            for word_idx in word_ids:
                if word_idx is None:
                    label_ids.append(-100)
                elif word_idx != previous_word_idx:
                    label_ids.append(label_to_id[label[word_idx]])
                else:
                    label_ids.append(label_to_id['NAW'])
                previous_word_idx = word_idx
            labels.append(label_ids)
        tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [36]:
metric = load_metric("seqeval")

Downloading builder script:   0%|          | 0.00/2.47k [00:00<?, ?B/s]

In [37]:
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l not in [-100, 17]]
        for prediction, label in zip(predictions, labels)
    ]
    
    lengths = list(map(len, true_labels))
    
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l not in [-100, 17]]
        for prediction, label, length in zip(predictions, labels, lengths)
    ]
    
    results = metric.compute(predictions=true_predictions, references=true_labels)
    
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

In [38]:
sentences = []
words = []

last_lang = test.Language.values[0]
for i, row in enumerate(tqdm(test.itertuples(), total=len(test))):
    words.append(row.Word)
    
    if row.Word == '.' or i == len(test)-1 or row.Language != last_lang:
        sentences.append(' '.join(words))
        words = []
        
    last_lang = row.Language

100%|██████████| 32045/32045 [00:00<00:00, 309426.40it/s]


In [39]:
# test_dataset = HFDataset.from_pandas(test).remove_columns(column_names=['Id', 'Language', 'Pos'])
# test_dataset = HFDataset.from_pandas(test_ner).remove_columns(column_names=['Language', 'family', 'region'])
test_dataset = HFDataset.from_pandas(pd.DataFrame({'Word': sentences}))
test_dataset = test_dataset.map(
    process_dataset,
    batched=True,
    remove_columns=['Word'],
)
test_dataset

Map:   0%|          | 0/1974 [00:00<?, ? examples/s]

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 1974
})

In [40]:
class AWP:
    def __init__(self, model, optimizer, *, adv_param='weight',
                 adv_lr=0.001, adv_eps=0.001):
        self.model = model
        self.optimizer = optimizer
        self.adv_param = adv_param
        self.adv_lr = adv_lr
        self.adv_eps = adv_eps
        self.backup = {}

    def perturb(self):
        """
        Perturb model parameters for AWP gradient
        Call before loss and loss.backward()
        """
        self._save()  # save model parameters
        self._attack_step()  # perturb weights

    def _attack_step(self):
        e = 1e-6
        for name, param in self.model.named_parameters():
            if param.requires_grad and param.grad is not None and self.adv_param in name:
                grad = self.optimizer.state[param]['exp_avg']
                norm_grad = torch.norm(grad)
                norm_data = torch.norm(param.detach())

                if norm_grad != 0 and not torch.isnan(norm_grad):
                    # Set lower and upper limit in change
                    limit_eps = self.adv_eps * param.detach().abs()
                    param_min = param.data - limit_eps
                    param_max = param.data + limit_eps

                    # Perturb along gradient
                    # w += (adv_lr * |w| / |grad|) * grad
                    param.data.add_(grad, alpha=(self.adv_lr * (norm_data + e) / (norm_grad + e)))

                    # Apply the limit to the change
                    param.data.clamp_(param_min, param_max)

    def _save(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad and param.grad is not None and self.adv_param in name:
                if name not in self.backup:
                    self.backup[name] = param.clone().detach()
                else:
                    self.backup[name].copy_(param.data)

    def restore(self):
        """
        Restore model parameter to correct position; AWP do not perturbe weights, it perturb gradients
        Call after loss.backward(), before optimizer.step()
        """
        for name, param in self.model.named_parameters():
            if name in self.backup:
                param.data.copy_(self.backup[name])

class MyTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.awp = AWP(self.model, self.optimizer, adv_lr=0.001, adv_eps=0.001)
    
    def training_step(self, model: nn.Module, inputs) -> torch.Tensor:
        """
        Perform a training step on a batch of inputs.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Module`):
                The model to train.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.

        Return:
            `torch.Tensor`: The tensor with training loss on this batch.
        """
        model.train()
        inputs = self._prepare_inputs(inputs)

        # if is_sagemaker_mp_enabled():
        #     loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
        #     return loss_mb.reduce_mean().detach().to(self.args.device)
        
        self.awp.perturb() # 
        
        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs)

        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training

        if self.do_grad_scaling:
            self.scaler.scale(loss).backward()
        elif self.use_apex:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            self.accelerator.backward(loss)
            
        self.awp.restore()

        return loss.detach() / self.args.gradient_accumulation_steps

    def compute_loss(self, model, inputs, return_outputs=False):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        outputs = model(**inputs)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            if is_peft_available() and isinstance(model, PeftModel):
                model_name = unwrap_model(model.base_model)._get_name()
            else:
                model_name = unwrap_model(model)._get_name()
            if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                loss = self.label_smoother(outputs, labels)
        else:
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        return (loss, outputs) if return_outputs else loss
        

In [42]:
def train_fold(fold):
    raw_dataset = DatasetDict({
        'train': HFDataset.from_pandas(train_ner[train_ner.fold != fold]),
        'validation': HFDataset.from_pandas(train_ner[train_ner.fold == fold])
    }).remove_columns(column_names=['fold', '__index_level_0__', 'Language', 'family', 'region'])

    train_dataset = raw_dataset["train"]
    column_names = train_dataset.column_names

    train_dataset = train_dataset.map(
        process_dataset,
        batched=True,
        remove_columns=column_names,
        desc="Running tokenizer on train dataset",
    )
    if max_seq_length is not None:
        train_dataset = train_dataset.filter(
            lambda example: len(example['input_ids']) <= max_seq_length
        )

    eval_dataset = raw_dataset['validation']
    eval_dataset = eval_dataset.map(
        process_dataset,
        batched=True,
        remove_columns=column_names,
    )
    if max_seq_length is not None:
        eval_dataset = eval_dataset.filter(
            lambda example: len(example['input_ids']) <= max_seq_length
        )

    training_args = TrainingArguments(
        output_dir=f'pos-tagging-ner/{model_name.replace("/", "-")}/{fold}',
        learning_rate=5e-5,
        do_train=True,
        do_eval=True,
        per_device_train_batch_size=32,
        per_device_eval_batch_size=32,
        gradient_accumulation_steps=1,
        group_by_length=True,
        overwrite_output_dir=True,
        warmup_steps=0.15,
        num_train_epochs=5,
        lr_scheduler_type ='cosine',
        evaluation_strategy="steps",
        save_strategy="steps",
        eval_steps=500,
        save_steps=500,
        load_best_model_at_end=True,
        metric_for_best_model='accuracy',
        greater_is_better=True,
        save_total_limit=1,
        fp16=True,
        report_to='none'
    )

    model = load_model()

    trainer = Trainer(
    # trainer = MyTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics
    )

    trainer.train()
    
    results = trainer.evaluate(eval_dataset)
    
    test_pos_ids = trainer.predict(test_dataset)
    
    return results['eval_accuracy'], test_pos_ids

In [45]:
scores = []
all_preds = []

for fold in [1, 2, 3, 4]:
    score, fold_pred = train_fold(fold)
    
    scores.append(score)
    all_preds.append(fold_pred)
    
    print(score)
    print()
    
avg_score = np.mean(scores)

Running tokenizer on train dataset:   0%|          | 0/97131 [00:00<?, ? examples/s]

Filter:   0%|          | 0/97131 [00:00<?, ? examples/s]

Map:   0%|          | 0/6908 [00:00<?, ? examples/s]

Filter:   0%|          | 0/6908 [00:00<?, ? examples/s]

You're using a RemBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


0.5690797151020079



Running tokenizer on train dataset:   0%|          | 0/97498 [00:00<?, ? examples/s]

Filter:   0%|          | 0/97498 [00:00<?, ? examples/s]

Map:   0%|          | 0/6541 [00:00<?, ? examples/s]

Filter:   0%|          | 0/6541 [00:00<?, ? examples/s]

0.6031341201993853



Running tokenizer on train dataset:   0%|          | 0/100961 [00:00<?, ? examples/s]

Filter:   0%|          | 0/100961 [00:00<?, ? examples/s]

Map:   0%|          | 0/3078 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3078 [00:00<?, ? examples/s]

0.6271357913669064



Running tokenizer on train dataset:   0%|          | 0/96524 [00:00<?, ? examples/s]

Filter:   0%|          | 0/96524 [00:00<?, ? examples/s]

Map:   0%|          | 0/7515 [00:00<?, ? examples/s]

Filter:   0%|          | 0/7515 [00:00<?, ? examples/s]

0.6577001294139397



In [46]:
scores

[0.5690797151020079,
 0.6031341201993853,
 0.6271357913669064,
 0.6577001294139397]

In [47]:
avg_score

0.6142624390205598

In [48]:
test_pos_ids = np.mean([p.predictions for p in all_preds], axis=0)
print(test_pos_ids.shape)
test_pos_ids = test_pos_ids.argmax(axis=-1)

(1974, 168, 18)


In [49]:
final_pos = []
for pos_ids, sentence in zip(test_pos_ids, sentences):
    length = len(sentence.split(' '))
    final_pos.extend(
        list(map(id_to_label.get, [x for x in pos_ids if x not in [-100, 17]]))[:length]
        # list(map(id_to_label.get, [x for x in pos_ids if x not in [-100, 17]]))[1:length+1]
    )

In [50]:
test['Pos'] = final_pos

In [51]:
test.Pos.value_counts()

NOUN     6797
VERB     4979
ADP      4414
PROPN    3095
PUNCT    2953
AUX      2531
PRON     1784
SCONJ    1175
DET       908
ADV       808
CCONJ     719
PART      708
ADJ       555
NUM       550
INTJ       27
X          27
SYM        15
Name: Pos, dtype: int64

In [52]:
test.head()

Unnamed: 0,Id,Word,Language,Pos
0,Id00qog2f11n_0,Ne,luo,AUX
1,Id00qog2f11n_1,otim,luo,VERB
2,Id00qog2f11n_2,penj,luo,NOUN
3,Id00qog2f11n_3,e,luo,ADP
4,Id00qog2f11n_4,kind,luo,NOUN


In [53]:
test[['Id', 'Pos']].to_csv(f'submissions/ps-round1-pos-ner-{model_name.split("/")[-1]}-{avg_score:.3f}.csv', index=False)

In [None]:
# test[test.Pos == 'SYM']