In [1]:
from datasets import load_dataset, load_from_disk
from transformers import T5Tokenizer, T5ForConditionalGeneration, TrainingArguments, Trainer, DataCollatorForSeq2Seq
import evaluate
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = load_dataset("json", data_files={"train": "../data/cleaned_pubmed_qa.json"}, split="train")
print("Loaded full dataset with", len(dataset), "samples")

Loaded full dataset with 195696 samples


In [3]:
dataset_split = dataset.train_test_split(test_size=0.1)
full_train_data = dataset_split["train"]
full_eval_data = dataset_split["test"]

In [16]:
small_train_data = full_train_data.shuffle(seed=42).select(range(1000))
small_eval_data = full_eval_data.shuffle(seed=42).select(range(100))

In [17]:
small_train_data

Dataset({
    features: ['question', 'context', 'answer'],
    num_rows: 1000
})

In [18]:
model_checkpoint = "t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_checkpoint)
model = T5ForConditionalGeneration.from_pretrained(model_checkpoint)

In [19]:
def preprocess_function(examples):
    inputs = ["question: " + q + " context: " + c for q, c in zip(examples["question"], examples["context"])]
    targets = examples["answer"]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=128, truncation=True, padding="max_length")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [20]:
tokenized_train = small_train_data.map(preprocess_function, batched=True)
tokenized_eval = small_eval_data.map(preprocess_function, batched=True)
print("Tokenization complete")

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map: 100%|██████████| 1000/1000 [00:01<00:00, 748.96 examples/s]
Map: 100%|██████████| 100/100 [00:00<00:00, 669.18 examples/s]

Tokenization complete





In [21]:
rouge = evaluate.load("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    if isinstance(predictions, tuple):
        predictions = predictions[0]
    pred_ids = np.argmax(predictions, axis=-1)
    decoded_preds = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    final_result = {}
    for key, value in result.items():
        if hasattr(value, 'mid'):
            final_result[key] = value.mid.fmeasure * 100
        else:
            final_result[key] = value * 100
    return final_result

In [22]:
training_args = TrainingArguments(
    output_dir="../model/final_qa_small",
    eval_strategy="steps",
    eval_steps=50,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=1,
    logging_dir="../model/logs",
    logging_steps=10,
    save_steps=1e6,
    gradient_accumulation_steps=8,
    warmup_steps=10,
    weight_decay=0.01,
    report_to="none"
)

In [23]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_eval,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

  trainer = Trainer(


In [24]:
trainer.train()
trainer.save_model("../model/final_qa_small")

Step,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
50,2.077,1.761953,39.44637,12.131592,33.966887,33.923502
100,1.6209,1.656699,40.122698,12.829875,35.266787,35.225219


In [25]:
metrics = trainer.evaluate(eval_dataset=tokenized_eval)
print("Final ROUGE Metrics on Small Eval Set:")
for k, v in metrics.items():
    print(f"{k}: {v:.2f}")



Final ROUGE Metrics on Small Eval Set:
eval_loss: 1.64
eval_rouge1: 40.15
eval_rouge2: 12.97
eval_rougeL: 35.42
eval_rougeLsum: 35.39
eval_runtime: 46.09
eval_samples_per_second: 2.17
eval_steps_per_second: 2.17
epoch: 1.00
