In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import spacy
from utils import build_phrase_vocab
from pre_processing import preprocess_with_phrases

In [None]:
sequence_len = 128
min_len = 5

en_texts, hi_texts = ..., ...
en_proc, hi_proc, phrase_tags = preprocess_with_phrases(en_texts, hi_texts, min_len, sequence_len)
phrase2idx = build_phrase_vocab()

In [None]:
class CharPhraseDataset(Dataset):
    def __init__(self, x, y, phrases, sequence_len, ch2i, phrase2idx):
        self.x, self.y, self.phrases = x, y, phrases
        self.sequence_len = sequence_len
        self.ch2i = ch2i
        self.phrase2idx = phrase2idx

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        x = [self.ch2i.get(c, 0) for c in self.x[idx]]
        y = [self.ch2i.get(c, 0) for c in self.y[idx]]
        p = [self.phrase2idx.get(tag, 0) for tag in self.phrases[idx]]
        # Padding if needed
        x = x[:self.sequence_len] + [0]*(self.sequence_len - len(x))
        y = y[:self.sequence_len] + [0]*(self.sequence_len - len(y))
        p = p[:self.sequence_len] + [0]*(self.sequence_len - len(p))
        return torch.tensor(x), torch.tensor(y), torch.tensor(p)

In [None]:
chars = set(''.join(en_proc + hi_proc))
ch2i = {c: i for i, c in enumerate(['<pad>'] + sorted(list(chars)))}

In [None]:
dataset = CharPhraseDataset(en_proc, hi_proc, phrase_tags, sequence_len, ch2i, phrase2idx)

In [None]:
from transformer import TransformerWithPhrase

mconfig = ...  # transformer config as per your code
model = TransformerWithPhrase(mconfig, phrase_vocab_size=len(phrase2idx), phrase_emb_dim=16)

In [None]:
from trainer import Trainer, TrainerConfig
trainer_config = TrainerConfig(...)
trainer = Trainer(model, dataset, trainer_config)
trainer.train()

In [None]:
test_sents = [...]
test_phrases = [extract_7_phrases(s) for s in test_sents]
test_p = [[phrase2idx.get(tag, 0) for tag in tags] + [0]*(sequence_len-len(tags)) for tags in test_phrases]
test_x = [[ch2i.get(c, 0) for c in s] + [0]*(sequence_len-len(s)) for s in test_sents]
test_x = torch.tensor(test_x)
test_p = torch.tensor(test_p)
with torch.no_grad():
    translations = model.generate(test_x, test_p)
print(translations)