In [None]:
import tensorflow as tf
from subword_nmt.apply_bpe import BPE
import os

# File paths
CHECKPOINT_DIR = "./checkpoints"  # Replace with your checkpoint directory path
TRAIN_EN_FILE = "data/wmt14_de_en/train.tok.clean.bpe.32000.en"  # Update with correct path
TRAIN_DE_FILE = "data/wmt14_de_en/train.tok.clean.bpe.32000.de"  # Update with correct path
BPE_CODES_FILE = "data/wmt14_de_en/bpe.32000"  # Update with correct path
VOCAB_FILE_EN = "data/wmt14_de_en/vocab.bpe.32000"  # Update with correct path
VOCAB_FILE_DE = "data/wmt14_de_en/vocab.bpe.32000"  # Update with correct path

# Vocabulary and BPE model
with open(VOCAB_FILE_EN, "r", encoding="utf-8") as f:
    vocab_en = {line.strip(): idx for idx, line in enumerate(f)}
with open(VOCAB_FILE_DE, "r", encoding="utf-8") as f:
    vocab_de = {line.strip(): idx for idx, line in enumerate(f)}

# Reverse vocabulary for German (to decode token IDs back to words)
inv_vocab_de = {idx: word for word, idx in vocab_de.items()}

# Initialize BPE encoding
bpe = BPE(open(BPE_CODES_FILE, "r", encoding="utf-8"))

# Special tokens
SOS_TOKEN_ID = vocab_en.get("<sos>", 1)  # Use 1 if <sos> isn't in vocab
EOS_TOKEN_ID = vocab_en.get("<eos>", 2)  # Use 2 if <eos> isn't in vocab
UNK_TOKEN_ID = vocab_en.get("<unk>", 3)  # Use 3 if <unk> isn't in vocab

# Define encoder and decoder models (assuming you've implemented them previously)
encoder = Encoder(vocab_size=len(vocab_en), embedding_dim=256, enc_units=512, batch_sz=1)
decoder = Decoder(vocab_size=len(vocab_de), embedding_dim=256, dec_units=512, batch_sz=1)

# Checkpoint to restore the model weights
checkpoint = tf.train.Checkpoint(encoder=encoder, decoder=decoder)
checkpoint.restore(tf.train.latest_checkpoint(CHECKPOINT_DIR)).expect_partial()
print("Model weights restored from checkpoint.")

### Utility Functions

def preprocess_sentence(sentence, vocab, bpe):
    """
    Tokenize and encode the sentence with BPE, converting to vocab IDs.
    """
    sentence_bpe = bpe.process_line(sentence)  # Apply BPE
    tokens = sentence_bpe.split()  # Tokenize
    token_ids = [vocab.get(token, UNK_TOKEN_ID) for token in tokens]
    return tf.convert_to_tensor([token_ids], dtype=tf.int32)  # Convert to tensor

def decode_tokens(tokens, inv_vocab):
    """
    Decode a sequence of token IDs back to words.
    """
    words = []
    for token in tokens:
        word = inv_vocab.get(token, "<unk>")
        if word == "<eos>":
            break
        words.append(word)
    return " ".join(words)

def translate_sentence(sentence, encoder, decoder):
    """
    Translate an English sentence to German.
    """
    # Preprocess and encode the English sentence
    input_seq = preprocess_sentence(sentence, vocab_en, bpe)
    enc_output, enc_hidden = encoder(input_seq)
    
    # Initialize the decoder with the <sos> token
    dec_input = tf.expand_dims([SOS_TOKEN_ID], 0)
    dec_hidden = enc_hidden
    
    # Store the translation
    result = []
    
    # Generate translation token-by-token
    for t in range(50):  # Set a max length for the output sentence
        predictions, dec_hidden, _ = decoder(dec_input, enc_output, dec_hidden)
        predicted_id = tf.argmax(predictions[0]).numpy()
        
        # Stop if the <eos> token is predicted
        if predicted_id == EOS_TOKEN_ID:
            break
        
        # Append the predicted token to the result
        result.append(predicted_id)
        
        # Use the predicted token as the next input to the decoder
        dec_input = tf.expand_dims([predicted_id], 0)
    
    # Decode the sequence of token IDs into a German sentence
    translated_sentence = decode_tokens(result, inv_vocab_de)
    return translated_sentence

### Translate a Sample Sentence

sample_sentence = "Hello, how are you?"
translated_text = translate_sentence(sample_sentence, encoder, decoder)

print("English:", sample_sentence)
print("German:", translated_text)
