In [22]:
!pip3 install OpenNMT-py==1.2.0

Defaulting to user installation because normal site-packages is not writeable
Collecting OpenNMT-py==1.2.0
  Downloading OpenNMT_py-1.2.0-py3-none-any.whl (195 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m195.2/195.2 KB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting torchtext==0.4.0
  Downloading torchtext-0.4.0-py3-none-any.whl (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.1/53.1 KB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
Collecting future
  Downloading future-1.0.0-py3-none-any.whl (491 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.3/491.3 KB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Installing collected packages: future, torchtext, OpenNMT-py
  Attempting uninstall: OpenNMT-py
    Found existing installation: OpenNMT-py 3.5.1
    Uninstalling OpenNMT-py-3.5.1:
      Successfully uninstalled OpenNMT-py-3.5.1
Successfully installed OpenNMT-py-1.2.0 future-1.0.0 to

In [45]:
import requests

def get_forms(sysID, lang, lemma_tags_set):
    """lemma_tags_set: `[{"lemma": "go", "tags": "V;PST"}, ...]`"""
    
    request_data = {
        "sysID": sysID,
        "lang": lang,
        "data": lemma_tags_set
    }
    response = requests.post('https://test2.kurdinus.com/oracle/GetForms', json=request_data)
    if response.status_code == 200:
        return response.content.decode()
    else:
        return "#FAILED: " + str(response.status_code)

In [46]:
def check_forms(sysID, lang, lemma_form_tags_set):
    """lemma_forms_tags_set: `[{"lemma": "go", "form": "goed", "tags": "V;PST"}, ...]`"""
    
    request_data = {
        "sysID": sysID,
        "lang": lang,
        "data": lemma_form_tags_set
    }
    response = requests.post('https://test2.kurdinus.com/oracle/CheckForms', json=request_data)
    print(response)
    if response.status_code == 200:
        return response.content.decode()
    else:
        return "#FAILED: " + str(response.status_code)

In [47]:
import pandas as pd
import random

file = pd.read_csv('lat.tsv', sep='\t', header=None).values.tolist()

def get_data(rows, function):
    """rows: `[[lemma, form(if checking data), tags], ...]`
    output: `["lemma\\t form\\t tags", ...]`"""

    lemma_forms_tags_set = []
    for row in rows:
        if row in file: file.remove(row)
        lemma_forms_tags_set.append({"lemma": row[0], "form": row[1], "tags": row[-1]})

    oracle_data = function("fumo_test8", "lat", lemma_forms_tags_set).split('\n')[:-1]
    return oracle_data

In [48]:

rows = random.sample(file, 1000)

# split 90-10 into train and dev
oracle_data = get_data(rows, get_forms)
train_data = oracle_data[:int(len(oracle_data)*0.9)]
dev_data = oracle_data[int(len(oracle_data)*0.9):]
print(oracle_data[:5])

['apprēnsō\tapprēnsābis\tV;IND;ACT;FUT;2;SG', 'congaudeō\tcongaudet\tV;IND;ACT;PRS;3;SG', 'adprēnsō\tadprēnsāberis\tV;IND;PASS;FUT;2;SG', 'abhorrēscō\tabhorrēscēbātis\tV;IND;ACT;PST;IPFV;2;PL', 'attestor\tattestābāris\tV;IND;ACT;PST;IPFV;2;SG']


In [49]:
def create_data_files(name, data):
    """data: `["lemma\\t form\\t tags", ...]`
    output: `name.src`, `name.tgt` files with the data in this format: `g o # V PST` and `w e n t`"""
    train_src = open(f'data/{name}.src', 'w')
    train_tgt = open(f'data/{name}.tgt', 'w')
    for result in data:
        if type(result) == str: result = result.split('\t')
        if len(result) == 2:
            lemma, tags = result
            word = ''
        else:
            lemma, word, tags = result
        src = ' '.join(list(lemma)) + " # " + ' '.join(tags.split(';'))
        tgt = ' '.join(list(word))
        train_src.write(src + '\n')
        train_tgt.write(tgt + '\n')
    train_src.close()
    train_tgt.close()

In [50]:
create_data_files('train', train_data)
create_data_files('dev', dev_data)

In [51]:
!onmt_preprocess -train_src data/train.src -train_tgt data/train.tgt -valid_src data/dev.src -valid_tgt data/dev.tgt -save_data run/data -overwrite

[2024-05-07 21:53:43,436 INFO] Extracting features...
[2024-05-07 21:53:43,446 INFO]  * number of source features: 0.
[2024-05-07 21:53:43,446 INFO]  * number of target features: 0.
[2024-05-07 21:53:43,446 INFO] Building `Fields` object...
[2024-05-07 21:53:43,446 INFO] Building & saving training data...
[2024-05-07 21:53:43,523 INFO] Building shard 0.
[2024-05-07 21:53:43,542 INFO]  * saving 0th train data shard to run/data.train.0.pt.
[2024-05-07 21:53:43,777 INFO]  * tgt vocab size: 36.
[2024-05-07 21:53:43,778 INFO]  * src vocab size: 48.
[2024-05-07 21:53:43,822 INFO] Building & saving validation data...
[2024-05-07 21:53:45,026 INFO] Building shard 0.
[2024-05-07 21:53:45,028 INFO]  * saving 0th valid data shard to run/data.valid.0.pt.


In [52]:
%%capture cap1 --no-stderr
!onmt_train -data run/data -save_model run/model -encoder_type rnn -rnn_type LSTM -rnn_size 128 -layers 1 -word_vec_size 128 -save_checkpoint_steps 200 -valid_steps 200 -early_stopping 2
with open('train.log', 'w') as f:
    f.write(cap1.stdout)

In [61]:
%%capture cap2 --no-stderr
!onmt_translate -model run/model_step_1000.pt -src data/dev.src -output data/dev.hyp -replace_unk -verbose
with open('pred_scores.log', 'w') as f:
    f.write(cap2.stdout)

In [62]:
!julia evaluate.jl data/dev.tgt data/dev.hyp

Accuracy: 79 / 100   0.79
Character edit distance 0.46
mētiuntur -> mētīuntur
supant -> supāus
foetēbuntur -> fotēbuntur
olēbās -> olōēbās
urvābat -> urveābat
cunient -> cunībunt
oreris -> orieris
glabrō -> glabrās
immorāberis -> immormorāberis
caurītis -> cauriētis
expetessor -> expetessēor
aporiāberis -> aporiēris
cōnfit -> cōnfīt
cooperor -> coperārur
foetentur -> fotīuntur
faetēbantur -> faeteēbantur
dēpangis -> dēpangīs
remūgient -> remūgībunt
dēmōlientur -> dēmōlībuntur
praecognōscēbat -> praecābat
pigritābāris -> pitritābāris


In [63]:
import re

def get_sorted(text = None):
    if text==None: 
        with open('pred_scores.log', 'r') as file:
            lines = file.readlines()
    else: lines = text.split('\n')

    data = []
    lemma = ''
    tags = ''
    score = 0
    for i in range(len(lines)):
        if lines[i].startswith('SENT'):
            lemma_tags = re.findall(r'\[\'(.*)\'\]', lines[i])[0]
            lemma, tags = lemma_tags.split('#')
            lemma = lemma.replace("'", '').replace(', ', '')
            tags = tags.replace("'", '').replace(', ', ';')[1:]
        elif lines[i].startswith('PRED SCORE'):
            score = float(lines[i].split(':')[1].strip())
        elif lines[i].startswith('PRED'):
            prediction = lines[i].split(':')[1].strip().replace(' ', '')
            data.append((lemma, prediction, tags, score))

    data.sort(key=lambda x: x[3])
    return data

In [64]:
%%capture cap3 --no-stderr
create_data_files("test", file)
!onmt_translate -model run/model_step_1000.pt -src data/test.src -output data/test.hyp -replace_unk -verbose
with open('pred_scores.log', 'w') as f:
    f.write(cap3.stdout)

In [69]:
%%capture cap2 --no-stderr
# now predict all data in `file`
# with open('pred_scores.log', 'w') as f:
#     f.write('')

results = []

for i in range(len(file)//100):
    create_data_files("test", file[100*i : len(file) if 100*(i+1)>len(file) else 100*(i+1)])
    %%capture cap2 --no-stderr
    !onmt_translate -model run/model_step_1000.pt -src data/test.src -output data/test.hyp -replace_unk -verbose
    data = cap2.stdout
    data = [x[:-1] for x in get_sorted(data)]
    results += get_data(data, check_forms)


UsageError: Line magic function `%%capture` not found.
