In [None]:
from datasets import load_dataset
from transformers import LlamaForSequenceClassification, LlamaTokenizer, TrainingArguments, Trainer

# Load the dataset
dataset = load_dataset("ag_news")

# Preprocess the data
tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")

def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)

tokenized_dataset = dataset.map(preprocess_function, batched=True)

# Fine-tune the LLaMA model
model = LlamaForSequenceClassification.from_pretrained("decapoda-research/llama-7b-hf", num_labels=4)

training_args = TrainingArguments(
    output_dir="llama-text-classification",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    tokenizer=tokenizer,
)

trainer.train()