# Import Statements

In [1]:
import os, math, random, pickle, re
from collections import Counter

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
    
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
    device = torch.device("cuda")
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print("Device used:", device)

Device used: mps


# Initial configuration

In [3]:
# data paramaters
ctx_turns = 2          
max_src_len = 128
max_tgt_len = 40
batch_size = 32

# model params
embed_dimen = 256
hidden_dimen = 256
dropout = 0.3
bidir_encoder = True   
teacher_forcing = 0.5
lr = 2e-3
epochs = 10

ckpt_dir = "../models/checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)
best_ckpt = os.path.join(ckpt_dir, "seq2seq_attn_best.pt")
last_ckpt = os.path.join(ckpt_dir, "seq2seq_attn_last.pt")

# Preprocess and vocab (extend with special/control tokens)

In [4]:
try:
    from utils.preprocess_dailydialog import preprocess_text
except Exception:
    import nltk
    from nltk.corpus import stopwords
    from nltk.stem import WordNetLemmatizer
    nltk.download('punkt', quiet=True); nltk.download('stopwords', quiet=True); nltk.download('wordnet', quiet=True)
    _sw = set(stopwords.words('english')); _lem = WordNetLemmatizer()
    def preprocess_text(text):
        text = text.lower()
        text = re.sub(r"http\S+", "", text)
        text = re.sub(r"[^a-zA-Z\s]", "", text)
        text = re.sub(r"\s+", " ", text).strip()
        toks = nltk.word_tokenize(text)
        return [_lem.lemmatize(t) for t in toks if t not in _sw]

with open("../data/processed/vocab.pkl", "rb") as f:
    vocab = pickle.load(f)

# Add needed special tokens if missing
SPECIALS = ["<PAD>", "<UNK>", "<BOS>", "<EOS>", "<CTX_SEP>",
            "<EMO_0>", "<EMO_1>", "<EMO_2>", "<EMO_3>", "<EMO_4>", "<EMO_5>", "<EMO_6>",
            "<ACT_0>", "<ACT_1>", "<ACT_2>", "<ACT_3>", "<ACT_4>"]
for tok in SPECIALS:
    if tok not in vocab:
        vocab[tok] = len(vocab)

pad_idx = vocab["<PAD>"]
unk_idx = vocab["<UNK>"]
bos_idx = vocab["<BOS>"]
eos_idx = vocab["<EOS>"]
sep_idx = vocab["<CTX_SEP>"]

def toks_to_ids(tokens, max_len):
    ids = [vocab.get(t, unk_idx) for t in tokens][:max_len]
    return ids

def pad_to(ids, pad, max_len):
    return ids + [pad] * (max_len - len(ids))

# Build context->response pairs from DailyDialog

In [5]:
def build_pairs(split_name="train", 
                ctx_turns=ctx_turns, 
                include_controls=True, 
                drop_dummy_act=True):
    
    ds = load_dataset("daily_dialog", split=split_name)
    pairs = []

    for dialog, emos, acts in zip(ds["dialog"], ds["emotion"], ds["act"]):
        toks_list = [preprocess_text(utt) for utt in dialog]
        for t in range(len(toks_list) - 1):
            start = max(0, t - ctx_turns + 1)
            ctx_utts = toks_list[start:t+1]     
            resp_utts = toks_list[t+1] 

            # flatten context with separator
            ctx = []
            for u in ctx_utts:
                if ctx: ctx.append("<CTX_SEP>")
                ctx.extend(u)

            # optional control tokens (gold labels from HF)
            if include_controls:
                emo_id = emos[t+1] if emos[t+1] is not None else 0
                act_id = acts[t+1] if acts[t+1] is not None else 0
                if drop_dummy_act and act_id == 0:
                    continue
                ctx = [f"<EMO_{emo_id}>", f"<ACT_{act_id}>"] + ctx

            pairs.append((ctx, resp_utts))

    return pairs

train_pairs = build_pairs("train")
val_pairs   = build_pairs("validation")
test_pairs  = build_pairs("test")
len(train_pairs), len(val_pairs), len(test_pairs)

Using the latest cached version of the dataset since daily_dialog couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /Users/anaghar/.cache/huggingface/datasets/daily_dialog/default/1.0.0/1d0a58c7f2a4dab5ed9d01dbde8e55e0058e589ab81fce5c2df929ea810eabcd (last modified on Sun Jun 15 03:11:15 2025).
Using the latest cached version of the dataset since daily_dialog couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /Users/anaghar/.cache/huggingface/datasets/daily_dialog/default/1.0.0/1d0a58c7f2a4dab5ed9d01dbde8e55e0058e589ab81fce5c2df929ea810eabcd (last modified on Sun Jun 15 03:11:15 2025).
Using the latest cached version of the dataset since daily_dialog couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /Users/anaghar/.cache/huggingface/datasets/daily_dialog/default/1.0.0/1d0a58c7f2a4dab5ed9d01dbde8e55e0058e589ab81fce5c2df929ea810eabcd (las

(76052, 7069, 6740)

# Dataset & dynamic padding collate

In [6]:
class Seq2SeqDataset(Dataset):
    def __init__(self, pairs):
        self.pairs = pairs

    def __len__(self): return len(self.pairs)

    def __getitem__(self, idx):
        src_toks, tgt_toks = self.pairs[idx]
        # ids (truncate and add BOS/EOS to target)
        src_ids = toks_to_ids(src_toks, max_src_len)
        tgt_ids = [bos_idx] + toks_to_ids(tgt_toks, max_tgt_len-2) + [eos_idx]
        return torch.tensor(src_ids), torch.tensor(tgt_ids)

def collate_batch(batch):
    srcs, tgts = zip(*batch)
    src_lens = [len(s) for s in srcs]
    tgt_lens = [len(t) for t in tgts]

    max_src = max(src_lens); max_tgt = max(tgt_lens)
    src_pad = torch.stack([F.pad(s, (0, max_src - len(s)), value=pad_idx) for s in srcs])
    tgt_pad = torch.stack([F.pad(t, (0, max_tgt - len(t)), value=pad_idx) for t in tgts])

    return src_pad, torch.tensor(src_lens), tgt_pad, torch.tensor(tgt_lens)

train_ds = Seq2SeqDataset(train_pairs)
val_ds   = Seq2SeqDataset(val_pairs)
test_ds  = Seq2SeqDataset(test_pairs)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, collate_fn=collate_batch)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, collate_fn=collate_batch)

# Model: Encoder, Luong-Attention, Decoder, Seq2Seq wrapper

In [7]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, pad_idx, bidirectional=True, dropout=0.):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.bidir = bidirectional
        self.rnn = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=bidirectional)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, src_lens):
        # src: (B, T)
        emb = self.dropout(self.embedding(src))
        packed = nn.utils.rnn.pack_padded_sequence(emb, src_lens.cpu(), batch_first=True, enforce_sorted=False)
        enc_outs, (h, c) = self.rnn(packed)
        enc_outs, _ = nn.utils.rnn.pad_packed_sequence(enc_outs, batch_first=True)
        if self.bidir:
            # concat last forward/backward states
            h = torch.cat([h[-2], h[-1]], dim=1).unsqueeze(0)
            c = torch.cat([c[-2], c[-1]], dim=1).unsqueeze(0)
        return enc_outs, (h, c)  # enc_outs for attention

class LuongAttention(nn.Module):
    def __init__(self, enc_dim, dec_dim):
        super().__init__()
        assert enc_dim == dec_dim, "For dot attention, enc_dim must equal dec_dim"
        
    def forward(self, dec_hidden, enc_outs, mask):
        scores = torch.bmm(enc_outs, dec_hidden.unsqueeze(2)).squeeze(2)
        scores = scores.masked_fill(~mask, -1e9)
        attn_weights = F.softmax(scores, dim=1)
        context = torch.bmm(attn_weights.unsqueeze(1), enc_outs).squeeze(1)  
        return context, attn_weights

class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, pad_idx, dropout=0.):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.rnn = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.attn = LuongAttention(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim * 2, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward_step(self, input_tok, hidden, enc_outs, enc_mask):
        # input_tok: (B,) token ids for current step
        emb = self.dropout(self.embedding(input_tok.unsqueeze(1)))   
        out, hidden = self.rnn(emb, hidden)                           
        dec_h = out.squeeze(1)                                        
        context, _ = self.attn(dec_h, enc_outs, enc_mask)             
        logits = self.fc(torch.cat([dec_h, context], dim=1))
        return logits, hidden

class Seq2Seq(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, pad_idx, bidir_encoder=True, dropout=0.):
        super().__init__()
        enc_out_dim = hidden_dim * (2 if bidir_encoder else 1)
        self.encoder = Encoder(vocab_size, embed_dim, hidden_dim, pad_idx, bidirectional=bidir_encoder, dropout=dropout)
        # For dot attention, make decoder hidden dim == enc_out_dim
        self.decoder = Decoder(vocab_size, embed_dim, enc_out_dim, pad_idx, dropout=dropout)
        self.bridge_h = nn.Linear(enc_out_dim, enc_out_dim)
        self.bridge_c = nn.Linear(enc_out_dim, enc_out_dim)
        self.pad_idx = pad_idx

    def make_src_mask(self, src_lens, max_time):
        # returns (B, max_time) True for valid timesteps
        B = src_lens.size(0)
        rng = torch.arange(max_time, device=src_lens.device).unsqueeze(0).expand(B, -1)
        return rng < src_lens.unsqueeze(1)

    def forward(self, src, src_lens, tgt, teacher_forcing=0.5):

        batch, Tt = tgt.size()
        enc_outs, (h, c) = self.encoder(src, src_lens)
        # init dec hidden with bridged encoder state
        h0 = torch.tanh(self.bridge_h(h[-1])).unsqueeze(0)
        c0 = torch.tanh(self.bridge_c(c[-1])).unsqueeze(0)

        logits = []
        inp = tgt[:, 0]
        enc_mask = self.make_src_mask(src_lens, enc_outs.size(1)) 

        hidden = (h0, c0)
        for t in range(1, Tt):
            logit_t, hidden = self.decoder.forward_step(inp, hidden, enc_outs, enc_mask)
            logits.append(logit_t.unsqueeze(1))
            teacher = (random.random() < teacher_forcing)
            inp = tgt[:, t] if teacher else logit_t.argmax(dim=1)

        return torch.cat(logits, dim=1)

# Train / validate (loss + BLEU) and save ckpts

In [8]:
def seq_ce_loss(logits, tgt):

    gold = tgt[:, 1:] 
    loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)),
                           gold.reshape(-1),
                           ignore_index=pad_idx)
    return loss

@torch.no_grad()
def greedy_decode(model, src, src_lens, max_len=max_tgt_len):
    model.eval()
    enc_outs, (h, c) = model.encoder(src, src_lens)
    h0 = torch.tanh(model.bridge_h(h[-1])).unsqueeze(0)
    c0 = torch.tanh(model.bridge_c(c[-1])).unsqueeze(0)
    enc_mask = model.make_src_mask(src_lens, enc_outs.size(1))

    batch = src.size(0)
    inputs = torch.full((batch,), bos_idx, dtype=torch.long, device=src.device)
    hidden = (h0, c0)

    hyps = [[] for _ in range(batch)]
    for _ in range(max_len-1):
        logits, hidden = model.decoder.forward_step(inputs, hidden, enc_outs, enc_mask)
        next_tok = logits.argmax(dim=1)
        inputs = next_tok
        for i in range(batch):
            if next_tok[i].item() == eos_idx:
                continue
            hyps[i].append(int(next_tok[i].item()))
    return hyps

def evaluate_bleu(model, loader, max_len=max_tgt_len):
    model.eval()
    chencherry = SmoothingFunction().method3
    refs, hyps = [], []
    with torch.no_grad():
        for src, src_lens, tgt, _ in loader:
            src, src_lens, tgt = src.to(device), src_lens.to(device), tgt.to(device)
            preds = greedy_decode(model, src, src_lens, max_len)
            for i in range(len(preds)):
                ref_ids = tgt[i, 1:].tolist()  # drop BOS
                if eos_idx in ref_ids:
                    ref_ids = ref_ids[:ref_ids.index(eos_idx)]
                elif pad_idx in ref_ids:
                    ref_ids = ref_ids[:ref_ids.index(pad_idx)]
                refs.append([[id_to_tok(x) for x in ref_ids if x != pad_idx]])
                hyps.append([id_to_tok(x) for x in preds[i] if x not in (pad_idx, bos_idx, eos_idx)])
    bleu = corpus_bleu(refs, hyps, smoothing_function=chencherry)
    return bleu

# small helpers
inv_vocab = {i: t for t, i in vocab.items()}
def id_to_tok(i): return inv_vocab.get(int(i), "<UNK>")

model = Seq2Seq(
    vocab_size=len(vocab),
    embed_dim=embed_dimen,
    hidden_dim=hidden_dimen,
    pad_idx=pad_idx,
    bidir_encoder=bidir_encoder,
    dropout=dropout
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr)

best_val = float("inf"); best_bleu = 0.0
for epoch in range(1, epochs+1):
    model.train()
    total = 0.0; steps = 0
    for src, src_lens, tgt, tgt_lens in train_loader:
        src, src_lens, tgt = src.to(device), src_lens.to(device), tgt.to(device)
        optimizer.zero_grad()
        logits = model(src, src_lens, tgt, teacher_forcing=teacher_forcing)
        loss = seq_ce_loss(logits, tgt)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total += loss.item(); steps += 1

    val_loss = 0.0; vsteps = 0
    model.eval()
    with torch.no_grad():
        for src, src_lens, tgt, tgt_lens in val_loader:
            src, src_lens, tgt = src.to(device), src_lens.to(device), tgt.to(device)
            logits = model(src, src_lens, tgt, teacher_forcing=0.0)
            loss = seq_ce_loss(logits, tgt)
            val_loss += loss.item(); vsteps += 1
    val_loss /= max(1, vsteps)

    # optional BLEU each epoch
    bleu = evaluate_bleu(model, val_loader)

    print(f"Epoch {epoch:02d} | train: {total/max(1,steps):.4f} | val: {val_loss:.4f} | BLEU {bleu:.4f}")

    # save best by BLEU first, else by val loss
    improved = False
    if bleu > best_bleu:
        best_bleu = bleu; improved = True
    elif val_loss < best_val:
        improved = True
    if improved:
        best_val = val_loss
        torch.save(model.state_dict(), best_ckpt)

torch.save(model.state_dict(), last_ckpt)
print("✓ Training done. Best and last checkpoints saved.")

Epoch 01 | train: 6.2307 | val: 5.9549 | BLEU 0.0006
Epoch 02 | train: 5.6418 | val: 5.8190 | BLEU 0.0015
Epoch 03 | train: 5.2511 | val: 5.8145 | BLEU 0.0020
Epoch 04 | train: 4.9879 | val: 5.8174 | BLEU 0.0026
Epoch 05 | train: 4.7917 | val: 5.8703 | BLEU 0.0044
Epoch 06 | train: 4.6410 | val: 5.8786 | BLEU 0.0049
Epoch 07 | train: 4.5254 | val: 5.9146 | BLEU 0.0058
Epoch 08 | train: 4.4234 | val: 5.9541 | BLEU 0.0070
Epoch 09 | train: 4.3519 | val: 5.9635 | BLEU 0.0065
Epoch 10 | train: 4.2757 | val: 5.9939 | BLEU 0.0077
✓ Training done. Best and last checkpoints saved.


# Greedy decode + quick samples

In [9]:
def decode_ids(ids):
    toks = [id_to_tok(i) for i in ids if i not in (pad_idx, bos_idx, eos_idx)]
    return " ".join(toks)

batch = next(iter(val_loader))
src, src_lens, tgt, _ = [x.to(device) if torch.is_tensor(x) else x for x in batch]
preds = greedy_decode(model, src, src_lens)

for i in range(min(5, src.size(0))):
    print("CTX:", decode_ids(src[i].tolist()))
    print("REF:", decode_ids(tgt[i].tolist()[1:]))
    print("GEN:", decode_ids(preds[i]))
    print("---")

CTX: <EMO_0> <ACT_1> good morning sir bank near
REF: one block away
GEN: yes morning
---
CTX: <EMO_0> <ACT_3> good morning sir bank near <CTX_SEP> one block away
REF: well thats <UNK> change money
GEN: yes sir
---
CTX: <EMO_0> <ACT_2> one block away <CTX_SEP> well thats <UNK> change money
REF: surely course kind currency got
GEN: oh
---
CTX: <EMO_0> <ACT_1> well thats <UNK> change money <CTX_SEP> surely course kind currency got
REF: rib
GEN: yes
---
CTX: <EMO_0> <ACT_2> surely course kind currency got <CTX_SEP> rib
REF: much would like change
GEN: would like euro
---


# User-facing inference wrapper (with optional emotion/act)

In [10]:
@torch.no_grad()
def beam_search_decode(model, src, src_lens, max_len, beam_size=5, length_penalty=0.8, min_len=6):
    model.eval()
    enc_outs, (h, c) = model.encoder(src, src_lens)
    h0 = torch.tanh(model.bridge_h(h[-1])).unsqueeze(0)
    c0 = torch.tanh(model.bridge_c(c[-1])).unsqueeze(0)
    enc_mask = model.make_src_mask(src_lens, enc_outs.size(1))

    assert src.size(0) == 1, "beam_search_decode expects batch=1"
    beams = [(0.0, [bos_idx], (h0, c0))]
    completed = []

    for _ in range(max_len - 1):
        new_beams = []
        for logp, seq, hid in beams:
            last = seq[-1]
            if last == eos_idx and len(seq) > 1:
                completed.append((logp, seq))
                continue

            # block EOS until min_len
            block_eos = (len(seq) - 1) < min_len

            inp = torch.tensor([last], device=src.device, dtype=torch.long)
            logits, new_hid = model.decoder.forward_step(inp, hid, enc_outs, enc_mask)
            probs = F.log_softmax(logits, dim=-1).squeeze(0)
            if block_eos:
                probs[eos_idx] = -1e9

            topk_logp, topk_idx = probs.topk(beam_size)
            for k in range(beam_size):
                new_beams.append((logp + float(topk_logp[k]),
                                  seq + [int(topk_idx[k])],
                                  new_hid))

        # prune with length penalty
        new_beams.sort(key=lambda x: x[0] / (len(x[1]) ** length_penalty), reverse=True)
        beams = new_beams[:beam_size]

    completed.extend([(lp, s) for lp, s, _ in beams])
    completed.sort(key=lambda x: x[0] / (len(x[1]) ** length_penalty), reverse=True)
    best = completed[0][1][1:]  # drop BOS
    if eos_idx in best:
        best = best[:best.index(eos_idx)]
    return [best]

@torch.no_grad()
def sample_decode(model, src, src_lens, max_len, temperature=0.9, top_k=50, top_p=0.9, min_len=6):
    model.eval()
    enc_outs, (h, c) = model.encoder(src, src_lens)
    h0 = torch.tanh(model.bridge_h(h[-1])).unsqueeze(0)
    c0 = torch.tanh(model.bridge_c(c[-1])).unsqueeze(0)
    enc_mask = model.make_src_mask(src_lens, enc_outs.size(1))

    inp = torch.tensor([bos_idx], device=src.device)
    hidden = (h0, c0)
    out_ids = []

    for t in range(max_len - 1):
        logits, hidden = model.decoder.forward_step(inp, hidden, enc_outs, enc_mask)
        logits = logits / max(1e-6, temperature)
        probs = F.softmax(logits, dim=-1).squeeze(0)

        # block EOS until min_len
        if t < min_len:
            probs[eos_idx] = 0.0
            probs = probs / probs.sum()

        # top-k
        if top_k is not None and top_k > 0:
            topk_vals, topk_idx = torch.topk(probs, k=min(top_k, probs.size(-1)))
            mask = torch.ones_like(probs, dtype=torch.bool)
            mask[topk_idx] = False
            probs = torch.where(mask, torch.zeros_like(probs), probs)

        # top-p (nucleus)
        if top_p is not None and 0 < top_p < 1:
            sorted_probs, sorted_idx = torch.sort(probs, descending=True)
            cumsum = torch.cumsum(sorted_probs, dim=-1)
            cutoff = cumsum > top_p
            cutoff[0] = False
            sorted_probs[cutoff] = 0
            probs = torch.zeros_like(probs).scatter(0, sorted_idx, sorted_probs)
            probs = probs / probs.sum()

        tok = torch.multinomial(probs, 1).item()
        if tok == eos_idx:
            break
        out_ids.append(tok)
        inp = torch.tensor([tok], device=src.device)

    return [out_ids]

In [14]:
_contraction_fixes = [
    (r"\bdont\b", "don't"), (r"\bdoesnt\b", "doesn't"), (r"\bdidnt\b", "didn't"),
    (r"\bhasnt\b", "hasn't"), (r"\bhavent\b", "haven't"), (r"\bhadnt\b", "hadn't"),
    (r"\bcant\b", "can't"), (r"\bcouldnt\b", "couldn't"), (r"\bshouldnt\b", "shouldn't"),
    (r"\bwouldnt\b", "wouldn't"), (r"\bwont\b", "won't"), (r"\bisnt\b", "isn't"),
    (r"\baren't\b", "aren't"), (r"\bim\b", "I'm"), (r"\bi'm\b", "I'm"),
    (r"\bive\b", "I've"), (r"\bi've\b", "I've"), (r"\bill\b", "I'll"), (r"\bi'll\b", "I'll"),
    (r"\biyoull\b", "you'll"), (r"\bthatll\b", "that'll"), (r"\btheres\b", "there's"),
    (r"\bwhats\b", "what's"), (r"\blets\b", "let's"),
]
def clean_generated_text(text: str) -> str:
    if not text:
        return text
    # tighten spacing before punctuation
    text = re.sub(r"\s+([,.;:!?])", r"\1", text)
    # apply contraction fixes (case-insensitive)
    for pat, rep in _contraction_fixes:
        text = re.sub(pat, rep, text, flags=re.IGNORECASE)
    # tidy casing and terminal punctuation
    text = text.strip()
    if text:
        text = text[0].upper() + text[1:]
        if text[-1] not in ".!?":
            text += "."
    return text

# encode context for inference (unchanged)
def encode_context(text, emo_id=None, act_id=None, ctx_turns=1):
    if isinstance(text, str):
        turns = [text]
    else:
        turns = text[-ctx_turns:]

    toks = []
    if emo_id is not None: toks.append(f"<EMO_{int(emo_id)}>")
    if act_id is not None: toks.append(f"<ACT_{int(act_id)}>")
    for t in turns:
        if len(toks) > 0: toks.append("<CTX_SEP>")
        toks.extend(preprocess_text(t))

    ids = toks_to_ids(toks, max_src_len)
    ids = pad_to(ids, pad_idx, max_src_len)
    return torch.tensor(ids, dtype=torch.long).unsqueeze(0)

@torch.no_grad()
def generate_reply(context, emo_id=None, act_id=None, max_len=max_tgt_len,
                   strategy="beam",
                   beam_size=5, length_penalty=0.8, min_len=6,
                   temperature=0.9, top_k=50, top_p=0.9):
    src = encode_context(context, emo_id, act_id).to(device)                   
    src_len = torch.tensor([(src != pad_idx).sum().item()], device=device)     
    if strategy == "beam":
        pred_ids = beam_search_decode(model, src, src_len, max_len=max_len,
                                      beam_size=beam_size, length_penalty=length_penalty,
                                      min_len=min_len)[0]
    else:
        pred_ids = sample_decode(model, src, src_len, max_len=max_len,
                                 temperature=temperature, top_k=top_k, top_p=top_p,
                                 min_len=min_len)[0]
    raw = decode_ids(pred_ids)
    return clean_generated_text(raw)


print(generate_reply(
    ["Could you send the file?", "Sure, I will send it by evening."],
    emo_id=4,
    act_id=1,
    strategy="beam",
    beam_size=5,
    length_penalty=0.8,
    min_len=8
))

Ok I'll call back sunday evening forward evening.
