In [2]:
from dotenv import load_dotenv
from huggingface_hub import login
import os
load_dotenv()
token = os.getenv('HUGGINGFACE_HUB_TOKEN')
login(token=token)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# fine_tune_distilbart_peft_samsum.py
import os
import math
import numpy as np
import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
import evaluate

# ---------- Settings ----------
MODEL_NAME = "sshleifer/distilbart-cnn-12-6"
PEFT_OUTPUT_DIR = "./lora_samsum"
MAX_INPUT_LENGTH = 1024
MAX_TARGET_LENGTH = 128
BATCH_SIZE = 8          # change to fit GPU
NUM_EPOCHS = 3
LEARNING_RATE = 2e-4
WEIGHT_DECAY = 0.01
SEED = 42
# ------------------------------

torch.manual_seed(SEED)
np.random.seed(SEED)

# 1) Load dataset (SAMSum)
dataset = load_dataset("knkarthick/samsum")  # Hugging Face dataset page: knkarthick/samsum
# Dataset fields: 'dialogue' and 'summary' (dataset card shows standard fields)
print(dataset)

# 2) Tokenizer & model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
# recommended when training with adapters
model.config.use_cache = False

# 3) Preprocess function
def preprocess(batch):
    model_inputs = tokenizer(
        batch["dialogue"],
        max_length=MAX_INPUT_LENGTH,
        truncation=True,
        padding="max_length"
    )

    labels = tokenizer(
        batch["summary"],
        max_length=MAX_TARGET_LENGTH,
        truncation=True,
        padding="max_length"
    )

    labels["input_ids"] = [
        [(l if l != tokenizer.pad_token_id else -100) for l in label]
        for label in labels["input_ids"]
    ]

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


# 4) Map tokenize (batched)
tokenized_datasets = dataset.map(
    preprocess,
    batched=True,
    remove_columns=dataset["train"].column_names
)

# 5) Data collator
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

# 6) Setup PEFT (LoRA)
# Choose target_modules: for BART-like models, q_proj and v_proj are common. Add others if desired.
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)

# Wrap model — only LoRA params are trainable
model = get_peft_model(model, lora_config)

# 7) Training arguments (Seq2Seq)
training_args = Seq2SeqTrainingArguments(
    output_dir="./distilbart_peft_out",
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    predict_with_generate=True,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="steps",
    logging_steps=200,
    save_total_limit=3,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    num_train_epochs=NUM_EPOCHS,
    fp16=torch.cuda.is_available(),
    remove_unused_columns=True,
    push_to_hub=False,
    hub_model_id="Dhyanesh-AN/distilbart-samsum-lora",
    hub_strategy="end",
    report_to="none",
)

# 8) Metric (ROUGE)
rouge = evaluate.load("rouge")

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [lab.strip() for lab in labels]
    return preds, labels

def compute_metrics(eval_pred):
    generated_tokens, label_tokens = eval_pred
    # decode
    if isinstance(generated_tokens, tuple):
        generated_tokens = generated_tokens[0]
    decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
    # replace -100
    label_tokens = np.where(label_tokens != -100, label_tokens, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(label_tokens, skip_special_tokens=True)
    preds, labels = postprocess_text(decoded_preds, decoded_labels)
    result = rouge.compute(predictions=preds, references=labels, use_stemmer=True)
    # rouge returns dict with lists; get mid scores
    result = {k: round(v*100, 4) for k, v in result.items()}
    # optionally compute length
    result["gen_len"] = np.mean([len(tokenizer.encode(p)) for p in preds])
    return result

# 9) Seq2SeqTrainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

# 10) Train
trainer.train()

# 11) Save the LoRA adapter (and tokenizer)
os.makedirs(PEFT_OUTPUT_DIR, exist_ok=True)
# Save adapter (PEFT)
model.save_pretrained(PEFT_OUTPUT_DIR)
tokenizer.save_pretrained(PEFT_OUTPUT_DIR)
print(f"Saved LoRA adapter to {PEFT_OUTPUT_DIR}")

# 12) Example inference: load base model + LoRA adapter
# (demonstrates how to re-load at inference time)
from transformers import AutoModelForSeq2SeqLM
base_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
base_model.config.use_cache = True
peft_model = PeftModel.from_pretrained(base_model, PEFT_OUTPUT_DIR)

def summarize(text, max_length=128):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=MAX_INPUT_LENGTH).to(peft_model.device)
    gen = peft_model.generate(**inputs, max_length=max_length, num_beams=4, early_stopping=True)
    return tokenizer.decode(gen[0], skip_special_tokens=True)

# test on a validation example
example = dataset["validation"][0]["dialogue"]
print("CONVERSATION:\n", example)
print("SUMMARY:\n", summarize(example))

Generating train split: 100%|██████████| 14731/14731 [00:00<00:00, 95611.26 examples/s]
Generating validation split: 100%|██████████| 818/818 [00:00<00:00, 107086.38 examples/s]
Generating test split: 100%|██████████| 819/819 [00:00<00:00, 108418.60 examples/s]


DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 14731
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 818
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 819
    })
})


Cancellation requested; stopping current tasks.


KeyboardInterrupt: 

In [None]:
from huggingface_hub import HfApi
api = HfApi()
repo_id = "Dhyanesh-AN/distilbart-samsum-lora"

# Push LoRA adapter only if repo does not already exist
try:
    api.model_info(repo_id)
    print(f"Model https://huggingface.co/{repo_id} already exists — skipping push")
except Exception:
    # Repo not found (or another error) — attempt to push
    model.push_to_hub(repo_id)
    tokenizer.push_to_hub(repo_id)
    print(f"Model pushed to https://huggingface.co/{repo_id}")


README.md:   0%|          | 0.00/613 [00:00<?, ?B/s]

Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...adapter_model.safetensors:   9%|8         |  557kB / 6.30MB            

Model pushed to https://huggingface.co/Dhyanesh-AN/distilbart-samsum-lora


In [None]:
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from peft import PeftModel

# 1. Load base model and tokenizer
base_model_name = "sshleifer/distilbart-cnn-12-6"

tokenizer = AutoTokenizer.from_pretrained(base_model_name)
base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)

# 2. Load LoRA adapter
model = PeftModel.from_pretrained(
    base_model,
    "Dhyanesh-AN/distilbart-samsum-lora"
)

model.eval()

# 3. Input dialogue
dialogue = """
#Person1#: Hi, Mr. Smith. I'm Doctor Hawkins. Why are you here today?
#Person2#: I found it would be a good idea to get a check-up.
#Person1#: Yes, well, you haven't had one for 5 years. You should have one every year.
#Person2#: I know. I figure as long as there is nothing wrong, why go see the doctor?
#Person1#: Well, the best way to avoid serious illnesses is to find out about them early.
#Person2#: Ok.
#Person1#: Let me see here. Your eyes and ears look fine. Do you smoke?
#Person2#: Yes.
#Person1#: Smoking causes lung cancer and heart disease. You should quit.
#Person2#: I've tried many times but can't quit.
#Person1#: We have classes and medications that might help.
"""

# 4. Tokenize
inputs = tokenizer(
    dialogue,
    return_tensors="pt",
    truncation=True,
    max_length=1024
)

# 5. Generate summary
with torch.no_grad():
    summary_ids = model.generate(
        **inputs,
        max_length=80,
        min_length=30,
        num_beams=4,
        length_penalty=2.0,
        early_stopping=True
    )

# 6. Decode
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
print("SUMMARY:\n", summary)


SUMMARY:
 #Person1# hasn't had a check-up for 5 years. #Person2# has smoked, but he can't quit smoking.
