In [2]:
import os
import torch
from transformers import BartForConditionalGeneration, BartTokenizer
import csv

def load_token_mapping(csv_path):
    token_to_string = {}
    with open(csv_path, 'r', newline='', encoding='utf-8') as file:
        reader = csv.reader(file)
        for row in reader:
            token_id, token_str = int(row[0]), row[1]
            token_to_string[token_id] = token_str
    return token_to_string

token_mapping = load_token_mapping('tokens.csv')
model_path = '../saved_models/epoch_36'

class SimpleBART(torch.nn.Module):
    def __init__(self):
        super(SimpleBART, self).__init__()

        if os.path.exists(model_path):
            self.bart = BartForConditionalGeneration.from_pretrained(model_path)
            self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
        else:
            raise ValueError(f"Model not found at {model_path}. Please make sure the path is correct or set load_from_saved to False.")

    def forward(self, input_ids, attention_mask):
        return self.bart(input_ids=input_ids, attention_mask=attention_mask)
    
    def decode_beam_to_tokens(self, beam_output):
        decoded_tokens = []
        for token_id in beam_output[2:]:
            if token_id == self.tokenizer.eos_token_id:
                break
            decoded_tokens.append(token_mapping.get(token_id.item(), "<UNK>"))  # default to "<UNK>" if token not found
        return decoded_tokens

model = SimpleBART()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
criterion = torch.nn.CrossEntropyLoss()

In [6]:
source_text = input("Enter the source text: ")

# Tokenize the input text
input_ids = model.tokenizer.encode(source_text, return_tensors="pt")

print(f"> {source_text}")

while True:
    model.train()

    # Beam search for top translations based on the input text
    translation_ids = model.bart.generate(
        input_ids,
        num_beams=4,
        max_length=200,
        num_return_sequences=4,
        eos_token_id=model.tokenizer.eos_token_id
    )

    # Display the options
    for i, option_ids in enumerate(translation_ids, 1):
        # Print the decoded strings for each option
        print(f"Option {i}: {model.decode_beam_to_tokens(option_ids)}")
        print(f"  {option_ids.tolist()}", flush=True)

    choice = int(input("Which option is correct? (1/2/3/4): "))
    chosen_translation_ids = translation_ids[choice - 1]

    target_ids = chosen_translation_ids.unsqueeze(0) # Add batch dimension

    # Run the model again to get the logits for the original input
    outputs = model(input_ids, None)
    logits = outputs[0]

    # Resize for the loss function
    target_ids = target_ids[:, :logits.shape[1]]
    logits = logits[:, :target_ids.shape[1]]

    loss = criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"Loss: {loss.item()}\n")

> red horse
Option 1: ['horse']
  [2, 0, 457, 2, 1]
Option 2: ['horse', 'red']
  [2, 0, 457, 578, 2]
Option 3: ['horse', 'rough']
  [2, 0, 457, 211, 2]
Option 4: ['horse', 'grey']
  [2, 0, 457, 121, 2]


ValueError: invalid literal for int() with base 10: ''