In [None]:
# install stuff
!pip install datasets==3.6.0
!pip install sentencepiece

In [None]:
!pip install sacrebleu
!pip install -q gdown

In [None]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [None]:
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from datasets import load_dataset
import torch.nn.functional as F
from collections import Counter
import sentencepiece as spm
from tqdm import tqdm
import torch.nn as nn
import sacrebleu
import random
import torch
import math

In [None]:
# HyperParams (15GB VRAM Friendly)
epochs = 10
embed_dim = 384
num_heads = 6                  # 384 / 6 = 64 per head
head_dim = embed_dim // num_heads
num_layers = 8
ff_hidden_dim = 4 * embed_dim  # = 1536
seq_len = 128
vocab_size = 32000             # smaller = faster & better training
dropout = 0.05
weight_decay = 0.001
batch_size = 32
early_stop_patience = 3
lr = 1e-4
betas = (0.9, 0.95)
clip_grad = 1.0
special_tokens = ["<pad>","<unk>","<s>","</s>",]
prev_epochs = 0

In [None]:
from google.colab import drive
drive.mount('/content/drive')
drive_save_dir = '/content/drive/MyDrive/model_checkpoints'
os.makedirs(drive_save_dir, exist_ok=True)

In [None]:
import gdown
# get the model last check point
file_id = ""
output_path = f"best_model_{3}.pt"

gdown.download(f"https://drive.google.com/uc?id={file_id}", output_path, quiet=False)

In [None]:
# get device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
#
def save_model_checkpoint(model, epoch,prev_epochs, optimizer, scheduler):
    # Remove all older checkpoints
    for f in os.listdir(drive_save_dir):
        if f.startswith("best_model_epoch_"):
            os.remove(os.path.join(drive_save_dir, f))
    # catch with old model
    prev_epochs += epoch

    # Save new checkpoint with current epoch in filename
    save_path = os.path.join(drive_save_dir, f"best_model_epoch_{prev_epochs}.pt")
    torch.save({
    'epoch': epoch,
    'model_state': model.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    'scheduler_state': scheduler.state_dict()
    }, save_path)
    print(f"[✔] Saved model at epoch {epoch} to: {save_path}")

In [None]:
# Load offensive dataset (it's less then 50% so should be fine i guess)
offensive_data_train = load_dataset("allenai/prosocial-dialog", split="train")
offensive_texts_train = ["<s> User: " + d["context"] + " <sep> Bot: " + d["response"] + " </s>" for d in offensive_data_train]
# test set
offensive_data_test = load_dataset("allenai/prosocial-dialog", split="test")
offensive_texts_test = ["<s> User: " + d["context"] + " <sep> Bot: " + d["response"] + " </s>" for d in offensive_data_test]
# evaluation set
offensive_data_val = load_dataset("allenai/prosocial-dialog", split="validation")
offensive_texts_val = ["<s> User: " + d["context"] + " <sep> Bot: " + d["response"] + " </s>" for d in offensive_data_val]

# Load the clean dataset
clean_data_train = load_dataset("daily_dialog", split="train")
clean_texts_train = ["<s> User: " + d["dialog"][0] + " <sep> Bot: " + d["dialog"][1] + " </s>" for d in clean_data_train if len(d["dialog"]) >= 2]
# test set
clean_data_test = load_dataset("daily_dialog", split="test")
clean_texts_test = ["<s> User: " + d["dialog"][0] + " <sep> Bot: " + d["dialog"][1] + " </s>" for d in clean_data_test if len(d["dialog"]) >= 2]
# eval set
clean_data_val = load_dataset("daily_dialog", split="validation")
clean_texts_val = ["<s> User: " + d["dialog"][0] + " <sep> Bot: " + d["dialog"][1] + " </s>" for d in clean_data_val if len(d["dialog"]) >= 2]

# Merge
combined_texts_train = offensive_texts_train + clean_texts_train
random.shuffle(combined_texts_train)
# test
combined_texts_test = offensive_texts_test + clean_texts_test
# eval
combined_texts_val = offensive_texts_val + clean_texts_val

In [None]:
# get data content
combined_texts = combined_texts_train

In [None]:
with open("train_text.txt", "w", encoding="utf-8") as f:
    for line in combined_texts:
        f.write(line.strip() + "\n")

In [None]:
spm.SentencePieceTrainer.train(
    input='train_text.txt',
    model_prefix='chatbot_tokenizer',
    vocab_size=vocab_size,
    model_type='bpe',
    pad_id=0,
    unk_id=1,
    bos_id=2,
    eos_id=3,
    user_defined_symbols=['<sep>', 'User:', 'Bot:',"<s>","</s>"]
)


In [None]:
class GPTDataset(torch.utils.data.Dataset):
    def __init__(self, texts, tokenizer, seq_len, vocab_size):
        self.texts = texts
        self.tokenizer = tokenizer
        self.seq_len = seq_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
      tokens = self.tokenizer.encode(self.texts[idx], out_type=int)

      # Pad up to seq_len
      if len(tokens) < self.seq_len:
          tokens += [0] * (self.seq_len - len(tokens))
      else:
          tokens = tokens[:self.seq_len]

      # Now, shift for input and target
      input_ids = torch.tensor(tokens[:-1], dtype=torch.long)   # seq_len - 1
      target_ids = torch.tensor(tokens[1:], dtype=torch.long)   # seq_len - 1

      return input_ids, target_ids


In [None]:
# get toknizer
sp = spm.SentencePieceProcessor()
sp.load("chatbot_tokenizer.model")
vocab_size = sp.get_piece_size()


# create actualt encoded data sets
train_dataset = GPTDataset(combined_texts_train,sp, seq_len, vocab_size)
test_dataset = GPTDataset(combined_texts_test,sp, seq_len, vocab_size)
validation_dataset = GPTDataset(combined_texts_val,sp, seq_len, vocab_size)

In [None]:
print(sp.encode("<s> User: Hi! <sep> Bot: Hello </s>", out_type=str))

In [None]:
for token in ["<sep>", "User:", "Bot:","<s>","<unk>","<pad>"]:
    print(f"{token} →", sp.encode(token, out_type=str))



In [None]:
assert sp.piece_to_id("<pad>") == 0
assert sp.piece_to_id("<unk>") == 1
assert sp.piece_to_id("<s>") == 2
assert sp.piece_to_id("</s>") == 3

# Tokenizing as strings should NOT return those IDs
assert sp.encode("<pad>")[0] != 0
assert sp.encode("<unk>")[0] != 1


In [None]:
print(sp.IdToPiece(15))

In [None]:
vocab_size = sp.get_piece_size()
print("Actual tokenizer vocab size:", vocab_size)

In [None]:
# create data loaders
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
val_dataloader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# mask function
def casual_mask(size):
  # for record so i don't forget that unsqueeze add a dimention with 1 in this case
  return torch.tril(torch.ones(size,size)).unsqueeze(0).unsqueeze(0)

In [None]:
# Multi-Head attention block
class MultiHeadAttention(nn.Module):
  def __init__(self, embed_dim, num_heads, head_dim):
    # initilize parent class
    super().__init__()
    # insert vars
    self.num_heads = num_heads
    self.head_dim = head_dim

    # get the results
    self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
    self.out_proj = nn.Linear(embed_dim, embed_dim)

  # feed forward (x -> input and mask for the mask function up)
  def forward(self, x, mask=None):
    # get the shape of the input that way to rearrange the items that's all
    B, T, E = x.shape
    qkv = self.qkv_proj(x).chunk(3, dim=-1)
    Q, K, V = [t.view(B, T,self.num_heads, self.head_dim).transpose(1, 2) for t in qkv]

    # get the attention score
    scores = (Q @ K.transpose(-2,-1)) / self.head_dim ** 0.05

    # check if mask exist before applying it
    if mask is not None:
      scores = scores.masked_fill(mask[:,:,:T,:T] == 0, float("-inf"))

    attn = torch.softmax(scores, dim=-1)
    out = (attn @ V).transpose(1, 2).contiguous().view(B,T,E)

    # return the results
    return self.out_proj(out)

In [None]:
# Decoder Block
class DecoderModel(nn.Module):
  def __init__(self,embed_dim,numb_heads, head_dim ,dropout=0.1):
    super().__init__()
    # get the attention score
    self.attn = MultiHeadAttention(embed_dim, numb_heads, head_dim)

    # normalization layers
    self.norm1 = nn.LayerNorm(embed_dim)
    self.norm2 = nn.LayerNorm(embed_dim)

    # drop out layers
    self.dropout1 = nn.Dropout(dropout)
    self.dropout2 = nn.Dropout(dropout)

    # create actual model block
    self.ff = nn.Sequential(
        # the input layer have embd_dim size and output have embd_dim * 4
        nn.Linear(embed_dim, ff_hidden_dim),
        nn.GELU(),
        nn.Dropout(dropout),
        # then recieve the output who in shape of embd_dim * 4 and return embd_dim output
        nn.Linear(ff_hidden_dim, embed_dim)
    )


  # feed forward
  def forward(self, x, mask):
    # normlize the input then calculate the score
    x = x + self.dropout1(self.attn(self.norm1(x), mask))
    # Feedforward block: LayerNorm -> FFN -> Dropout -> Residual
    x = x + self.dropout2(self.ff(self.norm2(x)))
    return x

In [None]:
# model block (TinyGPT)
class TinyGPT(nn.Module):
  def __init__(self, embed_dim, num_heads,num_layers, vocab_size, seq_len, head_dim ,dropout=0.1):
    super().__init__()
    # initilze model specific params
    self.token_embd = nn.Embedding(vocab_size, embed_dim)
    self.pos_embd = nn.Embedding(seq_len, embed_dim)

    # create the model block (make a decoder block for each layer in the model)
    self.blocks = nn.ModuleList([
        DecoderModel(embed_dim, num_heads, head_dim ,dropout) for _ in range(num_layers)
    ])

    # norm and head
    self.ln = nn.LayerNorm(embed_dim)
    self.dropout = nn.Dropout(dropout)
    # use the vocab size here since we want ever word in vocab to have a chance to be selected then the layer select the higher chance
    self.head = nn.AdaptiveLogSoftmaxWithLoss(
    in_features=embed_dim,
    n_classes=vocab_size,
    cutoffs=[2000, 10000],  # adjust for your vocab size
    div_value=4.0,
    head_bias=False
    )


    # add that to help convergance
    self.head.weight = self.token_embd.weight

  # forward pass
  def forward(self, x, target=None):
    B, T = x.shape
    positions = torch.arange(T, device=x.device).unsqueeze(0)  # (1, T)

    # Embedding + Positional Encoding
    x = self.token_embd(x) + self.pos_embd(positions)  # (B, T, E)
    x = self.dropout(x)

    # Causal Mask for self-attention
    mask = casual_mask(T).to(x.device)
    for block in self.blocks:
        x = block(x, mask)

    # Final layer norm
    x = self.ln(x)  # (B, T, E)

    # Flatten for AdaptiveSoftmax: (B * T, E)
    x = x.view(-1, x.size(-1))

    if target is not None:
        # Flatten target: (B * T,)
        target = target.view(-1)
        output = self.head(x, target)  # returns object with `.loss`
        return output
    else:
        # Inference mode: return log probabilities (B, T, vocab_size)
        log_probs = self.head.log_prob(x)  # (B*T, vocab_size)
        return log_probs.view(B, T, -1)


In [None]:
# create the model instant
model = TinyGPT(
     embed_dim,
     num_heads,
     num_layers,
     vocab_size,
     seq_len,
     head_dim ,
     dropout
)
checkpoint = torch.load(f"best_model_{prev_epochs}.pt")
model.load_state_dict(checkpoint['model_state'])

In [None]:
model.to(device)

In [None]:
def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps):
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

    return LambdaLR(optimizer, lr_lambda)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.95), weight_decay=weight_decay)
# Estimate total steps (epoch * steps_per_epoch)
total_steps = epochs * len(train_dataloader)
warmup_steps = int(0.1 * total_steps)
scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
loss_fn = nn.CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)

In [None]:

optimizer.load_state_dict(checkpoint['optimizer_state'])
scheduler.load_state_dict(checkpoint['scheduler_state'])
start_epoch = checkpoint['epoch'] + 1

In [None]:
def compute_bleu_score(predicted_ids, target_ids, sp):
    predictions = []
    references = []

    for pred, true in zip(predicted_ids, target_ids):
        pred_tokens = pred.tolist()
        true_tokens = true.tolist()

        if 0 in pred_tokens:
            pred_tokens = pred_tokens[:pred_tokens.index(0)]
        if 0 in true_tokens:
            true_tokens = true_tokens[:true_tokens.index(0)]

        pred_text = sp.decode(pred_tokens)
        true_text = sp.decode(true_tokens)

        predictions.append(pred_text)
        references.append(true_text)

    bleu = sacrebleu.corpus_bleu(predictions, [references])
    print(f"[BLEU] Score: {bleu.score:.2f}")
    return bleu.score


In [None]:
def evaluate(model, dataloader, vocab_size, device, sp, silent=False):
    model.eval()
    model.to(device)
    total_loss = 0

    all_preds = []
    all_targets = []

    with torch.no_grad():
        for input_ids, target_ids in tqdm(dataloader, desc="Evaluating", disable=silent):
            input_ids = input_ids.to(device)
            target_ids = target_ids.to(device)

            # ✅ Compute loss using adaptive softmax
            output = model(input_ids, target_ids)
            loss = output.loss
            total_loss += loss.item()

            # ✅ Get predictions via log_probs for BLEU calculation
            log_probs = model(input_ids)  # inference = no targets
            predicted_ids = torch.argmax(log_probs, dim=-1)

            all_preds.extend(predicted_ids.cpu())
            all_targets.extend(target_ids.cpu())

    avg_loss = total_loss / len(dataloader)
    bleu_score = compute_bleu_score(all_preds, all_targets, sp)

    if not silent:
        print(f"[Eval] Loss: {avg_loss:.4f}")
        print(f"[Eval] BLEU: {bleu_score:.2f}")

    return avg_loss, bleu_score


In [None]:
def train_model(
    model,
    train_loader,
    val_loader,
    optimizer,
    vocab_size,
    device,
    epochs,
    clip_grad,
    early_stop_patience,
    sp,  # SentencePiece tokenizer
    scheduler=None,
    save_path="best_model.pth"
):
    model = model.to(device)
    best_val_loss = float("inf")
    patience = 0

    for epoch in range(1, epochs + 1):
        print(f"\n[START] Epoch {epoch}")
        model.train()
        total_loss = 0

        for input_ids, target_ids in tqdm(train_loader, desc=f"Epoch {epoch}"):
            input_ids = input_ids.to(device)
            target_ids = target_ids.to(device)

            optimizer.zero_grad()
            output = model(input_ids, target_ids)   #  pass targets
            loss = output.loss                      #  extract loss from AdaptiveSoftmax output
            loss.backward()                         #  do backprop

            if clip_grad:
                torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)

            optimizer.step()

            total_loss += loss.item()

        # # Compute average train loss and BLEU
        avg_train_loss = total_loss / len(train_loader)

        print(f"[Epoch {epoch}] Train Loss: {avg_train_loss:.4f}")
        # print(f"[Epoch {epoch}] Train BLEU: {bleu_score:.2f}")
        print(f"[Epoch {epoch}] Train PPL: {math.exp(avg_train_loss):.2f}")

        # Run evaluation (val_loss and val_bleu must match above logic)
        val_loss, val_bleu = evaluate(model, val_loader, vocab_size, device, sp)
        if scheduler is not None:
            scheduler.step()

        print(f"[Epoch {epoch}] Val Loss: {val_loss:.4f}")
        print(f"[Epoch {epoch}] Val BLEU: {val_bleu:.2f}")

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience = 0
            torch.save(model.state_dict(), save_path)
            save_model_checkpoint(model, epoch, prev_epochs, optimizer, scheduler)
            print(f"[✓] Saved best model to {save_path}")
        else:
            patience += 1
            print(f"[!] Patience: {patience}/{early_stop_patience}")
            if patience >= early_stop_patience:
                print("[!] Early stopping triggered.")
                break


In [None]:
# start model
print("[START] starting training process...")
train_model(
    model,
    train_dataloader,
    val_dataloader,
    optimizer,
    vocab_size,
    device,
    epochs,
    clip_grad,
    early_stop_patience,
    sp=sp,
    scheduler=scheduler,
    save_path="best_model.pth"
)
print("[FINISHED] Training is over")

In [None]:

save_model_checkpoint(model, 0)

In [None]:
# evaluate model
print("[START] start evaluation process...")
evaluate(model, test_dataloader, vocab_size, device, sp, silent=False)
print("[FINISHED] Revaluation is over")

In [None]:
def check_input_ids(input_ids, vocab_size):
    for i, tok in enumerate(input_ids):
        if not isinstance(tok, int):
            raise TypeError(f"[{i}] Token is not an int: {tok}")
        if tok < 0 or tok >= vocab_size:
            raise ValueError(f"[{i}] Token out of bounds: {tok} (vocab_size={vocab_size})")


In [None]:
def safe_piece_to_id(sp, token, fallback_id):
    try:
        tid = sp.piece_to_id(token)
        return tid if tid != -1 else fallback_id
    except Exception:
        return fallback_id


In [None]:
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    logits = logits.clone()

    if top_k > 0:
        top_k = min(top_k, logits.size(-1))
        threshold = torch.topk(logits, top_k)[0][..., -1, None]
        logits = torch.where(logits < threshold, torch.full_like(logits, filter_value), logits)

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        sorted_probs = F.softmax(sorted_logits, dim=-1)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

        sorted_mask = cumulative_probs > top_p
        sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
        sorted_mask[..., 0] = 0

        indices_to_remove = sorted_mask.scatter(1, sorted_indices, sorted_mask)
        logits = torch.where(indices_to_remove, torch.full_like(logits, filter_value), logits)

    return logits

In [None]:
def sample_response(
    model,
    sp,
    input_text,
    seq_len,
    device,
    max_gen=32,
    top_k=50,
    top_p=0.9,
    temperature=1.0,
    verbose=True
):
    model.eval()
    vocab_size = model.token_embd.num_embeddings
    pad_id = safe_piece_to_id(sp, "<pad>", 0)
    unk_id = safe_piece_to_id(sp, "<unk>", 1)
    eos_id = safe_piece_to_id(sp, "</s>", 3)

    print("[INFO] Starting sample_response")
    print("[INFO] Vocab size (from model):", vocab_size)
    print(f"[INFO] Special token IDs — PAD: {pad_id}, UNK: {unk_id}, EOS: {eos_id}")

    # Tokenize input
    tokens = sp.encode(input_text, out_type=int)
    tokens = [t if 0 <= t < vocab_size else unk_id for t in tokens]
    tokens = tokens[:seq_len - max_gen]
    print("[DEBUG] Truncated & validated tokens:", tokens)

    input_ids = tokens + [pad_id] * (seq_len - len(tokens))
    print(f"[DEBUG] Final input_ids (len={len(input_ids)}):", input_ids)

    # Final safety check
    for i, t in enumerate(input_ids):
        if not isinstance(t, int):
            raise TypeError(f"Token at pos {i} is not int: {t}")
        if not (0 <= t < vocab_size):
            raise ValueError(f"Token at pos {i} is out of bounds: {t} (vocab size: {vocab_size})")

    input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device)
    generated = input_tensor[:, :len(tokens)]  # Start with only real tokens

    print("before the steps loop")
    for step in range(max_gen):
      print("Before no grade")
      with torch.no_grad():
        print("Start bad ids check")
        bad_ids = [t.item() for t in generated[0] if t.item() < 0 or t.item() >= vocab_size]
        if bad_ids:
          raise ValueError(f"🚨 Invalid token IDs found before model call: {bad_ids}")
        else:
          print("bad_ids didn't fire")
        logits = model(generated)

      if torch.any(torch.isnan(logits)) or torch.any(torch.isinf(logits)):
          raise RuntimeError("❌ Logits contain NaN or Inf.")

      next_token_logits = logits[:, -1, :] / temperature
      next_token_logits = torch.where(
          torch.isfinite(next_token_logits),
          next_token_logits,
          torch.full_like(next_token_logits, -1e10)
      )

      filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
      probs = F.softmax(filtered_logits, dim=-1)

      if torch.any(torch.isnan(probs)) or torch.all(probs == 0):
        if verbose:
            print("⚠️ Sampling fallback to [UNK] (all probs invalid)")
            next_token = torch.tensor([[unk_id]], device=device)
      else:
          probs = probs / probs.sum(dim=-1, keepdim=True)
          next_token = torch.multinomial(probs, num_samples=1)

      tok_id = next_token.item()
      if not (0 <= tok_id < vocab_size):
        print(f"[WARN] Sampled token ID {tok_id} is invalid — replacing with [UNK]")
        tok_id = unk_id
        next_token = torch.tensor([[tok_id]], device=device)

      if verbose:
          print(f"[STEP {step}] Sampled token ID: {tok_id} →", sp.decode([tok_id]))

      if tok_id == eos_id:
          if verbose:
              print("[INFO] <EOS> token generated — stopping.")
          break

      generated = torch.cat((generated, next_token), dim=1)

    output_tokens = generated[0].tolist()
    if verbose:
        print("[INFO] Final generated token IDs:", output_tokens)

    response = sp.decode(output_tokens)
    return response

In [None]:
while True:
    user_input = input("User: ")
    if user_input.lower() in {"exit", "quit"}:
        break
    formatted_prompt = f"[BOS] User: {user_input} [SEP] Bot:"
    try:
        response = sample_response(model, sp, formatted_prompt, seq_len, device, verbose=True)
        print("Bot:", response.split("[SEP] Bot:")[-1].strip())
    except Exception as e:
        print("❌ Error:", e)

In [None]:
def simple_greedy_inference(model, sp, input_text, seq_len, device, max_gen=32):
    model.eval()
    model.to(device)

    # Encode the prompt
    tokens = sp.encode(input_text, out_type=int)
    print(f"input tokens are : {tokens}")
    vocab_size = sp.get_piece_size()

    # Validate token IDs
    for i, t in enumerate(tokens):
        if not (0 <= t < vocab_size):
            raise ValueError(f"❌ Token at pos {i} is invalid: {t} (vocab_size={vocab_size})")

    # Truncate prompt to leave space for generation
    tokens = tokens[:seq_len - max_gen]

    # Get padding token safely
    try:
        pad_id = sp.piece_to_id("[PAD]")
    except:
        pad_id = 0  # fallback

    # Pad input to seq_len
    input_ids = tokens + [pad_id] * (seq_len - len(tokens))
    input_tensor = torch.tensor([input_ids], dtype=torch.long).to(device)

    generated = input_tensor[:, :-max_gen]

    for _ in range(max_gen):
        with torch.no_grad():
            logits = model(generated)
            next_token_logits = logits[:, -1, :]  # [B, vocab]

            next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)

            if next_token.item() == sp.piece_to_id("[EOS]"):
                break

            generated = torch.cat((generated, next_token), dim=1)

    # Decode output
    output_tokens = generated[0].tolist()
    assert sp.get_piece_size() > max(output_tokens), "Model is producing OOB token IDs"
    print(f"output tokens are {output_tokens}")
    response = sp.decode(output_tokens)
    return response


In [None]:
prompt_text = "[BOS] User: Hello there! [SEP] Bot:"
response = simple_greedy_inference(model, sp, prompt_text, seq_len, device)
print("Bot:", response.split("[SEP] Bot:")[-1])
