In [None]:
!wget https://raw.githubusercontent.com/mhulden/eztransformer/refs/heads/main/eztr.py
!mkdir downloads
!mv eztr.py downloads
## to clean up: rm downloads/eztr.py

%pip install -r "../requirements.txt"

In [1]:
from typing import Literal
from datasets import load_dataset

def load_data(
    lang: Literal['egyptian', 'gulf'],
    split: Literal['all', 'train', 'dev', 'no'] = 'all',
):
    lang = {'egyptian': 'arz', 'gulf': 'afb'}[lang]
    prefix = f"../data/{lang}/{lang}"
    files = {}
    if split in {"train", "all"}:
        files["train"] = f"{prefix}.trn"

    if split in {"train", "dev", "all", "no"}:
        files["valid"] = f"{prefix}.dev"

    if split in {"test", "all"}:
        files["test"] = f"{prefix}.tst"
    
    ds = load_dataset(
        'csv', delimiter="\t", data_files=files,
        column_names=["lemma", "features", "form"])

    if split == "no":
        return ds["valid"]
    
    return ds


def load_partial(train_size: int, valid_size: int|float, shift=0, lang='egyptian'):
    ''' load a subset of data '''

    data = load_data(lang=lang, split='train')
    
    # reduce data's size
    train = data['train'].select(range(shift, shift+train_size))
    
    if isinstance(valid_size, float):
        valid_size = int(valid_size * train_size)
    # returning training data instead of valid, to avoid peeking into dev data while debugging
    valid = data['train'].select(range(shift+train_size, shift+train_size+valid_size))
    
    data['train'] = train
    data['valid'] = valid
    
    return data

In [4]:
batch_size = 512
num_proc = 4
train_size = 2048 * 2
valid_size = 256 

data = load_partial(train_size=train_size, valid_size=valid_size, shift=3491)

# turn to characters
data = data.map(
    lambda batch: {
        "lemma": [list(s) for s in batch["lemma"]],
        "form": [list(t) for t in batch["form"]],
        "features": [f.split(';') for f in batch["features"]],
    },
    batched=True,
    batch_size=batch_size,
    num_proc=num_proc,
)

data = data.map(
    lambda batch: {
        'src': [l+f for l, f in zip(batch["lemma"], batch["features"])]
    },
    batched=True,
    batch_size=batch_size,
    num_proc=num_proc,
    remove_columns=['features', 'lemma']
)

# only for compatibility with eztransformer's build_vocab
data = data.map(
    lambda batch: {
        'src': [' '.join(item) for item in batch['src']],
        'form': [' '.join(item) for item in batch['form']],
    },
    batched=True,
    batch_size=batch_size,
    num_proc=num_proc,
)

Map (num_proc=4):   0%|          | 0/4096 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/256 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/4096 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/256 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/4096 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/256 [00:00<?, ? examples/s]

In [5]:
from downloads.eztr import EZTransformer

apple_device = True

# Initialize model
trf = EZTransformer(
    device = 'mps' if apple_device else 'cpu', # Change device as needed; 'cuda' (NVIDIA), 'mps' (Apple), or 'cpu'
    # learning rate
    lrt = 1e-3,
    # batch size
    bts = 256,
    # embedding
    eed = 256,
    ded = 256,
    # hidden size:
    ehs = 512,
    dhs = 512,
    # layers:
    enl = 2,
    dnl = 2,
    # heads:
    eah = 4,
    dah = 4,

    save_best = 10,
    # dropout
    drp=0.1,

    # lst
    # cnm
    )

def ez_train(train_data, valid_data, print_validation_examples):
    if not isinstance(train_data, list):
        train_data = list(zip(train_data["src"], train_data["form"]))
        valid_data = list(zip(valid_data["src"], valid_data["form"]))

    # Train model
    trf.fit(
        train_data = train_data, 
        valid_data = valid_data, 
        print_validation_examples = 0, 
        max_epochs = 40,
        warmup=(train_size//batch_size)*20,
        )

    trf.print_validation_examples(valid_data, print_validation_examples)

ez_train(data['train'], data['valid'], 10)

Epoch 1/40: 100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


LR: 6.719840027857805e-05
Epoch 1: Training Loss: 4.009897
Epoch 1: Validation Loss: 3.560622


Epoch 2/40: 100%|██████████| 16/16 [00:08<00:00,  1.81it/s]


LR: 0.00013044395348194566
Epoch 2: Training Loss: 3.326871
Epoch 2: Validation Loss: 3.028870


Epoch 3/40: 100%|██████████| 16/16 [00:09<00:00,  1.74it/s]


LR: 0.00014285714285714284
Epoch 3: Training Loss: 2.831384
Epoch 3: Validation Loss: 2.539448


Epoch 4/40: 100%|██████████| 16/16 [00:09<00:00,  1.75it/s]


LR: 0.00012403473458920844
Epoch 4: Training Loss: 2.409231
Epoch 4: Validation Loss: 2.209717


Epoch 5/40: 100%|██████████| 16/16 [00:09<00:00,  1.76it/s]


LR: 0.0001111111111111111
Epoch 5: Training Loss: 2.120853
Epoch 5: Validation Loss: 2.008362


Epoch 6/40: 100%|██████████| 16/16 [00:08<00:00,  1.82it/s]


LR: 0.00010153461651336191
Epoch 6: Training Loss: 1.906537
Epoch 6: Validation Loss: 1.841191


Epoch 7/40: 100%|██████████| 16/16 [00:08<00:00,  1.80it/s]


LR: 9.407208683835973e-05
Epoch 7: Training Loss: 1.726263
Epoch 7: Validation Loss: 1.717804


Epoch 8/40: 100%|██████████| 16/16 [00:08<00:00,  1.78it/s]


LR: 8.804509063256238e-05
Epoch 8: Training Loss: 1.576620
Epoch 8: Validation Loss: 1.616507


Epoch 9/40: 100%|██████████| 16/16 [00:09<00:00,  1.74it/s]


LR: 8.304547985373997e-05
Epoch 9: Training Loss: 1.457408
Epoch 9: Validation Loss: 1.526002


Epoch 10/40: 100%|██████████| 16/16 [00:08<00:00,  1.87it/s]


LR: 7.881104062391007e-05
Epoch 10: Training Loss: 1.370110
Epoch 10: Validation Loss: 1.454587
Model saved to best_model.pt


Epoch 11/40: 100%|██████████| 16/16 [00:08<00:00,  1.91it/s]


LR: 7.516460280028288e-05
Epoch 11: Training Loss: 1.290725
Epoch 11: Validation Loss: 1.387624


Epoch 12/40: 100%|██████████| 16/16 [00:09<00:00,  1.68it/s]


LR: 7.198157507486945e-05
Epoch 12: Training Loss: 1.229834
Epoch 12: Validation Loss: 1.341344


Epoch 13/40: 100%|██████████| 16/16 [00:08<00:00,  1.93it/s]


LR: 6.917144638660747e-05
Epoch 13: Training Loss: 1.183038
Epoch 13: Validation Loss: 1.305748


Epoch 14/40: 100%|██████████| 16/16 [00:08<00:00,  1.86it/s]


LR: 6.666666666666667e-05
Epoch 14: Training Loss: 1.141157
Epoch 14: Validation Loss: 1.276766


Epoch 15/40: 100%|██████████| 16/16 [00:08<00:00,  1.82it/s]


LR: 6.441566264008309e-05
Epoch 15: Training Loss: 1.105794
Epoch 15: Validation Loss: 1.249902


Epoch 16/40: 100%|██████████| 16/16 [00:08<00:00,  1.83it/s]


LR: 6.237828615518053e-05
Epoch 16: Training Loss: 1.075669
Epoch 16: Validation Loss: 1.220177


Epoch 17/40: 100%|██████████| 16/16 [00:09<00:00,  1.76it/s]


LR: 6.052275326688024e-05
Epoch 17: Training Loss: 1.049833
Epoch 17: Validation Loss: 1.196577


Epoch 18/40: 100%|██████████| 16/16 [00:11<00:00,  1.43it/s]


LR: 5.882352941176471e-05
Epoch 18: Training Loss: 1.026758
Epoch 18: Validation Loss: 1.182135


Epoch 19/40: 100%|██████████| 16/16 [00:09<00:00,  1.62it/s]


LR: 5.7259833431386826e-05
Epoch 19: Training Loss: 1.012857
Epoch 19: Validation Loss: 1.166431


Epoch 20/40: 100%|██████████| 16/16 [00:08<00:00,  1.79it/s]


LR: 5.5814557218594757e-05
Epoch 20: Training Loss: 0.999641
Epoch 20: Validation Loss: 1.158888
Model saved to best_model.pt


Epoch 21/40: 100%|██████████| 16/16 [00:09<00:00,  1.75it/s]


LR: 5.447347107028433e-05
Epoch 21: Training Loss: 0.981753
Epoch 21: Validation Loss: 1.148877


Epoch 22/40: 100%|██████████| 16/16 [00:08<00:00,  1.86it/s]


LR: 5.322462954123495e-05
Epoch 22: Training Loss: 0.970694
Epoch 22: Validation Loss: 1.143110


Epoch 23/40: 100%|██████████| 16/16 [00:08<00:00,  1.86it/s]


LR: 5.2057920629535354e-05
Epoch 23: Training Loss: 0.960761
Epoch 23: Validation Loss: 1.141107


Epoch 24/40: 100%|██████████| 16/16 [00:08<00:00,  1.83it/s]


LR: 5.0964719143762554e-05
Epoch 24: Training Loss: 0.953834
Epoch 24: Validation Loss: 1.129697


Epoch 25/40: 100%|██████████| 16/16 [00:11<00:00,  1.34it/s]


LR: 4.9937616943892234e-05
Epoch 25: Training Loss: 0.941548
Epoch 25: Validation Loss: 1.124174


Epoch 26/40: 100%|██████████| 16/16 [00:11<00:00,  1.41it/s]


LR: 4.8970210687439175e-05
Epoch 26: Training Loss: 0.936821
Epoch 26: Validation Loss: 1.124932


Epoch 27/40: 100%|██████████| 16/16 [00:08<00:00,  1.82it/s]


LR: 4.8056933133221275e-05
Epoch 27: Training Loss: 0.929742
Epoch 27: Validation Loss: 1.115888


Epoch 28/40: 100%|██████████| 16/16 [00:08<00:00,  1.90it/s]


LR: 4.719291781830087e-05
Epoch 28: Training Loss: 0.923648
Epoch 28: Validation Loss: 1.113866


Epoch 29/40: 100%|██████████| 16/16 [00:08<00:00,  1.83it/s]


LR: 4.6373889576016824e-05
Epoch 29: Training Loss: 0.917785
Epoch 29: Validation Loss: 1.113462


Epoch 30/40: 100%|██████████| 16/16 [00:09<00:00,  1.75it/s]


LR: 4.5596075258755325e-05
Epoch 30: Training Loss: 0.911947
Epoch 30: Validation Loss: 1.107085
Model saved to best_model.pt


Epoch 31/40: 100%|██████████| 16/16 [00:08<00:00,  1.86it/s]


LR: 4.485613040162567e-05
Epoch 31: Training Loss: 0.907266
Epoch 31: Validation Loss: 1.107163


Epoch 32/40: 100%|██████████| 16/16 [00:08<00:00,  1.89it/s]


LR: 4.41510785688348e-05
Epoch 32: Training Loss: 0.904646
Epoch 32: Validation Loss: 1.104763


Epoch 33/40: 100%|██████████| 16/16 [00:08<00:00,  1.92it/s]


LR: 4.347826086956522e-05
Epoch 33: Training Loss: 0.897924
Epoch 33: Validation Loss: 1.100756


Epoch 34/40: 100%|██████████| 16/16 [00:08<00:00,  1.93it/s]


LR: 4.2835293687811936e-05
Epoch 34: Training Loss: 0.894951
Epoch 34: Validation Loss: 1.098140


Epoch 35/40: 100%|██████████| 16/16 [00:08<00:00,  1.78it/s]


LR: 4.222003309207491e-05
Epoch 35: Training Loss: 0.891247
Epoch 35: Validation Loss: 1.093215


Epoch 36/40: 100%|██████████| 16/16 [00:08<00:00,  1.91it/s]


LR: 4.163054471218133e-05
Epoch 36: Training Loss: 0.888613
Epoch 36: Validation Loss: 1.090995


Epoch 37/40: 100%|██████████| 16/16 [00:08<00:00,  1.79it/s]


LR: 4.106507811765909e-05
Epoch 37: Training Loss: 0.882325
Epoch 37: Validation Loss: 1.092755


Epoch 38/40: 100%|██████████| 16/16 [00:08<00:00,  1.89it/s]


LR: 4.052204492365539e-05
Epoch 38: Training Loss: 0.880506
Epoch 38: Validation Loss: 1.089500


Epoch 39/40: 100%|██████████| 16/16 [00:08<00:00,  1.94it/s]


LR: 4e-05
Epoch 39: Training Loss: 0.877390
Epoch 39: Validation Loss: 1.089065


Epoch 40/40: 100%|██████████| 16/16 [00:08<00:00,  1.92it/s]


LR: 3.9497625276668216e-05
Epoch 40: Training Loss: 0.874431
Epoch 40: Validation Loss: 1.089770
Model saved to best_model.pt

Validation Examples:
Input:     م َ غ ْ ن ا ط ِ ي س ِ ي ADJ MASC SG PSSD
Target:    م َ غ ْ ن ا ط ِ ي س ِ ي
[91mPredicted:[0m م َ غ ْ ط ا ن ِ ي ن ِ ي ِ ي

Input:     م َ ن ْ ه ُ و ب ADJ MASC SG PSSD
Target:    م َ ن ْ ه ُ و ب
[92mPredicted:[0m م َ ن ْ ه ُ و ب

Input:     م َ و ّ ِ ت V IPFV NOM(MASC,PL,3)
Target:    ي ِ م َ و ّ ِ ت ُ و ا
[92mPredicted:[0m ي ِ م َ و ّ ِ ت ُ و ا

Input:     م َ ن ْ ظ ُ و ر ADJ INDF MASC SG
Target:    م َ ن ْ ظ ُ و ر
[91mPredicted:[0m م َ ن ْ و ر ُ و ر

Input:     م َ ك ْ س َ ب N INDF MASC SG
Target:    م َ ك ْ س َ ب
[92mPredicted:[0m م َ ك ْ س َ ب

Input:     م ُ ب ا ر ا ة N MASC SG PSSD
Target:    م ُ ب ا ر ا ة
[91mPredicted:[0m م ُ ب ا ر

Input:     م َ ن ْ ظ ُ و ر ADJ MASC SG PSSD
Target:    م َ ن ْ ظ ُ و ر
[91mPredicted:[0m م َ ن ْ ن ُ و ر

Input:     م َ ل ّ V PFV NOM(FEM,SG,2)
Target:    م َ ل ّ َ ي ت ِ ي
[92m