In [None]:
import torch
import torch.nn as nn
import re
import unicodedata
import os
from enum import Enum

# ========================================================================================
# SECTION 1: Model & Preprocessing Code (Copied from training script)
# This section contains the necessary class definitions and functions
# to rebuild the model architecture and process input text.
# ========================================================================================

# --- LTCCell and Model Definitions ---
class MappingType(Enum):
    Identity = 0; Linear = 1; Affine = 2
class ODESolver(Enum):
    SemiImplicit = 0; Explicit = 1; RungeKutta = 2

class LTCCell(nn.Module):
    # This class is copied verbatim from your training script.
    def __init__(self, input_size, num_units, solver=ODESolver.SemiImplicit, solver_clip=5.0):
        super(LTCCell, self).__init__(); self._input_size = input_size; self._num_units = num_units; self._solver = solver
        self._solver_clip = solver_clip; self._ode_solver_unfolds = 6; self._input_mapping = MappingType.Affine
        self._erev_init_factor = 1; self._w_init_max = 1.0; self._w_init_min = 0.01; self._cm_init_min = 0.5
        self._cm_init_max = 0.5; self._gleak_init_min = 1.0; self._gleak_init_max = 1.0; self._fix_cm = None
        self._fix_gleak = None; self._fix_vleak = None; self._w_min_value = 1e-5; self._w_max_value = 1000.0
        self._gleak_min_value = 1e-5; self._gleak_max_value = 1000.0; self._cm_t_min_value = 1e-6
        self._cm_t_max_value = 1000.0; self._get_variables(); self._map_inputs()
    @property
    def state_size(self): return self._num_units
    @property
    def output_size(self): return self._num_units
    def _map_inputs(self):
        if self._input_mapping in [MappingType.Affine, MappingType.Linear]:
            self.input_w = nn.Parameter(torch.Tensor(self._input_size)); nn.init.constant_(self.input_w, 1.0)
        if self._input_mapping == MappingType.Affine:
            self.input_b = nn.Parameter(torch.Tensor(self._input_size)); nn.init.constant_(self.input_b, 0.0)
    def _get_variables(self):
        self.sensory_mu = nn.Parameter(torch.Tensor(self._input_size, self._num_units)); self.sensory_sigma = nn.Parameter(torch.Tensor(self._input_size, self._num_units))
        self.sensory_W = nn.Parameter(torch.Tensor(self._input_size, self._num_units)); sensory_erev_init = (2 * torch.randint(0, 2, size=[self._input_size, self._num_units]) - 1) * self._erev_init_factor
        self.sensory_erev = nn.Parameter(sensory_erev_init.float()); nn.init.uniform_(self.sensory_mu, a=0.3, b=0.8)
        nn.init.uniform_(self.sensory_sigma, a=3.0, b=8.0); nn.init.uniform_(self.sensory_W, a=self._w_init_min, b=self._w_init_max)
        self.mu = nn.Parameter(torch.Tensor(self._num_units, self._num_units)); self.sigma = nn.Parameter(torch.Tensor(self._num_units, self._num_units))
        self.W = nn.Parameter(torch.Tensor(self._num_units, self._num_units)); erev_init = (2 * torch.randint(0, 2, size=[self._num_units, self._num_units]) - 1) * self._erev_init_factor
        self.erev = nn.Parameter(erev_init.float()); nn.init.uniform_(self.mu, a=0.3, b=0.8); nn.init.uniform_(self.sigma, a=3.0, b=8.0)
        nn.init.uniform_(self.W, a=self._w_init_min, b=self._w_init_max)
        if self._fix_vleak is None:
            self.vleak = nn.Parameter(torch.Tensor(self._num_units)); nn.init.uniform_(self.vleak, a=-0.2, b=0.2)
        else: self.register_buffer('vleak', torch.full([self._num_units], self._fix_vleak))
        if self._fix_gleak is None:
            self.gleak = nn.Parameter(torch.Tensor(self._num_units))
            if self._gleak_init_max > self._gleak_init_min: nn.init.uniform_(self.gleak, a=self._gleak_init_min, b=self._gleak_init_max)
            else: nn.init.constant_(self.gleak, self._gleak_init_min)
        else: self.register_buffer('gleak', torch.full([self._num_units], self._fix_gleak))
        if self._fix_cm is None:
            self.cm_t = nn.Parameter(torch.Tensor(self._num_units))
            if self._cm_init_max > self._cm_init_min: nn.init.uniform_(self.cm_t, a=self._cm_init_min, b=self._cm_init_max)
            else: nn.init.constant_(self.cm_t, self._cm_init_min)
        else: self.register_buffer('cm_t', torch.full([self._num_units], self._fix_cm))
    def forward(self, inputs, state):
        if self._input_mapping in [MappingType.Affine, MappingType.Linear]: inputs = inputs * self.input_w
        if self._input_mapping == MappingType.Affine: inputs = inputs + self.input_b
        if self._solver == ODESolver.Explicit: next_state = self._ode_step_explicit(inputs, state)
        elif self._solver == ODESolver.SemiImplicit: next_state = self._ode_step_semi_implicit(inputs, state)
        elif self._solver == ODESolver.RungeKutta: next_state = self._ode_step_runge_kutta(inputs, state)
        else: raise ValueError(f"Unknown ODE solver '{str(self._solver)}'")
        return next_state, next_state
    def _ode_step_semi_implicit(self, inputs, state):
        v_pre = state; sensory_w_activation = self.sensory_W * self._sigmoid(inputs, self.sensory_mu, self.sensory_sigma)
        sensory_rev_activation = sensory_w_activation * self.sensory_erev; w_numerator_sensory = torch.sum(sensory_rev_activation, dim=1)
        w_denominator_sensory = torch.sum(sensory_w_activation, dim=1)
        for _ in range(self._ode_solver_unfolds):
            w_activation = self.W * self._sigmoid(v_pre, self.mu, self.sigma); rev_activation = w_activation * self.erev
            w_numerator = torch.sum(rev_activation, dim=1) + w_numerator_sensory
            w_denominator = torch.sum(w_activation, dim=1) + w_denominator_sensory
            numerator = self.cm_t * v_pre + self.gleak * self.vleak + w_numerator
            denominator = self.cm_t + self.gleak + w_denominator; v_pre = numerator / denominator
        return v_pre
    def _f_prime(self, inputs, state):
        v_pre = state; sensory_w_activation = self.sensory_W * self._sigmoid(inputs, self.sensory_mu, self.sensory_sigma)
        w_reduced_sensory = torch.sum(sensory_w_activation, dim=1); w_activation = self.W * self._sigmoid(v_pre, self.mu, self.sigma)
        w_reduced_synapse = torch.sum(w_activation, dim=1); sensory_in = self.sensory_erev * sensory_w_activation; synapse_in = self.erev * w_activation
        sum_in = (torch.sum(sensory_in, dim=1) - v_pre * w_reduced_synapse + torch.sum(synapse_in, dim=1) - v_pre * w_reduced_sensory)
        f_prime = (1 / self.cm_t) * (self.gleak * (self.vleak - v_pre) + sum_in); return f_prime
    def _ode_step_explicit(self, inputs, state):
        v_pre = state; h = 0.1
        for _ in range(self._ode_solver_unfolds):
            f_prime = self._f_prime(inputs, v_pre); v_pre = v_pre + h * f_prime
            if self._solver_clip > 0: v_pre = torch.clamp(v_pre, -self._solver_clip, self._solver_clip)
        return v_pre
    def _ode_step_runge_kutta(self, inputs, state):
        v_pre = state; h = 0.1
        for _ in range(self._ode_solver_unfolds):
            k1 = h * self._f_prime(inputs, v_pre); k2 = h * self._f_prime(inputs, v_pre + 0.5 * k1)
            k3 = h * self._f_prime(inputs, v_pre + 0.5 * k2); k4 = h * self._f_prime(inputs, v_pre + k3)
            v_pre = v_pre + (1.0 / 6.0) * (k1 + 2 * k2 + 2 * k3 + k4)
            if self._solver_clip > 0: v_pre = torch.clamp(v_pre, -self._solver_clip, self._solver_clip)
        return v_pre
    def _sigmoid(self, v_pre, mu, sigma):
        v_pre = v_pre.unsqueeze(-1); mues = v_pre - mu; x = sigma * mues; return torch.sigmoid(x)
    def constrain_parameters(self):
        self.cm_t.data.clamp_(min=self._cm_t_min_value, max=self._cm_t_max_value)
        self.gleak.data.clamp_(min=self._gleak_min_value, max=self._gleak_max_value)
        self.W.data.clamp_(min=self._w_min_value, max=self._w_max_value)
        self.sensory_W.data.clamp_(min=self._w_min_value, max=self._w_max_value)

class EncoderLTC(nn.Module):
    def __init__(self, input_size, embed_size, hidden_size):
        super(EncoderLTC, self).__init__(); self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, embed_size)
        self.ltc_cell = LTCCell(embed_size, hidden_size)
    def forward(self, x):
        seq_length, batch_size = x.shape
        hidden_state = torch.zeros(batch_size, self.hidden_size).to(x.device)
        embedded = self.embedding(x)
        for t in range(seq_length):
            hidden_state, _ = self.ltc_cell(embedded[t], hidden_state)
        return hidden_state

class DecoderLTC(nn.Module):
    def __init__(self, output_size, embed_size, hidden_size):
        super(DecoderLTC, self).__init__(); self.hidden_size = hidden_size
        self.embedding = nn.Embedding(output_size, embed_size)
        self.ltc_cell = LTCCell(embed_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)
    def forward(self, x, hidden_state):
        x = x.unsqueeze(0)
        embedded = self.embedding(x)
        output, hidden_state = self.ltc_cell(embedded[0], hidden_state)
        predictions = self.fc(output)
        return predictions, hidden_state

class Seq2SeqLTC(nn.Module):
    # This Seq2Seq class is slightly simplified as it does not need the training-specific
    # forward pass with teacher forcing. We only need the encoder and decoder components.
    def __init__(self, encoder, decoder, target_vocab_size, device):
        super(Seq2SeqLTC, self).__init__(); self.encoder = encoder; self.decoder = decoder
        self.target_vocab_size = target_vocab_size; self.device = device
    def forward(self, source, target, teacher_force_ratio=0.5): # This method is not used in inference
        raise NotImplementedError("Use encoder and decoder directly for inference.")

# --- Tokenizer ---
class Tokenizer:
    def __init__(self, encoding_name="cl100k_base"):
        import tiktoken
        self.special_tokens = ["<pad>", "<sos>", "<eos>", "<unk>"]
        self._encoding = tiktoken.get_encoding(encoding_name)
        self._base_vocab_size = self._encoding.n_vocab
        self._offset = len(self.special_tokens)
        self.pad_id = 0; self.sos_id = 1; self.eos_id = 2; self.unk_id = 3
        self.vocab_size = self._base_vocab_size + self._offset
    def encode(self, s, add_special_tokens=True):
        s = "" if s is None else str(s)
        base_tokens = self._encoding.encode(s)
        token_ids = [t + self._offset for t in base_tokens]
        if add_special_tokens:
            return [self.sos_id] + token_ids + [self.eos_id]
        return token_ids
    def decode(self, ids):
        base_ids = [i - self._offset for i in ids if i >= self._offset]
        if not base_ids: return ""
        return self._encoding.decode(base_ids)

# --- Preprocessing Functions ---
def unicode_to_ascii(s):
    return ''.join(c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn')

def preprocess_sentence(w):
    w = str(w)
    w = unicode_to_ascii(w.lower().strip())
    w = re.sub(r"([?.!,¿])", r" \1 ", w)
    w = re.sub(r'[" "]+', " ", w)
    w = re.sub(r"[^a-zA-Z?.!,¿]+", " ", w)
    w = w.strip()
    return w

# ========================================================================================
# SECTION 2: Inference Configuration
# ========================================================================================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_SAVE_PATH = "ltc_translator_best.pt"
EMBED_SIZE = 256
HIDDEN_SIZE = 512

# ========================================================================================
# SECTION 3: Translation Function
# ========================================================================================
def translate_sentence(model, sentence, tokenizer_en, tokenizer_fr, device, max_length=50):
    """
    Translates a single sentence from English to French.
    """
    # Set model to evaluation mode
    model.eval()

    # Preprocess and tokenize the source sentence
    processed_sentence = preprocess_sentence(sentence)
    src_tokens = tokenizer_en.encode(processed_sentence)
    
    # Add batch dimension (batch size = 1) and move to device
    src_tensor = torch.tensor(src_tokens).unsqueeze(1).to(device)

    # Disable gradient calculation for inference
    with torch.no_grad():
        # Get the context vector from the encoder
        hidden_state = model.encoder(src_tensor)

    # Initialize the list of target tokens with the <sos> token
    trg_indexes = [tokenizer_fr.sos_id]

    # Generate the translation token by token
    for _ in range(max_length):
        # Get the last predicted token
        trg_tensor = torch.tensor([trg_indexes[-1]], device=device)

        # Feed the current token and the hidden state to the decoder
        with torch.no_grad():
            output, hidden_state = model.decoder(trg_tensor, hidden_state)
        
        # Get the token with the highest probability
        pred_token = output.argmax(1).item()
        trg_indexes.append(pred_token)

        # Stop if the <eos> token is generated
        if pred_token == tokenizer_fr.eos_id:
            break
            
    # Decode the list of token IDs back to a French sentence
    trg_tokens_decoded = tokenizer_fr.decode(trg_indexes)

    return trg_tokens_decoded


# ========================================================================================
# SECTION 4: Main Execution Block
# ========================================================================================
if __name__ == "__main__":
    # Check if the model file exists
    if not os.path.exists(MODEL_SAVE_PATH):
        print(f"Error: Model file not found at '{MODEL_SAVE_PATH}'")
        print("Please run the training script first to generate the model file.")
        exit()

    # --- Initialize tokenizers ---
    print("Initializing tokenizers...")
    tokenizer_en = Tokenizer()
    tokenizer_fr = Tokenizer()
    INPUT_DIM = tokenizer_en.vocab_size
    OUTPUT_DIM = tokenizer_fr.vocab_size

    # --- Rebuild the model architecture ---
    print("Initializing model architecture...")
    encoder = EncoderLTC(INPUT_DIM, EMBED_SIZE, HIDDEN_SIZE)
    decoder = DecoderLTC(OUTPUT_DIM, EMBED_SIZE, HIDDEN_SIZE)
    model = Seq2SeqLTC(encoder, decoder, OUTPUT_DIM, DEVICE).to(DEVICE)

    # --- Load the saved weights ---
    print(f"Loading model state from {MODEL_SAVE_PATH}...")
    # Use map_location to ensure the model loads correctly whether on CPU or GPU
    model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))

    print("\n--- English to French Translator ---")
    print('Type an English sentence to translate, or "quit" to exit.\n')

    # --- Interactive loop ---
    while True:
        try:
            sentence = input("English  > ")
            if sentence.lower() in ["quit", "exit"]:
                print("Exiting translator.")
                break
            
            translation = translate_sentence(model, sentence, tokenizer_en, tokenizer_fr, DEVICE)
            print(f"French   > {translation}\n")

        except KeyboardInterrupt:
            print("\nExiting translator.")
            break