# Fine-tuning Llama 3 with Distributed Training on AWS

This notebook fine-tunes the Llama 3 8B model using distributed training with TensorFlow's MirroredStrategy on a g5.4xlarge instance with 2 A10 GPUs.

## Features:
- Distributed training across multiple GPUs using MirroredStrategy
- Lion optimizer for memory efficiency
- Dynamic sequence padding
- Configurable hyperparameters
- Checkpointing and early stopping

## Setup Instructions

Before running this notebook on AWS:

1. **Launch g5.4xlarge instance** with Deep Learning AMI
2. **Upload this notebook** to the instance
3. **Upload `massive_dataset.tar.gz`** to the instance (from your local massive_datasets_max folder)
4. **Run the setup cells** below to clone the repo and extract data
5. **Install dependencies**: `pip install transformers datasets keras tensorflow`

## AWS Instance Specs
- **Instance**: g5.4xlarge
- **GPUs**: 2x NVIDIA A10G (24GB each)
- **vCPUs**: 16
- **RAM**: 128GB

In [None]:
# Clone the repository
!git clone https://github.com/Mazzlabs/sys-scan-agent_MLops.git
%cd sys-scan-agent_MLops/ml_pipeline

In [None]:
# Upload and extract the dataset
# Note: Manually upload massive_dataset.tar.gz to the instance first
!tar -xzf ../massive_dataset.tar.gz -C ./
!ls -la massive_datasets_max/ | head -10

In [None]:
# Install required dependencies
!pip install transformers datasets keras tensorflow ipywidgets

In [None]:
import os
import keras
from keras import layers, ops
from transformers import AutoTokenizer, TFAutoModelForCausalLM
from datasets import load_from_disk
import numpy as np
import tensorflow as tf
import ipywidgets as widgets
from IPython.display import display

# Set Keras backend to TensorFlow
os.environ['KERAS_BACKEND'] = 'tensorflow'

In [None]:
def preprocess_function(examples, tokenizer):
    """
    Preprocesses text data with dynamic padding.
    """
    inputs = tokenizer(
        examples["text"],
        truncation=True,
        padding="longest",
        return_tensors="tf"
    )
    inputs["labels"] = inputs["input_ids"]
    return inputs

In [None]:
def create_tf_dataset(dataset, tokenizer, batch_size):
    """
    Creates a TensorFlow dataset from a Hugging Face dataset.
    """
    def generator():
        for example in dataset:
            processed = preprocess_function({"text": example["text"]}, tokenizer)
            yield {
                "input_ids": processed["input_ids"],
                "attention_mask": processed["attention_mask"],
                "labels": processed["labels"]
            }

    output_signature = {
        "input_ids": tf.TensorSpec(shape=(None,), dtype=tf.int32),
        "attention_mask": tf.TensorSpec(shape=(None,), dtype=tf.int32),
        "labels": tf.TensorSpec(shape=(None,), dtype=tf.int32)
    }

    tf_dataset = tf.data.Dataset.from_generator(
        generator,
        output_signature=output_signature
    )

    tf_dataset = tf_dataset.shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return tf_dataset

In [None]:
# Configuration Widgets
epochs_widget = widgets.IntSlider(value=1, min=1, max=10, description='Epochs:')
learning_rate_widget = widgets.FloatLogSlider(value=2e-4, min=-6, max=-2, step=0.1, description='Learning Rate:')
beta_1_widget = widgets.FloatSlider(value=0.9, min=0.0, max=1.0, step=0.01, description='Beta 1:')
beta_2_widget = widgets.FloatSlider(value=0.99, min=0.0, max=1.0, step=0.01, description='Beta 2:')
weight_decay_widget = widgets.FloatSlider(value=0.01, min=0.0, max=0.1, step=0.001, description='Weight Decay:')

display(epochs_widget, learning_rate_widget, beta_1_widget, beta_2_widget, weight_decay_widget)

In [None]:
def train_with_keras():
    """
    Fine-tunes the Llama 3 model using TensorFlow's MirroredStrategy for multi-GPU training.
    """
    strategy = tf.distribute.MirroredStrategy()
    print(f"✅ Found {strategy.num_replicas_in_sync} GPUs. Using MirroredStrategy.")

    # Get values from widgets
    epochs = epochs_widget.value
    learning_rate = learning_rate_widget.value
    beta_1 = beta_1_widget.value
    beta_2 = beta_2_widget.value
    weight_decay = weight_decay_widget.value

    dataset_path = "./processed_dataset"
    print(f"Loading pre-processed dataset from {dataset_path}...")
    split_dataset = load_from_disk(dataset_path)

    model_name = "meta-llama/Meta-Llama-3-8B"
    new_model_name = "sys-scan-llama-agent-keras3-lion"

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    batch_size_per_replica = 8
    global_batch_size = batch_size_per_replica * strategy.num_replicas_in_sync

    print(f"Batch size per GPU: {batch_size_per_replica}")
    print(f"Global batch size: {global_batch_size}")

    train_dataset = create_tf_dataset(split_dataset['train'], tokenizer, global_batch_size)
    val_dataset = create_tf_dataset(split_dataset['validation'], tokenizer, global_batch_size)

    with strategy.scope():
        print(f"Loading model {model_name} for Keras 3...")
        model = TFAutoModelForCausalLM.from_pretrained(
            model_name,
            return_dict=True
        )

        print("Creating Lion optimizer with configured parameters...")
        optimizer = keras.optimizers.Lion(
            learning_rate=learning_rate,
            beta_1=beta_1,
            beta_2=beta_2,
            weight_decay=weight_decay
        )

        model.compile(
            optimizer=optimizer,
            loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            metrics=[keras.metrics.SparseCategoricalAccuracy()]
        )

    callbacks = [
        keras.callbacks.ModelCheckpoint(
            filepath=f"./checkpoints/{new_model_name}_epoch_{{epoch:02d}}",
            save_freq='epoch',
            save_weights_only=True
        ),
        keras.callbacks.TensorBoard(log_dir="./logs"),
        keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True),
        keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-6)
    ]

    os.makedirs("./checkpoints", exist_ok=True)
    os.makedirs("./logs", exist_ok=True)

    print("\n🚀 Starting fine-tuning with Keras, Lion, and MirroredStrategy...")
    
    history = model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=epochs,
        callbacks=callbacks,
        verbose=1
    )

    print("✅ Fine-tuning completed!")

    print(f"Saving model to {new_model_name}...")
    model.save_pretrained(new_model_name)
    tokenizer.save_pretrained(new_model_name)

    print(f"🎉 Model and tokenizer saved to {new_model_name}")

    return history

In [None]:
# Run the training
history = train_with_keras()