In [16]:
import tkseem as tk
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split

from dotless_arabic.processing import process
from dotless_arabic.experiments.nlms.src import constants
from dotless_arabic.experiments.constants import COLLECT_DATASET
from dotless_arabic.experiments.nlms.src.models import LitNeuralLanguageModel
from dotless_arabic.experiments.nlms.src.utils import generate_text,get_best_checkpoint,get_tokenizer,get_dataloader
from dotless_arabic.tokenizers import WordTokenizer,FarasaMorphologicalTokenizer,DisjointLetterTokenizer,CharacterTokenizer

In [17]:
# write the dataset name here :)
dataset_name = 'quran'
dataset_type = 'dotted'
tokenizer_class = DisjointLetterTokenizer

In [18]:
dataset = COLLECT_DATASET[dataset_name]()
dataset = list(
    map(
        process,
        dataset,
    )
)

In [19]:
train_dataset, test_dataset = train_test_split(
        dataset,
        shuffle=True,
        test_size=constants.TEST_SIZE,
        random_state=constants.RANDOM_SEED,
    )

train_dataset, val_dataset = train_test_split(
    train_dataset,
    shuffle=True,
    test_size=constants.VAL_SIZE,
    random_state=constants.RANDOM_SEED,
)

In [20]:
model = LitNeuralLanguageModel.load_from_checkpoint(
        get_best_checkpoint(
                dataset_id=f"{dataset_type.upper()}-{dataset_name.upper()}_DATASET",
                tokenizer_class=tokenizer_class,
        )
)
model

LitNeuralLanguageModel(
  (embedding_layer): Embedding(2224, 512)
  (gru_layer): GRU(512, 512, num_layers=4, batch_first=True)
  (first_dense_layer): Linear(in_features=512, out_features=512, bias=True)
  (dropout_layer): Dropout(p=0.333, inplace=False)
  (relu): ReLU()
  (second_dense_layer): Linear(in_features=512, out_features=2224, bias=True)
)

In [21]:
tokenizer = get_tokenizer(
        train_dataset=train_dataset,
        vocab_size=model.vocab_size,
        tokenizer_class=tokenizer_class,
    )

Training DisjointLetterTokenizer ...


In [None]:
t = generate_text(
    lm_model=model,
    num_tokens=100,
    sequence_length=20,
    tokenizer=tokenizer,
)
print(t)