In [1]:
!pip install -q pandas numpy torch scikit-learn tqdm huggingface_hub
!pip install -U bitsandbytes typing_extensions
!pip install -U peft transformers accelerate tensorboard

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Collecting bitsandbytes
  Downloading bitsandbytes-0.45.4-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Collecting typing_extensions
  Downloading typing_extensions-4.13.0-py3-none-any.whl.metadata (3.0 kB)
Downloading bitsandbytes-0.45.4-py3-none-manylinux_2_24_x86_64.whl (76.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.0/76.0 MB[0m [31m98.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hDownloading typing_extensions-4.13.0-py3-none-any.whl (45 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.7/45.7 kB[0m [31m13.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: typing_extensions, bitsandbytes
  Attempting uninstall: typing_extensions
    Fo

# Gemma 3 Model Training with LoRA

This notebook implements the training pipeline for Google's Gemma 3 model using LoRA (Low-Rank Adaptation) for efficient fine-tuning.

Features:
1. LoRA implementation
2. Multi-metric early stopping
3. Evaluation metrics tracking
4. Multimodal capabilities (text classification)

In [3]:
import os
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import (
    get_peft_model,
    LoraConfig,
    TaskType,
    prepare_model_for_kbit_training
)
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    cohen_kappa_score,
    matthews_corrcoef,
    roc_auc_score
)
from sklearn.model_selection import train_test_split
from torch.utils.tensorboard import SummaryWriter

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

## Dataset Preparation

In [4]:
class FinancialTweetDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

## Early Stopping Implementation

In [5]:
class EarlyStoppingCallback:
    def __init__(self, patience=3, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_metrics = None
        self.early_stop = False
        
    def __call__(self, metrics):
        if self.best_metrics is None:
            self.best_metrics = metrics
            return False
        
        # Check if any metric improved by min_delta
        improved = False
        for metric, value in metrics.items():
            if value > self.best_metrics[metric] + self.min_delta:
                improved = True
                self.best_metrics = metrics
                break
        
        if not improved:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.counter = 0
        
        return self.early_stop

In [6]:
def calculate_metrics(predictions, labels):
    """Calculate multiple evaluation metrics"""
    pred_labels = np.argmax(predictions, axis=1)
    
    # Basic metrics
    accuracy = accuracy_score(labels, pred_labels)
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, pred_labels, average='weighted'
    )
    
    # Additional metrics
    kappa = cohen_kappa_score(labels, pred_labels)
    mcc = matthews_corrcoef(labels, pred_labels)
    
    # ROC-AUC (multi-class)
    try:
        roc_auc = roc_auc_score(labels, predictions, multi_class='ovr')
    except:
        roc_auc = 0.0
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'kappa': kappa,
        'mcc': mcc,
        'roc_auc': roc_auc
    }

## Data Loading and Preprocessing

In [7]:
# Load labeled data
df = pd.read_csv('all_labeled_tweets.csv')

# Convert labels to numeric
label_map = {
    'STRONGLY_POSITIVE': 0,
    'POSITIVE': 1,
    'NEUTRAL': 2,
    'NEGATIVE': 3,
    'STRONGLY_NEGATIVE': 4
}
df['label'] = df['sentiment'].map(label_map)

# Split data
train_texts, val_texts, train_labels, val_labels = train_test_split(
    df['description'].values, df['label'].values,
    test_size=0.2, random_state=42
)

## Gemma 3 Model Setup

We'll be using the 12B parameter pretrained Gemma 3 model from Google.

In [8]:
# Initialize tokenizer for Gemma 3
model_name = "google/gemma-3-4b-pt"
# Add Hugging Face authentication
from huggingface_hub import login
# Replace 'your_token_here' with your actual token or use environment variables
# You can get a token from https://huggingface.co/settings/tokens
import os
hf_token = os.environ.get("HF_TOKEN", None)
if hf_token:
    login(token=hf_token)
else:
    print("⚠️ Hugging Face token not found! Please set the HF_TOKEN environment variable.")
    print("You need to log in to access Gemma 3, which is a gated model. Visit:")
    print("https://huggingface.co/google/gemma-3-4b-pt and accept the license")
    login()  # Interactive login if running in interactive environment

tokenizer = AutoTokenizer.from_pretrained(model_name)

# Create datasets
train_dataset = FinancialTweetDataset(train_texts, train_labels, tokenizer)
val_dataset = FinancialTweetDataset(val_texts, val_labels, tokenizer)

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

In [9]:
# Initialize model with LoRA config
# Load in 8-bit precision to reduce memory requirements
# Initialize model with LoRA config
# Load in 8-bit precision to reduce memory requirements
from transformers import AutoModelForCausalLM

# Gemma 3 doesn't work with AutoModelForSequenceClassification
# Let's use AutoModelForCausalLM which works with Gemma 3 models
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    token=hf_token  # Pass the token for authentication
)

# Add a sequence classification head on top
from transformers.modeling_utils import PreTrainedModel
from torch import nn

class GemmaForSequenceClassification(nn.Module):
    def __init__(self, base_model, tokenizer, num_labels=5):
        super().__init__()
        self.base_model = base_model
        self.tokenizer = tokenizer
        self.num_labels = num_labels
        self._config = base_model.config  # Store the config
        
        # Create the prompt template
        self.prompt = """You are a financial sentiment analyzer. Classify the given tweet's sentiment into one of these categories:

                        STRONGLY_POSITIVE - Very bullish, highly confident optimistic outlook
                        POSITIVE - Generally optimistic, bullish view
                        NEUTRAL - Factual, balanced, or no clear sentiment
                        NEGATIVE - Generally pessimistic, bearish view
                        STRONGLY_NEGATIVE - Very bearish, highly confident pessimistic outlook

                        Examples:
                        "Breaking: Company XYZ doubles profit forecast!" -> STRONGLY_POSITIVE
                        "Expecting modest gains next quarter" -> POSITIVE
                        "Market closed at 35,000" -> NEUTRAL
                        "Concerned about rising rates" -> NEGATIVE
                        "Crash incoming, sell everything!" -> STRONGLY_NEGATIVE

                        Format: Return only one word from: STRONGLY_POSITIVE, POSITIVE, NEUTRAL, NEGATIVE, STRONGLY_NEGATIVE

                        Analyze the sentiment of this tweet: """        
        self.prompt_ids = tokenizer.encode(self.prompt, add_special_tokens=False)

        # For Gemma 3, use model_dim instead of hidden_size
        if hasattr(base_model.config, "model_dim"):
            hidden_dim = base_model.config.model_dim
        elif hasattr(base_model.config, "hidden_size"):
            hidden_dim = base_model.config.hidden_size
        else:
            # Default value for Gemma 3 4B model is 262208
            print("Warning: Using default hidden dimension of 262208 for Gemma 3 model")
            hidden_dim = 262208
            
        print(f"Using hidden dimension: {hidden_dim}")

        # Get device and dtype from base model
        device = next(base_model.parameters()).device
        dtype = next(base_model.parameters()).dtype
        print(f"Base model device: {device}, dtype: {dtype}")

        # Create classifier with matching dtype and move to correct device
        self.classifier = nn.Linear(hidden_dim, num_labels, dtype=dtype).to(device)
        print(f"Moved classifier to device: {device}")

    # Add a property to expose the config
    @property
    def config(self):
        return self._config
        
    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        batch_size = input_ids.shape[0]
        
        # Prepend prompt to each input
        prompt_ids_tensor = torch.tensor([self.prompt_ids] * batch_size, device=input_ids.device)
        modified_input_ids = torch.cat([prompt_ids_tensor, input_ids], dim=1)
        
        # Adjust attention mask
        prompt_attention = torch.ones(batch_size, len(self.prompt_ids), device=attention_mask.device)
        modified_attention_mask = torch.cat([prompt_attention, attention_mask], dim=1)
        
        # Run the model with the prompt
        outputs = self.base_model(
            input_ids=modified_input_ids, 
            attention_mask=modified_attention_mask
        )
        
        # For Gemma 3, handle different output structures
        if hasattr(outputs, "last_hidden_state"):
            hidden_states = outputs.last_hidden_state
        elif hasattr(outputs, "hidden_states") and outputs.hidden_states is not None:
            # Use the last layer's hidden states if available
            hidden_states = outputs.hidden_states[-1]
        else:
            # For CausalLM models, we can often use the logits
            # but first reshape them to get the hidden representation
            # print("Using logits for classification - this might not give optimal results")
            hidden_states = outputs.logits
        
        # Use a pooled representation of the sequence after the prompt
        # This combines the prompt context with the tweet content
        prompt_length = len(self.prompt_ids)
        
        # Extract non-prompt tokens
        relevant_states = hidden_states[:, prompt_length:, :]
        relevant_mask = modified_attention_mask[:, prompt_length:]
        
        # Mean pooling over the relevant tokens
        mask_expanded = relevant_mask.unsqueeze(-1).expand(relevant_states.size()).to(dtype=relevant_states.dtype)
        sum_hidden = torch.sum(relevant_states * mask_expanded, 1)
        count = torch.clamp(torch.sum(mask_expanded, 1), min=1e-9)  # Avoid division by zero
        pooled_output = sum_hidden / count
        
        # Apply the classification head
        logits = self.classifier(pooled_output)
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)
            
        class SequenceClassifierOutput:
            def __init__(self, loss, logits):
                self.loss = loss
                self.logits = logits
        
        return SequenceClassifierOutput(loss, logits)
    
    # Add methods required for PEFT with CAUSAL_LM
    def prepare_inputs_for_generation(self, *args, **kwargs):
        """
        This method is required by PEFT for CAUSAL_LM task type.
        It delegates to the base model's method.
        """
        return self.base_model.prepare_inputs_for_generation(*args, **kwargs)

    def get_output_embeddings(self):
        """Return the output embeddings from the base model if needed for generation"""
        return self.base_model.get_output_embeddings()

    # Forward all attribute requests that we don't handle to the base model
    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            # If we don't have the attribute, try getting it from base_model
            return getattr(self.base_model, name)
    
# Wrap the model with our classification head
model = GemmaForSequenceClassification(model, tokenizer, num_labels=5)  # 5 classes for the sentiment labels
# LoRA configuration
lora_config = LoraConfig(
    r=8,  # rank
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"], # Target the attention layers
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.CAUSAL_LM  # Changed from SEQ_CLS to CAUSAL_LM for Gemma 3
)

# Prepare model for LoRA
model = get_peft_model(model, lora_config)

# Print trainable parameters percentage
model.print_trainable_parameters()

config.json:   0%|          | 0.00/815 [00:00<?, ?B/s]



model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

Using hidden dimension: 262208
Base model device: cuda:0, dtype: torch.bfloat16
Moved classifier to device: cuda:0
trainable params: 3,223,552 || all params: 4,304,614,069 || trainable%: 0.0749


## Test For Model dtype


In [10]:
# Try alternative loading configurations for the Gemma 3 model
# First, make sure you have the necessary libraries
!pip install -q transformers accelerate bitsandbytes

# Load the model with more stable settings
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

# Option 1: Try with BF16 precision instead of FP16
print("Loading model with BF16 precision...")
try:
    model_bf16 = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,  # Use bfloat16 instead of float16
        device_map="auto",
        token=hf_token
    )
    
    # Simple test
    test_prompt = "Hello, how are you?"
    inputs = tokenizer(test_prompt, return_tensors="pt").to(model_bf16.device)
    
    with torch.no_grad():
        outputs = model_bf16.generate(
            inputs.input_ids,
            max_new_tokens=10,
            do_sample=False  # Deterministic generation first
        )
    
    print("BF16 test succeeded!")
    print(f"Output: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
    working_model = model_bf16
    
except Exception as e:
    print(f"BF16 loading failed with error: {str(e)}")
    working_model = None

# Option 2: Try with 4-bit quantization
if working_model is None:
    print("\nTrying with 4-bit quantization...")
    try:
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float32,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True
        )
        
        model_4bit = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=quantization_config,
            device_map="auto",
            token=hf_token
        )
        
        # Simple test
        test_prompt = "Hello, how are you?"
        inputs = tokenizer(test_prompt, return_tensors="pt").to(model_4bit.device)
        
        with torch.no_grad():
            outputs = model_4bit.generate(
                inputs.input_ids,
                max_new_tokens=10,
                do_sample=False
            )
        
        print("4-bit quantization test succeeded!")
        print(f"Output: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
        working_model = model_4bit
        
    except Exception as e:
        print(f"4-bit loading failed with error: {str(e)}")

# Option 3: Try with 8-bit quantization
if working_model is None:
    print("\nTrying with 8-bit quantization...")
    try:
        model_8bit = AutoModelForCausalLM.from_pretrained(
            model_name,
            load_in_8bit=True,
            device_map="auto",
            token=hf_token
        )
        
        # Simple test
        test_prompt = "Hello, how are you?"
        inputs = tokenizer(test_prompt, return_tensors="pt").to(model_8bit.device)
        
        with torch.no_grad():
            outputs = model_8bit.generate(
                inputs.input_ids,
                max_new_tokens=10,
                do_sample=False
            )
        
        print("8-bit quantization test succeeded!")
        print(f"Output: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
        working_model = model_8bit
        
    except Exception as e:
        print(f"8-bit loading failed with error: {str(e)}")

# Option 4: Try with full 32-bit precision (will use more memory)
if working_model is None:
    print("\nTrying with full 32-bit precision...")
    try:
        model_32bit = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float32,
            device_map="auto",
            token=hf_token
        )
        
        # Simple test
        test_prompt = "Hello, how are you?"
        inputs = tokenizer(test_prompt, return_tensors="pt").to(model_32bit.device)
        
        with torch.no_grad():
            outputs = model_32bit.generate(
                inputs.input_ids,
                max_new_tokens=10,
                do_sample=False
            )
        
        print("32-bit precision test succeeded!")
        print(f"Output: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
        working_model = model_32bit
        
    except Exception as e:
        print(f"32-bit loading failed with error: {str(e)}")

# If we found a working model configuration, now test it with our sentiment prompt
if working_model is not None:
    print("\n\nTesting working model with sentiment prompt...")
    
    # The sentiment prompt template
    prompt_template = """You are a financial sentiment analyzer. Classify the given tweet's sentiment into one of these categories:

STRONGLY_POSITIVE - Very bullish, highly confident optimistic outlook
POSITIVE - Generally optimistic, bullish view
NEUTRAL - Factual, balanced, or no clear sentiment
NEGATIVE - Generally pessimistic, bearish view
STRONGLY_NEGATIVE - Very bearish, highly confident pessimistic outlook

Examples:
"Breaking: Company XYZ doubles profit forecast!" -> STRONGLY_POSITIVE
"Expecting modest gains next quarter" -> POSITIVE
"Market closed at 35,000" -> NEUTRAL
"Concerned about rising rates" -> NEGATIVE
"Crash incoming, sell everything!" -> STRONGLY_NEGATIVE

Format: Return only one word from: STRONGLY_POSITIVE, POSITIVE, NEUTRAL, NEGATIVE, STRONGLY_NEGATIVE

Analyze the sentiment of this tweet: {}"""
    
    example_tweets = [
        "Breaking: Company XYZ doubles profit forecast!",
        "Expecting modest gains next quarter",
        "Market closed at 35,000",
        "Concerned about rising rates",
        "Crash incoming, sell everything!"
    ]
    
    # Test with the first example
    tweet = example_tweets[0]
    prompt = prompt_template.format(tweet)
    
    inputs = tokenizer(prompt, return_tensors="pt").to(working_model.device)
    
    try:
        with torch.no_grad():
            outputs = working_model.generate(
                inputs.input_ids,
                max_new_tokens=10,
                temperature=0.1,
                do_sample=False
            )
        
        full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
        response = full_output[len(tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True)):]
        
        print(f"\nTweet: {tweet}")
        print(f"Model response: '{response}'")
        
        if not response.strip():
            print("\nModel still not generating a response. Trying with different generation parameters...")
            
            with torch.no_grad():
                outputs = working_model.generate(
                    inputs.input_ids,
                    max_new_tokens=20,  # More tokens
                    temperature=0.7,    # Higher temperature
                    do_sample=True,     # Enable sampling
                    top_p=0.95          # Nucleus sampling
                )
            
            full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
            response = full_output[len(tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True)):]
            
            print(f"Model response with adjusted parameters: '{response}'")
    
    except Exception as e:
        print(f"Error testing sentiment prompt: {str(e)}")
else:
    print("\nAll model loading options failed. Please check your environment configuration.")
    
print("\nAfter you find a working model configuration, update your model loading code accordingly.")

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Loading model with BF16 precision...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



BF16 test succeeded!
Output: Hello, how are you?

I'm trying to use the <code>


Testing working model with sentiment prompt...





Tweet: Breaking: Company XYZ doubles profit forecast!
Model response: '

Answer: STRONGLY_POSITIVE

'

After you find a working model configuration, update your model loading code accordingly.


## Training Setup and Hyperparameters

In [11]:
# Training parameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# We don't need to move the model to device since we're using device_map="auto"
# which handles device placement automatically

# Define optimizer with weight decay
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup

# Training hyperparameters
num_epochs = 20
batch_size = 8  # Smaller batch size due to model size
learning_rate = 2e-4  # Lower learning rate for stability
weight_decay = 0.01
warmup_steps = 100

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

# Total steps for scheduler
total_steps = len(train_loader) * num_epochs

# Initialize optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=warmup_steps, 
    num_training_steps=total_steps
)

criterion = torch.nn.CrossEntropyLoss()

# Initialize early stopping
early_stopping = EarlyStoppingCallback(patience=3)

log_dir = '/workspace/logs/gemma3_training'
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir)
print(f"TensorBoard logs will be saved to: {log_dir}")
print("To view training progress, connect to TensorBoard using:")
print(f"  1. In a new terminal: tensorboard --logdir={log_dir} --port=6006")
print("  2. Or access through RunPod port forwarding on port 6006")

Using device: cuda
TensorBoard logs will be saved to: /workspace/logs/gemma3_training
To view training progress, connect to TensorBoard using:
  1. In a new terminal: tensorboard --logdir=/workspace/logs/gemma3_training --port=6006
  2. Or access through RunPod port forwarding on port 6006


## Training Loop with Evaluation

In [12]:
# Training loop
best_metrics = None
best_model_state = None
training_history = {"loss": [], "val_metrics": []}

for epoch in range(num_epochs):
    # Training
    model.train()
    total_loss = 0
    epoch_steps = 0
    
    for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        optimizer.zero_grad()
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = criterion(outputs.logits, labels)
        
        loss.backward()
        
        # Gradient clipping to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        scheduler.step()

        # Log learning rate to TensorBoard
        current_lr = scheduler.get_last_lr()[0]
        writer.add_scalar('Learning_rate', current_lr, epoch * len(train_loader) + epoch_steps)
        
        total_loss += loss.item()
        epoch_steps += 1
    
        # Log batch loss periodically
        if epoch_steps % 10 == 0:
            writer.add_scalar('Loss/train_batch', loss.item(), epoch * len(train_loader) + epoch_steps)
    
    
    avg_loss = total_loss / epoch_steps
    training_history["loss"].append(avg_loss)
    print(f"\nAverage training loss: {avg_loss:.4f}")

    # Log epoch-level training loss
    writer.add_scalar('Loss/train_epoch', avg_loss, epoch)
    
    # Validation
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            all_preds.append(outputs.logits.float().cpu().numpy())  # Convert to float32 first
            all_labels.append(labels.cpu().numpy())
    
    predictions = np.vstack(all_preds)
    true_labels = np.concatenate(all_labels)
    
    # Calculate metrics
    metrics = calculate_metrics(predictions, true_labels)
    training_history["val_metrics"].append(metrics)
    
    print("\nValidation Metrics:")
    for metric, value in metrics.items():
        print(f"{metric}: {value:.4f}")
        # Log each validation metric to TensorBoard
        writer.add_scalar(f'Metrics/{metric}', value, epoch)
    
    # Early stopping check
    if early_stopping(metrics):
        print("\nEarly stopping triggered!")
        break
    
    # Save best model
    if best_metrics is None or metrics['f1'] > best_metrics['f1']:
        best_metrics = metrics
        # For PEFT models, we save the state_dict of the adapter instead of the entire model
        best_model_state = {k: v.clone() for k, v in model.state_dict().items() if "lora" in k}
        print("New best model saved!")
        
        # Log model improvement
        writer.add_text('Training/best_model_update', f"New best model at epoch {epoch+1} with f1: {metrics['f1']:.4f}", epoch)

Epoch 1/20:   0%|          | 0/2818 [00:00<?, ?it/s]


Average training loss: 0.6676


Validation:   0%|          | 0/705 [00:00<?, ?it/s]


Validation Metrics:
accuracy: 0.7533
precision: 0.7065
recall: 0.7533
f1: 0.6843
kappa: 0.2029
mcc: 0.2684
roc_auc: 0.0000
New best model saved!


Epoch 2/20:   0%|          | 0/2818 [00:00<?, ?it/s]


Average training loss: 0.4936


Validation:   0%|          | 0/705 [00:00<?, ?it/s]


Validation Metrics:
accuracy: 0.7775
precision: 0.7617
recall: 0.7775
f1: 0.7644
kappa: 0.4365
mcc: 0.4426
roc_auc: 0.0000
New best model saved!


Epoch 3/20:   0%|          | 0/2818 [00:00<?, ?it/s]


Average training loss: 0.3892


Validation:   0%|          | 0/705 [00:00<?, ?it/s]


Validation Metrics:
accuracy: 0.7732
precision: 0.7734
recall: 0.7732
f1: 0.7727
kappa: 0.4817
mcc: 0.4818
roc_auc: 0.0000
New best model saved!


Epoch 4/20:   0%|          | 0/2818 [00:00<?, ?it/s]


Average training loss: 0.2895


Validation:   0%|          | 0/705 [00:00<?, ?it/s]


Validation Metrics:
accuracy: 0.7750
precision: 0.7528
recall: 0.7750
f1: 0.7507
kappa: 0.3873
mcc: 0.4063
roc_auc: 0.0000


Epoch 5/20:   0%|          | 0/2818 [00:00<?, ?it/s]


Average training loss: 0.2127


Validation:   0%|          | 0/705 [00:00<?, ?it/s]


Validation Metrics:
accuracy: 0.7647
precision: 0.7548
recall: 0.7647
f1: 0.7580
kappa: 0.4307
mcc: 0.4328
roc_auc: 0.0000


Epoch 6/20:   0%|          | 0/2818 [00:00<?, ?it/s]


Average training loss: 0.1604


Validation:   0%|          | 0/705 [00:00<?, ?it/s]


Validation Metrics:
accuracy: 0.7743
precision: 0.7613
recall: 0.7743
f1: 0.7643
kappa: 0.4422
mcc: 0.4456
roc_auc: 0.0000


Epoch 7/20:   0%|          | 0/2818 [00:00<?, ?it/s]


Average training loss: 0.1192


Validation:   0%|          | 0/705 [00:00<?, ?it/s]


Validation Metrics:
accuracy: 0.7681
precision: 0.7629
recall: 0.7681
f1: 0.7648
kappa: 0.4542
mcc: 0.4547
roc_auc: 0.0000


Epoch 8/20:   0%|          | 0/2818 [00:00<?, ?it/s]


Average training loss: 0.0964


Validation:   0%|          | 0/705 [00:00<?, ?it/s]


Validation Metrics:
accuracy: 0.7643
precision: 0.7564
recall: 0.7643
f1: 0.7597
kappa: 0.4401
mcc: 0.4410
roc_auc: 0.0000


Epoch 9/20:   0%|          | 0/2818 [00:00<?, ?it/s]


Average training loss: 0.0811


Validation:   0%|          | 0/705 [00:00<?, ?it/s]


Validation Metrics:
accuracy: 0.7686
precision: 0.7572
recall: 0.7686
f1: 0.7611
kappa: 0.4385
mcc: 0.4404
roc_auc: 0.0000


Epoch 10/20:   0%|          | 0/2818 [00:00<?, ?it/s]


Average training loss: 0.0671


Validation:   0%|          | 0/705 [00:00<?, ?it/s]


Validation Metrics:
accuracy: 0.7702
precision: 0.7571
recall: 0.7702
f1: 0.7611
kappa: 0.4375
mcc: 0.4401
roc_auc: 0.0000


Epoch 11/20:   0%|          | 0/2818 [00:00<?, ?it/s]


Average training loss: 0.0560


Validation:   0%|          | 0/705 [00:00<?, ?it/s]


Validation Metrics:
accuracy: 0.7640
precision: 0.7595
recall: 0.7640
f1: 0.7614
kappa: 0.4489
mcc: 0.4491
roc_auc: 0.0000


Epoch 12/20:   0%|          | 0/2818 [00:00<?, ?it/s]


Average training loss: 0.0505


Validation:   0%|          | 0/705 [00:00<?, ?it/s]


Validation Metrics:
accuracy: 0.7665
precision: 0.7522
recall: 0.7665
f1: 0.7570
kappa: 0.4247
mcc: 0.4279
roc_auc: 0.0000


Epoch 13/20:   0%|          | 0/2818 [00:00<?, ?it/s]


Average training loss: 0.0408


Validation:   0%|          | 0/705 [00:00<?, ?it/s]


Validation Metrics:
accuracy: 0.7688
precision: 0.7597
recall: 0.7688
f1: 0.7624
kappa: 0.4467
mcc: 0.4479
roc_auc: 0.0000


Epoch 14/20:   0%|          | 0/2818 [00:00<?, ?it/s]


Average training loss: 0.0346


Validation:   0%|          | 0/705 [00:00<?, ?it/s]


Validation Metrics:
accuracy: 0.7675
precision: 0.7624
recall: 0.7675
f1: 0.7640
kappa: 0.4546
mcc: 0.4551
roc_auc: 0.0000


Epoch 15/20:   0%|          | 0/2818 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)




Average training loss: 0.0297


Validation:   0%|          | 0/705 [00:00<?, ?it/s]


Validation Metrics:
accuracy: 0.7728
precision: 0.7638
recall: 0.7728
f1: 0.7674
kappa: 0.4570
mcc: 0.4582
roc_auc: 0.0000


Epoch 16/20:   0%|          | 0/2818 [00:00<?, ?it/s]


Average training loss: 0.0247


Validation:   0%|          | 0/705 [00:00<?, ?it/s]


Validation Metrics:
accuracy: 0.7753
precision: 0.7651
recall: 0.7753
f1: 0.7679
kappa: 0.4559
mcc: 0.4577
roc_auc: 0.0000


Epoch 17/20:   0%|          | 0/2818 [00:00<?, ?it/s]


Average training loss: 0.0203


Validation:   0%|          | 0/705 [00:00<?, ?it/s]


Validation Metrics:
accuracy: 0.7750
precision: 0.7628
recall: 0.7750
f1: 0.7660
kappa: 0.4481
mcc: 0.4508
roc_auc: 0.0000


Epoch 18/20:   0%|          | 0/2818 [00:00<?, ?it/s]


Average training loss: 0.0165


Validation:   0%|          | 0/705 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)




Average training loss: 0.0132


Validation:   0%|          | 0/705 [00:00<?, ?it/s]


Validation Metrics:
accuracy: 0.7702
precision: 0.7610
recall: 0.7702
f1: 0.7644
kappa: 0.4498
mcc: 0.4510
roc_auc: 0.0000

Early stopping triggered!


## Save Model and Results

In [13]:
# Save final model and metrics
output_dir = "/workspace/models/gemma3"
os.makedirs(output_dir, exist_ok=True)

# Save the adapter files for both final and best models
model_path = os.path.join(output_dir, "gemma3_lora_adapter_final")
best_model_path = os.path.join(output_dir, "gemma3_lora_adapter_best")

# Save final model
print(f"Saving final model to {model_path}")
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)

# Save best model if we have one
if best_model_state is not None:
    print(f"Saving best model to {best_model_path}")
    # Create directory for best model
    os.makedirs(best_model_path, exist_ok=True)
    
    # Save the best model state
    # For PEFT models we need to handle the state dict differently
    model.save_pretrained(best_model_path)
    tokenizer.save_pretrained(best_model_path)
    
    # Log model paths to TensorBoard
    writer.add_text('Models', f"Final model: {model_path}\nBest model: {best_model_path}", 0)

# Save training history
history_path = os.path.join(output_dir, "training_history.csv")
pd.DataFrame([{
    "epoch": i+1, 
    "loss": loss, 
    **metrics
} for i, (loss, metrics) in enumerate(zip(training_history["loss"], training_history["val_metrics"]))])\
    .to_csv(history_path, index=False)
print(f"Training history saved to {history_path}")

# Save final performance metrics
metrics_path = os.path.join(output_dir, "metrics.csv")
metrics_df = pd.DataFrame([best_metrics])
metrics_df.to_csv(metrics_path, index=False)
print(f"Best metrics saved to {metrics_path}")

Saving final model to /workspace/models/gemma3/gemma3_lora_adapter_final
Saving best model to /workspace/models/gemma3/gemma3_lora_adapter_best
Training history saved to /workspace/models/gemma3/training_history.csv
Best metrics saved to /workspace/models/gemma3/metrics.csv


## Load and Test the Fine-tuned Model

In [14]:
# Function to load the fine-tuned model
def load_finetuned_model(adapter_path, base_model):
    from peft import PeftModel, PeftConfig
    
    # Load the base model
    # Load the base model
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model,
        device_map="auto",
        token=hf_token  # Pass the token for authentication
    )
    
    # Use our custom classification wrapper
    model = GemmaForSequenceClassification(base_model, num_labels=5)
    
    # Load the PEFT adapter
    model = PeftModel.from_pretrained(model, adapter_path)
    
    return model

# Example of loading and using the model
def predict_sentiment(text, model, tokenizer, label_map_reverse):
    model.eval()
    encoding = tokenizer(text, return_tensors="pt", truncation=True, max_length=128).to(model.device)
    
    with torch.no_grad():
        outputs = model(**encoding)
    
    logits = outputs.logits
    probabilities = torch.nn.functional.softmax(logits, dim=1)
    predicted_class = torch.argmax(probabilities, dim=1).item()
    
    return {
        "sentiment": label_map_reverse[predicted_class],
        "probabilities": {label_map_reverse[i]: prob.item() for i, prob in enumerate(probabilities[0])}
    }

In [15]:
# We'll uncomment and run this after training is complete

# # Reverse the label map for interpretation
# label_map_reverse = {v: k for k, v in label_map.items()}

# # Load the fine-tuned model
# loaded_model = load_finetuned_model(
#     adapter_path="../models/gemma3/gemma3_lora_adapter", 
#     base_model="google/gemma-3-12b-pt"
# )
# loaded_tokenizer = AutoTokenizer.from_pretrained("../models/gemma3/gemma3_lora_adapter")

# # Test with some example tweets
# example_tweets = [
#     "Just announced record profits for Q3! Our company is performing exceptionally well.",
#     "The market is down 2% today, concerning trend continues.",
#     "No significant changes in our stock price today, trading sideways.",
#     "Our competitor's latest product launch is worrying for our market position.",
#     "Just had lunch with friends, the weather is nice today."
# ]

# for tweet in example_tweets:
#     result = predict_sentiment(tweet, loaded_model, loaded_tokenizer, label_map_reverse)
#     print(f"\nTweet: {tweet}")
#     print(f"Predicted sentiment: {result['sentiment']}")
#     print("Probabilities:")
#     for sentiment, prob in sorted(result['probabilities'].items(), key=lambda x: x[1], reverse=True)[:3]:
#         print(f"  {sentiment}: {prob:.4f}")