In [None]:
"""hf_seq2seq_lemmatization.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1M0LODo6g2mREdKOXunrg4mvjJUiSJ5tr

# The task

* Lemmatization
* Input: wordform + morpho information
* Output: word baseform
* Easy for English, but not so much for Finnish or many other languages

Here is few examples:

* dogs+NOUN+Plural -> dog
* sheep+NOUN+Plural -> sheep
* voi+VERB+... -> voida
* voi+NOUN+Singular -> voi

# Data preparation

* We can use universaldependencies.org
* Collection of treebanks
* Pick your favorite language, I will use Finnish
"""

In [None]:
!pip3 install --quiet datasets transformers

You can use UD_English-EWT for English

In [None]:
!wget -O train.conllu https://github.com/UniversalDependencies/UD_Finnish-TDT/raw/master/fi_tdt-ud-train.conllu
!wget -O validation.conllu https://github.com/UniversalDependencies/UD_Finnish-TDT/raw/master/fi_tdt-ud-dev.conllu
!wget -O test.conllu https://github.com/UniversalDependencies/UD_Finnish-TDT/raw/master/fi_tdt-ud-test.conllu

In [None]:
"""# Data preparation

* The CoNLL format should be familiar to you by now
* Here is few lines (the delimiter is TAB)



```
# newdoc id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200
# sent_id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200-0001
# newpar id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200-p0001
# text = What if Google Morphed Into GoogleOS?
1	What	what	PRON	WP	PronType=Int	0	root	0:root	_
2	if	if	SCONJ	IN	_	4	mark	4:mark	_
3	Google	Google	PROPN	NNP	Number=Sing	4	nsubj	4:nsubj	_
4	Morphed	morph	VERB	VBD	Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin	1	advcl	1:advcl:if	_
5	Into	into	ADP	IN	_	6	case	6:case	_
6	GoogleOS	GoogleOS	PROPN	NNP	Number=Sing	4	obl	4:obl:into	SpaceAfter=No
7	?	?	PUNCT	.	_	4	punct	4:punct	_


```

* Let us form training examples like so:
    * Input is `wordform`_`POS`_`FEATS`
    * Output is the lemma
* We can reuse part of our dataset preparation code from the [MLP notebook](https://github.com/TurkuNLP/Deep_Learning_in_LangTech_course/blob/master/hf_trainer_mlp.ipynb)
"""

In [None]:
import json
import datasets

In [None]:
dataset = datasets.load_dataset(
    'json',                             # Format of the data
    data_files={"test":"test.jsonl"},
    split={
        "test":"test"
    },
    features=datasets.Features({    # Here we tell how to interpret the attributes
        "form_tags":datasets.Value("string"),
        "lemma":datasets.Value("string")
    })
)

In [None]:
dataset=dataset.shuffle()
#that was easy!

In [None]:
print(dataset)

In [None]:
"""# Tokenize and prep"""

In [None]:
import transformers

In [None]:
#OK, let's try with our trusty tokenizer
#but why would this work in the first place?
model_name = "TurkuNLP/bert-base-finnish-cased-v1"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens({ "additional_special_tokens": [ "[unused1]", "[unused2]" ] })
"""the examples are formed surprisingly similarly to what you've seen before:

* `input_ids` is the input side
* `attention_mask` is the input attention mask
* `labels` is the output ids
* the encoder-decoder model should take care of the rest
"""

In [None]:
def tokenize(example):
    
    inp_w,inp_tags=example["form_tags"].split("+++",1)
    out=" ".join(example["lemma"])
    
    inp_tok=tokenizer("[unused1] "+" ".join(inp_w)+" "+(inp_tags.replace("|"," "))+" [unused2]",truncation=True)
    outp_tok=tokenizer("[unused1] "+out+" [unused2]",truncation=True)

    return {"input_ids":inp_tok["input_ids"],
            "attention_mask":inp_tok["attention_mask"],
            "labels":outp_tok["input_ids"]}

In [None]:
dataset=dataset.map(tokenize)

In [None]:
model = transformers.AutoModelForSeq2SeqLM.from_pretrained("s2s_lemmatizer")
#help(model.forward)

In [None]:
collator=transformers.DataCollatorForSeq2Seq(tokenizer=tokenizer,
                                             model=model,
                                             padding=True,
                                             return_tensors="pt")

In [None]:
lst=[]
for e in dataset["test"]:
    lst.append({"input_ids":e["input_ids"],"labels":e["labels"],"attention_mask":e["attention_mask"]})
    break
batch=collator(lst)
print(batch)

In [None]:
trainer_args = transformers.Seq2SeqTrainingArguments(
    "checkpoints",
    evaluation_strategy="steps",
    logging_strategy="steps",
    load_best_model_at_end=True,
    eval_steps=1000,
    logging_steps=100,
    learning_rate=5e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=128,
    max_steps=30000,
    save_steps=1000,
    predict_with_generate=True

In [None]:
)

In [None]:
early_stopping = transformers.EarlyStoppingCallback(
    early_stopping_patience=5
)

In [None]:
trainer = transformers.Seq2SeqTrainer(
    model=model,
    args=trainer_args,
    data_collator=collator,
    tokenizer = tokenizer,
)

In [None]:
test_data=dataset["test"].select(range(133))
predictions=trainer.predict(test_data)
for x,e in zip(predictions.predictions,test_data):
    print(e["form_tags"],tokenizer.decode(x))
#print(model(batch))
#print(batch)