In [None]:
!pip install transformers
!pip install accelerate
!pip install bitsandbytes
!pip install peft
!pip install datasets
!pip install tqdm

In [None]:
import torch
from tqdm import tqdm
from datasets import load_dataset, Dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from huggingface_hub import notebook_login

In [None]:
model_name = "google/gemma-7b"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
max_token = 2048

In [None]:
# Instantiate the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

In [None]:
# Tokenization functions
def tokenize_function(row):
    return tokenizer(row["text"], padding="max_length", max_length=max_token)

def is_shorter_than_max_token(row):
    """
    Return if a given row has more than max_token number of tokens
    """
    return len(row['input_ids']) <= max_token

def format_conversation(row):
    text = "Paper: " + row["article"] + "\nSummary: " + row["abstract"] + "<eos>"
    return {"text": text}

In [None]:
# Load and tokenize dataset
dataset = load_dataset("ccdv/arxiv-summarization")

# Turn each row into one sentence
dataset = dataset.map(lambda x: format_conversation(x))
dataset = dataset['train'].select(range(1000))

# Tokenize dataset
dataset = dataset.map(lambda x: tokenize_function(x))

# Filter conversation longer than token limit
dataset = dataset.filter(is_shorter_than_max_token)

# Split train and testing dataset
dataset = dataset.train_test_split(test_size=0.05)

# Load model and preparing for training

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", quantization_config=bnb_config)
model.resize_token_embeddings(len(tokenizer))
model = prepare_model_for_kbit_training(model)

In [None]:
# LORA config
config = LoraConfig(
    r=32,
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules="all-linear"
)

model = get_peft_model(model, config)

# Training

In [None]:
from transformers import TrainerCallback

class EarlyStoppingCallback(TrainerCallback):
    def __init__(self, threshold=0.8):
        self.threshold = threshold

    def on_evaluate(self, args, state, control, metrics, **kwargs):
        eval_loss = metrics.get("eval_loss")
        if eval_loss is not None and eval_loss < self.threshold:
            control.should_training_stop = True
            
callback = EarlyStoppingCallback() 

In [None]:
training_args = TrainingArguments(
    output_dir="output_dir",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=5,
    num_train_epochs=5,
    learning_rate=1e-4,
    evaluation_strategy="epoch",
    warmup_steps=50,
    weight_decay=1e-3,
    optim="paged_adamw_32bit",
    lr_scheduler_type="cosine",
    save_strategy="epoch"
)

trainer = Trainer(
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    args=training_args,
    data_collator=data_collator,
    tokenizer=tokenizer
)
trainer.train()

# Inference