## Imports

In [1]:
from src.tokenizer import Tokenizer, normalize_text
import numpy as np

import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

import jupyter_black

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

jupyter_black.load()
%matplotlib inline

In [2]:
# Read the text file
# with open("data/fra.txt", "r", encoding="ascii", errors="ignore") as fp:
with open("data/fra.txt", "r") as fp:
    data = fp.readlines()
sents = np.array([sent.strip("\n").split("\t") for sent in data])
sents = sents[:, :2]

sents[:, 0] = np.array([normalize_text(s) for s in sents[:, 0]])
sents[:, 1] = np.array([normalize_text(s) for s in sents[:, 1]])

#######################################################

# Decide on a maximum sequence length so we can pad our input sequences if necessary
MAX_LENGTH = 10
eng_prefixes = (
    "i am ",
    "i m ",
    "he is",
    "he s ",
    "she is",
    "she s ",
    "you are",
    "you re ",
    "we are",
    "we re ",
    "they are",
    "they re ",
)
condition = [
    len(s.split(" ")) <= MAX_LENGTH
    and len(t.split(" ")) <= MAX_LENGTH
    and s.startswith(eng_prefixes)
    for s, t in sents
]
sents = sents[condition]

src_ix, tgt_ix = 1, 0

src_tok = Tokenizer(max_length=MAX_LENGTH + 2)
src_tok.fit(sents[:, src_ix], max_vocab_size=508)
torch.save(src_tok, "data/src_tok.pt")

tgt_tok = Tokenizer(max_length=MAX_LENGTH + 2)
tgt_tok.fit(sents[:, tgt_ix], max_vocab_size=508)
torch.save(tgt_tok, "data/tgt_tok.pt")

X = src_tok.tok_seqs(sents[:, src_ix])
y = tgt_tok.tok_seqs(sents[:, tgt_ix])

Xtr, Xval, ytr, yval = train_test_split(
    X, y, test_size=0.2, random_state=42
)  # 80% train, 20% val
Xval, Xte, yval, yte = train_test_split(
    Xval, yval, test_size=0.5, random_state=42
)  # 10% val, 10% test

Xtr = torch.tensor(Xtr).to(device)
Xval = torch.tensor(Xval).to(device)
Xte = torch.tensor(Xte).to(device)
ytr = torch.tensor(ytr).to(device)
yval = torch.tensor(yval).to(device)
yte = torch.tensor(yte).to(device)

#########################################

batch_size = 25
vocab_size = src_tok.vocab_size
pad_token_ix = src_tok.wtoi[src_tok.pad_token]
print("Vocab size:", vocab_size)

train_ds = TensorDataset(Xtr, ytr)
train_dl = DataLoader(train_ds, batch_size=batch_size)
torch.save(train_dl, "data/train_dl.pt")

val_ds = TensorDataset(Xval, yval)
val_dl = DataLoader(val_ds, batch_size=batch_size)
torch.save(val_dl, "data/val_dl.pt")

test_ds = TensorDataset(Xte, yte)
test_dl = DataLoader(test_ds, batch_size=batch_size)
torch.save(test_dl, "data/test_dl.pt")

tiny_train_ds = TensorDataset(Xtr[:10], ytr[:10])
tiny_train_dl = DataLoader(tiny_train_ds, batch_size=5)
torch.save(tiny_train_dl, "data/tiny_train_dl.pt")

Vocab size: 512
