In [13]:

from datasets import load_metric
from transformers import EvalPrediction
import numpy as np

rouge = load_metric('rouge')

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    
    if isinstance(predictions, tuple):
        predictions = predictions[0]
    predictions = np.argmax(predictions, axis=-1)
    
    if isinstance(predictions, np.ndarray):
        predictions = predictions.tolist()
    if isinstance(labels, np.ndarray):
        labels = labels.tolist()
    
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    
    labels = [[token if token != -100 else tokenizer.pad_token_id for token in label] for label in labels]
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    return result

from transformers import T5ForConditionalGeneration, Trainer, TrainingArguments, DataCollatorForSeq2Seq

model = T5ForConditionalGeneration.from_pretrained('t5-small')

training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()


eval_results = trainer.evaluate()
print(f"Evaluation results: {eval_results}")

predictions = trainer.predict(tokenized_datasets['validation'])

if isinstance(predictions.predictions, tuple):
    predictions = predictions.predictions[0]
predictions = np.argmax(predictions, axis=-1)

decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)

for i, example in enumerate(tokenized_datasets['validation']['query'][:5]):
    print(f"Query: {example}")
    print(f"Generated Response: {decoded_predictions[i]}")
    print(f"Reference Response: {tokenized_datasets['validation']['response'][i]}")
    print()




Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,3.630521,33.038626,8.2136,26.56765,26.649188
2,No log,3.511202,33.335753,8.488953,27.186731,27.292487
3,No log,3.469456,33.946273,8.825708,27.64523,27.749164


Evaluation results: {'eval_loss': 3.469456434249878, 'eval_rouge1': 33.9462727331739, 'eval_rouge2': 8.82570798630188, 'eval_rougeL': 27.645229950659445, 'eval_rougeLsum': 27.749163932363658, 'eval_runtime': 3.8578, 'eval_samples_per_second': 7.776, 'eval_steps_per_second': 1.037, 'epoch': 3.0}
Query: Can I pre-order an item?
Generated Response: KannKönnen, I pre pre me price?? email.? the order address. you can send you. you-order your. available? Can Can Can Kann
Reference Response: Certainly. Can you please provide the product name or SKU and your email address so we can notify you when pre-orders are available?

Query: Do you offer gift wrapping?
Generated Response: Off offer not gift wrapping? gift gift? you offer offer gift gift?? gift.? you can offer yourif you wrapping is available. Offer Off Off
Reference Response: We do offer gift wrapping for select items. Can you please provide the product name or SKU so we can confirm if gift wrapping is available?

Query: Do you offer pri

In [19]:
from IPython.display import display, HTML
import ipywidgets as widgets

def get_reference_response(query):
    reference_response = None
    for idx, item in enumerate(tokenized_datasets['validation']['query']):
        if item == query:
            reference_response = tokenized_datasets['validation']['response'][idx]
            break
    
    return reference_response

input_box = widgets.Text(
    value='',
    placeholder='Type your query here...',
    description='Query:',
    disabled=False
)

output_box = widgets.Output()

def on_button_clicked(b):
    with output_box:
        output_box.clear_output()
        query = input_box.value
        reference_response = get_reference_response(query)
        if reference_response:
            print("Reference Response:", reference_response)
        else:
            print("Reference Response: Not found")

button = widgets.Button(description="Get Response")
button.on_click(on_button_clicked)

display(input_box, button, output_box)


Text(value='', description='Query:', placeholder='Type your query here...')

Button(description='Get Response', style=ButtonStyle())

Output()