# 1. Imports

In [2]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model
from rouge_score import rouge_scorer
import torch

# 2. Load Dataset (small subset for CPU-friendly training)

In [3]:
dataset = load_dataset("cnn_dailymail", "3.0.0")
train_dataset = dataset['train'].select(range(2000))
val_dataset = dataset['validation'].select(range(500))

# 3. Preprocessing Function

In [4]:
def preprocess_text(text):
    text = text.replace("\n", " ")
    text = " ".join(text.split())
    text = text.strip()
    return text

train_dataset = train_dataset.map(
    lambda x: {
        "article": preprocess_text(x["article"]),
        "summary": preprocess_text(x["highlights"])
    }
)

val_dataset = val_dataset.map(
    lambda x: {
        "article": preprocess_text(x["article"]),
        "summary": preprocess_text(x["highlights"])
    }
)



Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

# 4. Load DistilBART (lighter)

In [5]:
model_name = "sshleifer/distilbart-cnn-12-6"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

# 5. Apply LoRA (makes training possible on CPU)

In [6]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()



trainable params: 786,432 || all params: 306,296,832 || trainable%: 0.2568


# 6. Tokenization Function

In [7]:
def tokenize(batch):
    inputs = tokenizer(
        batch["article"],
        max_length=512,
        truncation=True,
        padding="max_length"
    )

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            batch["summary"],
            max_length=128,
            truncation=True,
            padding="max_length"
        )

    inputs["labels"] = labels["input_ids"]
    return inputs

train_dataset = train_dataset.map(tokenize, batched=True)
val_dataset = val_dataset.map(tokenize, batched=True)

# Set format for PyTorch
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])



Map:   0%|          | 0/2000 [00:00<?, ? examples/s]



Map:   0%|          | 0/500 [00:00<?, ? examples/s]

# 7. Training Arguments

In [11]:
training_args = TrainingArguments(
    output_dir="bart_lora",
    num_train_epochs=2,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,   # effective batch size = 8
    optim="adafactor",
    learning_rate=5e-4,
    logging_steps=50,
    save_steps=300,
)



# 8. Trainer

In [12]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

trainer.train()




Step,Training Loss
50,4.541
100,1.8685
150,1.6528
200,1.5996
250,1.544
300,1.4962
350,1.4941
400,1.4856
450,1.4778
500,1.4412


TrainOutput(global_step=500, training_loss=1.8600693817138672, metrics={'train_runtime': 7983.9807, 'train_samples_per_second': 0.501, 'train_steps_per_second': 0.063, 'total_flos': 3105487847424000.0, 'train_loss': 1.8600693817138672, 'epoch': 2.0})

# 9. Save Model

In [21]:
model.save_pretrained("../bart_lora")
tokenizer.save_pretrained("../bart_lora")

('../bart_lora/tokenizer_config.json',
 '../bart_lora/special_tokens_map.json',
 '../bart_lora/vocab.json',
 '../bart_lora/merges.txt',
 '../bart_lora/added_tokens.json',
 '../bart_lora/tokenizer.json')

# 10. Example Inference

In [17]:
example_text = " The two men killed as they floated holding onto their capsized boat in a secondary strike against a suspected drug vessel in early September did not appear to have radio or other communications devices, the top military official overseeing the strike told lawmakers on Thursday, according to three sources with direct knowledge of his congressional briefings.As far back as September, defense officials have been quietly pushing back on criticism that killing the two survivors amounted to a war crime by arguing, in part, that they were legitimate targets because they appeared to be radioing for help or backup — reinforcements that, if they had received it, could have theoretically allowed them to continue to traffic the drugs aboard their sinking ship. Defense officials made that claim in at least one briefing in September for congressional staff, according to a source familiar with the session, and several media outlets cited officials repeating that justification in the last week. But Thursday, Adm. Frank “Mitch” Bradley acknowledged that the two survivors of the military’s initial strike were in no position to make a distress call in his briefings to lawmakers. Bradley was in charge of Joint Special Operations Command at the time of the strike and was the top military officer directing the attack. The initial hit on the vessel, believed to be carrying cocaine, killed nine people immediately and split the boat in half, capsizing it and sending a massive smoke plume into the sky, the sources who viewed the video as part of the briefings said. Part of the surveillance video was a zoomed-in, higher-definition view of the two survivors clinging to a still-floating, capsized portion, they said. For a little under an hour — 41 minutes, according to a separate US official — Bradley and the rest of the US military command center discussed what to do as they watched the men struggle to overturn what was left of their boat, the sources said. "

encoded = tokenizer(example_text, return_tensors="pt", truncation=True, max_length=512)

summary_ids = model.generate(encoded["input_ids"], num_beams=4, max_length=128)
print("\nGenerated Summary:")
result = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
print(result)


Generated Summary:
 The two men killed as they floated holding onto their capsized boat in a secondary strike in early September did not appear to have radio or other communications devices, the top military official overseeing the strike told lawmakers on Thursday . Adm. Frank "Mitch" Bradley was in charge of Joint Special Operations Command at the time of the strike .


In [20]:
print("Total characters: ",len(example_text))
print("Summerized characters: ",len(result))

Total characters:  1965
Summerized characters:  353
