In [1]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq
from datasets import load_dataset, Dataset
from peft import get_peft_model, LoraConfig, TaskType
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Step 1: Load base model (FLAN-T5)
model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)



In [4]:
# Step 2: Apply LoRA
peft_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    bias="none"
)
model = get_peft_model(base_model, peft_config)

In [None]:
# Step 3: Prepare QA data (your domain QA pairs)
def preprocess(example):
    input_text = f"question: {example['question']} context: {example['context']}"
    target_text = example['answer']
    tokenized = tokenizer(
        input_text,
        max_length=512,
        truncation=True,
        padding="max_length"
    )
    label_tokenized = tokenizer(
        target_text,
        max_length=64,
        truncation=True,
        padding="max_length"
    )
    tokenized["labels"] = label_tokenized["input_ids"]
    return tokenized

# QA dataset
data = Dataset.from_list([
    {"question": "What is the side effect of paracetamol?",
     "context": "Paracetamol may cause liver damage in high doses.",
     "answer": "liver damage"},
    {"question": "What is the function of the liver?",
     "context": "The liver helps detoxify chemicals and metabolize drugs.",
     "answer": "helps detoxify chemicals and metabolize drugs"}
])
dataset = data.map(preprocess, remove_columns=data.column_names)

Map: 100%|██████████| 2/2 [00:00<00:00, 305.58 examples/s]


In [None]:
print(data['question'])
print(data['context'])
print(data['answer'])

['liver damage', 'helps detoxify chemicals and metabolize drugs']

In [16]:
print(dataset['input_ids'])
print(dataset['attention_mask'])
print(dataset['labels'])

[[822, 10, 363, 19, 8, 596, 1504, 13, 3856, 9113, 9, 4641, 58, 2625, 10, 4734, 9113, 9, 4641, 164, 1137, 11501, 1783, 16, 306, 6742, 7, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [17]:
# Step 4: Training setup
args = TrainingArguments(
    output_dir="./flan_t5_domain_qa",
    per_device_train_batch_size=4,
    num_train_epochs=3,
    learning_rate=5e-4,
    logging_steps=10,
    save_total_limit=2,
    fp16=True,
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=dataset,
    tokenizer=tokenizer,
    data_collator=DataCollatorForSeq2Seq(tokenizer, model=model)
)

trainer.train()

  trainer = Trainer(
No label_names provided for model class `PeftModelForSeq2SeqLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Step,Training Loss


TrainOutput(global_step=3, training_loss=0.0, metrics={'train_runtime': 1.5294, 'train_samples_per_second': 3.923, 'train_steps_per_second': 1.962, 'total_flos': 4124851568640.0, 'train_loss': 0.0, 'epoch': 3.0})

In [18]:
text = "question: What is the side effect of paracetamol? context: Paracetamol may cause liver damage."
inputs = tokenizer(text, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=20)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

a tyretinine metabolite
