In [None]:
!pip install huggingface huggingface_hub transformers

In [None]:
!pip install -U bitsandbytes

In [None]:
!pip install datasets accelerate

In [None]:
!pip install peft

In [None]:
!pip install torchao torchtune

In [None]:
!pip install --upgrade wandb

In [1]:
import os
# Disable the tokenizer parallelism warning to avoid spam
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import math
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler
from tqdm import tqdm
import os
import gc
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    Trainer,
    TrainingArguments,
    get_scheduler,
    BitsAndBytesConfig,
    EarlyStoppingCallback
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from peft.tuners.lora import LoraLayer
from datasets import load_dataset
import wandb
import bitsandbytes as bnb
from typing import Any, Dict, List, Optional, Tuple, Union

from huggingface_hub import login, HfApi, create_repo

In [2]:
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
class KFACAutogradFunction(torch.autograd.Function):
    """Custom autograd function to capture activations and gradients for K-FAC."""
    
    @staticmethod
    def forward(ctx, module, output, input):
        ctx.module = module
        # We only need the input for the backward pass to compute activation stats
        ctx.save_for_backward(input.detach())
        # We pass the output through, it's what the next layer will see
        return output

    @staticmethod
    def backward(ctx, grad_wrt_output):
        # This grad is the gradient w.r.t the output of the LoraLayer
        module = ctx.module
        input, = ctx.saved_tensors # This is the original input to the LoraLayer
        manager = module.grit_manager
        manager.backward_step += 1

        if not module.training:
            # Pass gradients through without modification if not training
            return None, grad_wrt_output, None
            
        # --- Run covariance updates periodically to avoid bottlenecking ---
        if manager.backward_step % manager.config.grit_cov_update_freq != 0:
            return None, grad_wrt_output, None

        with torch.no_grad():
            # --- 1. Project activations into r-space and update covariance ---
            # a has shape (batch*seq, in_features)
            a = input.reshape(-1, input.shape[-1])
            if a.shape[0] > 0 and 'default' in module.lora_A:
                # A has shape (r, in_features), so A.T has (in_features, r)
                # Cast weight to the same dtype as activation `a` to prevent mismatch
                lora_A_T = module.lora_A['default'].weight.data.T.to(device=a.device, dtype=a.dtype, non_blocking=True)
                # projected_a has shape (batch*seq, r)
                projected_a = a @ lora_A_T
                a_cov_sample = projected_a.T @ projected_a
                
                # Online update for a_covs
                current_cov = manager.a_covs[module]
                n = manager.num_samples_a[module]
                new_n = n + projected_a.shape[0]
                if new_n > 0:
                    updated_cov = (current_cov.float() * n + a_cov_sample.cpu().float()) / new_n
                    manager.a_covs[module].copy_(updated_cov)
                    manager.num_samples_a[module] = new_n

            # --- 2. Project gradients into r-space and update covariance ---
            # g has shape (batch*seq, out_features)
            g = grad_wrt_output.reshape(-1, grad_wrt_output.shape[-1])
            if g.shape[0] > 0 and 'default' in module.lora_B:
                # B has shape (out_features, r)
                # Cast weight to the same dtype as gradient `g` to prevent mismatch
                lora_B = module.lora_B['default'].weight.data.to(device=g.device, dtype=g.dtype, non_blocking=True)
                # projected_g has shape (batch*seq, r)
                projected_g = g @ lora_B
                g_cov_sample = projected_g.T @ projected_g

                # Online update for g_covs
                current_cov = manager.g_covs[module]
                n = manager.num_samples_g[module]
                new_n = n + projected_g.shape[0]
                if new_n > 0:
                    updated_cov = (current_cov.float() * n + g_cov_sample.cpu().float()) / new_n
                    manager.g_covs[module].copy_(updated_cov)
                    manager.num_samples_g[module] = new_n

        # We return gradients for the inputs of the `forward` method:
        # (module, output, input)
        # 1. module: Not a tensor, so None
        # 2. output: The gradient is `grad_wrt_output`, pass it back to the LoraLayer
        # 3. input: This function does not depend on `input` for its output's value,
        #    so its gradient contribution w.r.t. `input` is zero.
        return None, grad_wrt_output, None

In [4]:
# Place this function outside your GRITManager class.
# It only knows how to work with Tensors.
def jit_invert_tensor_pair(a_cov: torch.Tensor, g_cov: torch.Tensor, kfac_damping: float):
    """
    JIT-compatible function to invert a SINGLE pair of covariance tensors.
    """
    with torch.no_grad():
        damping = max(kfac_damping, 1e-6)
        
        a_cov_damped = a_cov.float() + damping * torch.eye(a_cov.shape[0], device='cpu')
        g_cov_damped = g_cov.float() + damping * torch.eye(g_cov.shape[0], device='cpu')
        
        L_a, info_a = torch.linalg.cholesky_ex(a_cov_damped)
        L_g, info_g = torch.linalg.cholesky_ex(g_cov_damped)
        
        if info_a == 0 and info_g == 0:
            a_inv = torch.cholesky_inverse(L_a).half()
            g_inv = torch.cholesky_inverse(L_g).half()
            return a_inv, g_inv
        else:
            # Return empty tensors on failure, which we can check for later
            return torch.empty(0), torch.empty(0)

In [5]:
class GRITConfig:
    """Configuration class for GRIT with Unsloth support."""
    def __init__(self):
        # Using a 3B model, optimized for a powerful A100 GPU
        self.model_id = "meta-llama/Llama-3.2-3B"  # Using Unsloth's pre-optimized model
        
        # -------- Training Configuration (Now Optimized for Kaggle -> 16GB VRAM) --------
        self.batch_size = 16  # Lowered from 8 for memory constraints
        self.gradient_accumulation_steps = 4 # Increased from 1 to maintain effective batch size of 8
        self.num_epochs = 1
        self.learning_rate = 2e-5
        self.precision = "bf16" # bf16 is still optimal for modern GPUs
        self.max_length = 256

        # LoRA configuration (optimized for memory)
        self.lora_rank = 16   # Start with a higher rank to allow for pruning
        self.lora_alpha = 32 # Adjusted for new rank (2 * rank)
        self.lora_dropout = 0.0
        # 5W heuristic for layer selection: include attention and key MLP layers
        self.lora_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]

        # GRIT parameters (tuned for speed vs. quality)
        self.kfac_update_freq = 100 # Invert less often
        self.kfac_damping = 0.001
        self.reprojection_freq = 100
        self.reprojection_k = 8
        self.grit_cov_update_freq = 15 # Update covs a bit more often with more budget
        
        # --- Rank-Adaptive LoRA Configuration ---
        self.enable_rank_adaptation = True
        self.rank_adaptation_threshold = 0.9  # Cumulative energy threshold
        self.min_lora_rank = 4                 # Minimum rank to prevent collapse

        # --- Convergence Control ---
        self.enable_early_stopping = True
        self.early_stopping_patience = 3 # Stop after 3 evaluations with no improvement

        # Unsloth-specific parameters
        self.use_unsloth = False
        self.unsloth_max_seq_length = self.max_length
        self.unsloth_dtype = torch.bfloat16 if self.precision == "bf16" else torch.float16
        self.unsloth_load_in_4bit = True  # Use 4-bit quantization
        
        # Data loading configuration
        self.dataset_name = "nyu-mll/glue" # Default dataset
        self.num_workers = 2 # Reduced from 8 for Kaggle's lower CPU/RAM resources
        self.pin_memory = False # This is generally fine
        self.drop_last = True

config = GRITConfig()

In [6]:
class GRITManager:
    """
    Memory-efficient GRIT Manager with CPU-based K-FAC storage and selective optimization.
    This addresses both performance and memory issues.
    """
    
    def __init__(self, model, config, device):
        self.model = model
        self.config = config
        self.device = device
        self.global_step = 0
        self.backward_step = 0
        
        # Adaptive frequency state
        self.loss_history = []
        self.loss_history_capacity = 20  # Corresponds to window size in _get_adaptive_freq
        self.last_kfac_update_step = 0
        self.last_reprojection_step = 0
        
        # Memory-efficient K-FAC state storage (CPU-based)
        self.a_covs = {}  # Input activation covariances (stored on CPU)
        self.g_covs = {}  # Output gradient covariances (stored on CPU)
        self.a_invs = {}  # Inverted input covariances (CPU)
        self.g_invs = {}  # Inverted output covariances (CPU)
        
        # Online estimation counters
        self.num_samples_a = {}
        self.num_samples_g = {}

        # Track which modules we're optimizing (subset for memory efficiency)
        self.optimized_modules = []
        
        self.factors_are_ready = False
        self._instrument_model()
        
        # Debug: Check how many LoRA layers we found
        total_modules = len(self.optimized_modules)
        print(f"GRITManager: Initialization complete.")
        print(f"🔍 Optimizing {total_modules} key LoRA modules.")
        print(f"💾 K-FAC matrices stored on CPU for memory efficiency.")

    def _instrument_model(self):
        """
        Replaces forward passes with a version that includes our autograd function.
        Optimized to only instrument q_proj modules in the last 8 layers and use r-dim
        covariance matrices to significantly reduce CPU RAM usage.
        """
        attention_modules = 0
        
        for name, module in self.model.named_modules():
            if isinstance(module, LoraLayer) and module.r['default'] > 0:
                module.lora_name = name # Store name for logging
                module.grit_manager = self 
                self.num_samples_a[module] = 0
                self.num_samples_g[module] = 0
                
                # Store original forward method
                module.original_forward = module.forward
                
                # Replace forward method to apply the hook *after* the original forward
                def new_forward(self, x):
                    y = self.original_forward(x)
                    return KFACAutogradFunction.apply(self, y, x)

                module.forward = new_forward.__get__(module, LoraLayer)

                # --- store covariances in r-dim space (MASSIVE memory saving) ---
                r = module.r['default']
                self.a_covs[module] = torch.zeros((r, r), device='cpu', dtype=torch.float16)
                self.g_covs[module] = torch.zeros((r, r), device='cpu', dtype=torch.float16)
                
                self.optimized_modules.append(module)
                attention_modules += 1
        
        print(f"🎯 Instrumented {attention_modules} modules for GRIT optimization with custom autograd:")
        print(f"   • {attention_modules} attention modules")
        print(f"🚀 Using r-dim ({self.config.lora_rank}x{self.config.lora_rank}) covariances for maximum memory efficiency.")

    def _get_adaptive_freq(self, base_freq, min_freq=1, max_freq=1000, window=20):
        """Calculates adaptive frequency based on loss stability."""
        if len(self.loss_history) < window:
            return base_freq
        
        # Check loss trend over the last `window` steps
        recent_losses = self.loss_history[-window:]
        first_half = sum(recent_losses[:window//2]) / (window//2)
        second_half = sum(recent_losses[window//2:]) / (window//2)
        
        # If loss is decreasing (learning is stable), decrease frequency (run less often)
        if second_half < first_half * 0.99:
            new_freq = int(base_freq * 1.5)
        # If loss is fluctuating or increasing, increase frequency (run more often)
        else:
            new_freq = int(base_freq * 0.75)
            
        return max(min_freq, min(new_freq, max_freq))

    def step(self, loss=None):
        """Called after each optimizer step to manage periodic updates"""
        self.global_step += 1
        if loss is not None:
            self.loss_history.append(loss)
            if len(self.loss_history) > self.loss_history_capacity:
                self.loss_history = self.loss_history[-self.loss_history_capacity:]

        # --- Second-order damping schedule ---
        if len(self.loss_history) > 10: # Need enough history to compute variance
            loss_variance = torch.tensor(self.loss_history).var().item()
            # Scale damping with variance, with min/max caps
            self.config.kfac_damping = max(1e-6, min(0.01, 0.001 + math.sqrt(loss_variance)))
            if self.global_step % 100 == 0: # Log periodically
                 wandb.log({"adaptive_kfac_damping": self.config.kfac_damping})

        # Adaptive K-FAC update frequency
        kfac_freq = self._get_adaptive_freq(self.config.kfac_update_freq)
        if self.global_step - self.last_kfac_update_step >= kfac_freq:
            self.update_and_invert_factors()
            self.last_kfac_update_step = self.global_step
        
        # Adaptive neural reprojection frequency
        reproj_freq = self._get_adaptive_freq(self.config.reprojection_freq, min_freq=1, max_freq=2000)
        if self.global_step - self.last_reprojection_step >= reproj_freq:
            self.neural_reprojection()
            self.last_reprojection_step = self.global_step

    # Inside your GRITManager class
    def update_and_invert_factors(self):
        print(f"\nGRITManager: Inverting K-FAC factors at step {self.global_step}...")

        # JIT script the new, simpler function ONCE outside the loop
        scripted_invert_fn = torch.jit.script(jit_invert_tensor_pair)

        # The loop over modules stays in regular Python, where module keys are OK
        for module in self.optimized_modules:
            a_cov = self.a_covs[module]
            g_cov = self.g_covs[module]

            # Call the JIT function with TENSORS ONLY
            a_inv, g_inv = scripted_invert_fn(
                a_cov=a_cov,
                g_cov=g_cov,
                kfac_damping=self.config.kfac_damping
            )

            # Check if the inversion was successful and update the dictionaries
            if a_inv.numel() > 0 and g_inv.numel() > 0:
                self.a_invs[module] = a_inv
                self.g_invs[module] = g_inv
            else:
                print(f"K-FAC inversion failed for a module. Skipping.")
        
        self.factors_are_ready = True

    def _invert_factors_fn(self):
        with torch.no_grad():
            for module in self.optimized_modules:
                # Ensure damping is non-negative
                damping = max(self.config.kfac_damping, 1e-6)
                
                # --- Invert Activation Covariance (a_cov) ---
                a_cov_damped = self.a_covs[module].float() + damping * torch.eye(
                    self.a_covs[module].shape[0], device='cpu'
                )
                
                # Use cholesky_ex to check for positive-definiteness (a proxy for invertibility here)
                # L_a is the Cholesky factor, info_a is 0 on success
                L_a, info_a = torch.linalg.cholesky_ex(a_cov_damped)
                
                # --- Invert Gradient Covariance (g_cov) ---
                g_cov_damped = self.g_covs[module].float() + damping * torch.eye(
                    self.g_covs[module].shape[0], device='cpu'
                )
                L_g, info_g = torch.linalg.cholesky_ex(g_cov_damped)

                # Check if both decompositions succeeded
                if info_a == 0 and info_g == 0:
                    # If successful, compute inverse from Cholesky factor (more stable)
                    self.a_invs[module] = torch.cholesky_inverse(L_a).half()
                    self.g_invs[module] = torch.cholesky_inverse(L_g).half()
                else:
                    # This block runs if inversion would have failed
                    print(f"K-FAC inversion check failed for a module. Skipping update.")
                    # You can still log this event if needed, but wandb calls might not be JIT-friendly.
                    # It's better to handle logging outside the JIT-compiled function.
                    continue

    def precondition_gradients(self):
        """Apply K-FAC preconditioning with efficient CPU-GPU transfers"""
        if not self.factors_are_ready:
            return
        if self.global_step % self.config.kfac_update_freq == 0:
            print(f"\nGRITManager: Applying Natural Gradient preconditioner at step {self.global_step} to {self.global_step + 50}...")
            
        with torch.no_grad():
            # Determine correct dtype based on config
            dtype = torch.float16 if self.config.precision == "fp16" else torch.bfloat16

            for module in self.optimized_modules:
                if module not in self.a_invs or module not in self.g_invs:
                    continue
                    
                lora_a = module.lora_A['default']
                lora_b = module.lora_B['default']
                
                if (lora_a is None or lora_b is None or 
                    lora_a.weight.grad is None or lora_b.weight.grad is None):
                    continue

                try:
                    # Move inverse matrices to GPU only when needed, using correct dtype
                    a_inv = self.a_invs[module].to(self.device, dtype=dtype)
                    g_inv = self.g_invs[module].to(self.device, dtype=dtype)
                    
                    # --- True Natural Gradient Computation: grad' = G_inv @ grad @ A_inv ---
                    
                    # Precondition LoRA B gradient (the "easy" part)
                    # Original grad_b has shape (out, r)
                    grad_b = lora_b.weight.grad.to(dtype)
                    # Corrected multiplication: (out, r) @ (r, r) -> (out, r)
                    preconditioned_b_grad = grad_b @ g_inv
                    
                    # Precondition LoRA A gradient (the "hard" part)
                    # Original grad_a has shape (r, in)
                    grad_a = lora_a.weight.grad.to(dtype)
                    
                    # Safety reshape for larger ranks
                    r = module.r['default']
                    if r > 0:
                        grad_a = grad_a.view(r, -1)
                    
                    # Instead, we apply the preconditioning to each matrix's gradient.
                    # grad_a' = A_inv @ grad_a
                    # Corrected multiplication: (r, r) @ (r, in) -> (r, in)
                    preconditioned_a_grad = a_inv @ grad_a

                    # Copy back the preconditioned gradients
                    lora_a.weight.grad.copy_(preconditioned_a_grad.to(lora_a.weight.grad.dtype))
                    lora_b.weight.grad.copy_(preconditioned_b_grad.to(lora_b.weight.grad.dtype))
                    
                    # Clean up GPU tensors immediately
                    del a_inv, g_inv, grad_a, grad_b, preconditioned_a_grad, preconditioned_b_grad
                    
                except Exception as e:
                    print(f"Gradient preconditioning failed: {e}")
                    wandb.log({"preconditioning_error": 1})
                    continue

    def neural_reprojection(self):
        """Perform neural reprojection and log the parameter reduction."""
        print(f"\nGRITManager: Neural reprojection at step {self.global_step}...")
        
        initial_params = 0
        final_params = 0
        log_dict = {}

        with torch.no_grad():
            # First, calculate initial effective parameters for all GRIT-optimized modules
            for module in self.optimized_modules:
                if hasattr(module, 'in_features') and hasattr(module, 'out_features'):
                    # Effective parameters in LoRA = r * (in_features + out_features)
                    initial_params += module.r['default'] * (module.in_features + module.out_features)

            for module in self.optimized_modules:
                try:
                    lora_a = module.lora_A['default']
                    lora_b = module.lora_B['default']
                    
                    if lora_a is None or lora_b is None:
                        continue
                        
                    A = lora_a.weight.data.float()
                    B = lora_b.weight.data.float()
                    
                    r = A.shape[0]  # Current LoRA rank
                    k = r # Initialize k to the current rank as a fallback
                    
                    # Form the r x r covariance matrix M = A @ A.T
                    M = A @ A.T
                    
                    # --- Numerical Stability Check ---
                    if torch.isnan(M).any() or torch.isinf(M).any():
                        print(f"WARNING: Covariance matrix M for {module.lora_name} contains NaN/Inf. Skipping reprojection for this module.")
                        log_dict[f"reprojection_errors/nan_inf/{module.lora_name}"] = 1
                        # If skipped, the rank does not change.
                        final_params += r * (module.in_features + module.out_features)
                        continue

                    # --- Rank Adaptation & Pruning Logic ---
                    if self.config.enable_rank_adaptation and r > self.config.min_lora_rank:
                        # --- Full Eigendecomposition for Adaptive Rank ---
                        try:
                            eigenvals, V = torch.linalg.eigh(M)
                        except Exception as e:
                            print(f"Eigendecomposition failed for module: {e}", "error")
                            log_dict["reprojection_errors/eigh_error"] = log_dict.get("reprojection_errors/eigh_error", 0) + 1
                            final_params += r * (module.in_features + module.out_features)
                            continue
                        
                        sorted_indices = torch.argsort(eigenvals, descending=True)
                        sorted_eigenvals = eigenvals[sorted_indices]
                        V = V[:, sorted_indices]

                        total_energy = torch.sum(sorted_eigenvals)
                        if total_energy > 1e-6:
                            cumulative_energy = torch.cumsum(sorted_eigenvals, dim=0) / total_energy
                            k = (cumulative_energy < self.config.rank_adaptation_threshold).sum().item() + 1
                        
                        k = max(k, self.config.min_lora_rank) # Enforce minimum rank
                        rank_reduction_percent = (1 - k / r) * 100 if r > 0 else 0
                        
                        log_dict[f"GRIT/Effective Rank (k)/{module.lora_name}"] = k
                        log_dict[f"GRIT/Rank Reduction (%)/{module.lora_name}"] = rank_reduction_percent

                        # Log the sorted eigenvalue distribution to show energy decay
                        # --- Enhanced WandB Logging for Eigen-spectra ---
                        # --- Log Data Tables for Manual Plotting ---
                        try:
                            # Create a shorter name for logging
                            short_lora_name = module.lora_name.replace("base_model.model.", "")
                        
                            # 1. Log Eigen value spectra (This part is fine)
                            eigen_spectra_list = sorted_eigenvals.cpu().numpy().tolist()
                            data_for_table = [[s] for s in eigen_spectra_list]
                            table = wandb.Table(data=data_for_table, columns=["eigenvalue"])
                            hist_plot = wandb.plot.histogram(table, "eigenvalue", title=f"Eigenvalue Spectrum - {short_lora_name}")
                            log_dict[f"Charts/Eigenvalue Spectrum/{short_lora_name}"] = hist_plot
                        
                            # 2. Log ALL (Original) Eigenvectors
                            V_all_T = V.T.cpu().numpy()  # <-- CORRECT: Use the full V matrix
                            heatmap_data_all = []
                            for x, eigenvector in enumerate(V_all_T):
                                for y, component in enumerate(eigenvector):
                                    heatmap_data_all.append([x, y, component])
                            heatmap_table_all = wandb.Table(
                                data=heatmap_data_all,
                                columns=["eigenvector_idx", "component_idx", "value"]
                            )
                            log_dict[f"Data_Tables/Eigenvectors_All/{short_lora_name}"] = heatmap_table_all
                        
                            # 3. Log the CHOSEN TOP-K eigenvectors
                            V_k_T = V[:, :k].T.cpu().numpy() # This is correct for the chosen vectors
                            heatmap_data_chosen = []
                            for x, eigenvector in enumerate(V_k_T):
                                for y, component in enumerate(eigenvector):
                                    heatmap_data_chosen.append([x, y, component])
                            heatmap_table_chosen = wandb.Table(
                                data=heatmap_data_chosen,
                                columns=["eigenvector_idx", "component_idx", "value"]
                            )
                            log_dict[f"Data_Tables/Eigenvectors_Chosen_TopK/{short_lora_name}"] = heatmap_table_chosen
                        
                        except Exception as e:
                            print(f"Data logging for {module.lora_name} failed: {e}")
                        
                        V_k = V[:, :k]
                    else:
                        # --- Fixed Rank Reprojection ---
                        k = min(self.config.reprojection_k, r)
                        try:
                            if k < r and k > 0:
                                _, V_k = torch.lobpcg(lambda v: M @ v, X=torch.randn(r, k, device=M.device, dtype=M.dtype), k=k, largest=True)
                            else:
                                _, V = torch.linalg.eigh(M)
                                eigenvals, _ = torch.linalg.eigh(M)
                                V_k = V[:, torch.argsort(eigenvals, descending=True)][:, :k]
                        except Exception as e:
                            print(f"LOBPCG failed for module, falling back to eigh: {e}")
                            _, V = torch.linalg.eigh(M)
                            eigenvals, _ = torch.linalg.eigh(M)
                            V_k = V[:, torch.argsort(eigenvals, descending=True)][:, :k]
                    
                    # Accumulate the new effective parameter count for this module
                    if hasattr(module, 'in_features') and hasattr(module, 'out_features'):
                        final_params += k * (module.in_features + module.out_features)

                    # --- Project onto the top-k eigenvectors and prune via zeroing ---
                    A_proj = V_k.T @ A
                    B_proj = B @ V_k

                    A_new, B_new = torch.zeros_like(A), torch.zeros_like(B)
                    A_new[:k, :], B_new[:, :k] = A_proj, B_proj

                    lora_a.weight.data.copy_(A_new.to(A.dtype))
                    lora_b.weight.data.copy_(B_new.to(B.dtype))
                        
                except Exception as e:
                    print(f"Neural reprojection failed for module {module.lora_name}: {e}")
                    log_dict["reprojection_errors/general_error"] = log_dict.get("reprojection_errors/general_error", 0) + 1
                    # On failure, assume rank does not change for this module
                    if hasattr(module, 'in_features') and hasattr(module, 'out_features'):
                        final_params += module.r['default'] * (module.in_features + module.out_features)
                    continue

        # --- Log the final results ---
        if initial_params > 0:
            param_reduction = initial_params - final_params
            reduction_percent = (param_reduction / initial_params) * 100
            print(f"✅ Neural reprojection completed. Effective parameter count reduced.")
            print(f"   - Initial GRIT params: {initial_params:,}")
            print(f"   - Final GRIT params:   {final_params:,}")
            print(f"   - Reduction:           {param_reduction:,} ({reduction_percent:.2f}%)")
            
            log_dict.update({
                "Parameters/GRIT Initial Params": initial_params,
                "Parameters/GRIT Final Params": final_params,
                "Parameters/GRIT Param Reduction (%)": reduction_percent,
            })
        else:
            print(f"✅ Neural reprojection completed for {len(self.optimized_modules)} modules.")
        
        if log_dict:
            wandb.log(log_dict, step=self.global_step)
        
        # Strategic memory cleanup after expensive operations
        gc.collect()
        torch.cuda.synchronize()

In [7]:
from torch.optim import Optimizer

class GritOptimizer(Optimizer):
    """
    A wrapper for a PyTorch optimizer that applies GRIT's K-FAC gradient
    preconditioning before the actual optimization step. This ensures that the
    optimizer uses the natural gradient instead of the standard gradient.
    """
    def __init__(self, optimizer: Optimizer, grit_manager: 'GRITManager'):
        self.optimizer = optimizer
        self.grit_manager = grit_manager

    @property
    def state(self):
        return self.optimizer.state

    @property
    def param_groups(self):
        return self.optimizer.param_groups

    @param_groups.setter
    def param_groups(self, value):
        self.optimizer.param_groups = value

    def step(self, closure=None):
        """
        Performs a single optimization step.

        1.  Applies K-FAC preconditioning to the accumulated gradients.
        2.  Calls the underlying optimizer's step function.
        """
        # Apply K-FAC preconditioning to the accumulated gradients
        if self.grit_manager.factors_are_ready:
            self.grit_manager.precondition_gradients()

        # Call the underlying optimizer's step function
        self.optimizer.step(closure)

    def zero_grad(self, set_to_none: bool = False):
        self.optimizer.zero_grad(set_to_none=set_to_none)

    def add_param_group(self, param_group: dict):
        self.optimizer.add_param_group(param_group)

    def state_dict(self):
        return self.optimizer.state_dict()

    def load_state_dict(self, state_dict: dict):
        self.optimizer.load_state_dict(state_dict)

    def __repr__(self):
        return f"GritOptimizer({self.optimizer.__repr__()})"

In [8]:
from transformers import TrainerCallback

class GritCallback(TrainerCallback):
    """
    This callback injects GRIT's logic into the training loop.
    """
    def __init__(self, grit_manager):
        self.grit_manager = grit_manager

    def on_step_end(self, args, state, control, **kwargs):
        """
        Triggered at the end of each training step.
        """
        last_loss = state.log_history[-1].get("loss") if state.log_history else None
        self.grit_manager.step(loss=last_loss)

In [9]:
class GritTrainer(Trainer):
    """
    Optimized GRIT Trainer that addresses the core performance issues and
    correctly applies gradient preconditioning by wrapping the optimizer.
    """

    def __init__(self, grit_manager, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.grit_manager = grit_manager
        print("GritTrainer: Initialized with GRIT implementation.")

    def create_optimizer_and_scheduler(self, num_training_steps: int):
        """
        Overrides the base method to wrap the created optimizer with our
        GritOptimizer, which handles gradient preconditioning.
        """
        super().create_optimizer_and_scheduler(num_training_steps)

        if self.optimizer is not None:
            print("🎁 Wrapping the optimizer with GRIT preconditioning logic.")
            self.optimizer = GritOptimizer(self.optimizer, self.grit_manager)

    # --- THIS IS THE CORRECTED METHOD ---
    def training_step(self, model, inputs, num_items_in_batch=None):
        """
        The training step's signature is now aligned with the parent class.
        """
        model.train()
        inputs = self._prepare_inputs(inputs)

        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs)

        if self.args.n_gpu > 1:
            loss = loss.mean()

        self.accelerator.backward(loss)

        return loss.detach() / self.args.gradient_accumulation_steps
        
    def evaluate(self, *args, **kwargs):
        """Overrides the default evaluate method to add aggressive memory cleanup."""
        print("\n🧹 Clearing VRAM before evaluation...")
        gc.collect()
        torch.cuda.empty_cache()

        return super().evaluate(*args, **kwargs)

    def prediction_step(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:

        has_labels = "labels" in inputs
        if has_labels:
            labels = inputs.pop("labels")
        else:
            labels = None

        _, generated_tokens, _ = super().prediction_step(
            model, inputs, prediction_loss_only, ignore_keys
        )

        if has_labels:
            inputs["labels"] = labels

        loss = None
        if has_labels:
            with torch.no_grad():
                loss = self.compute_loss(model, inputs.copy()).detach()

        if generated_tokens is not None and type(generated_tokens).__name__ == 'EmptyLogits':
            batch_size = inputs["input_ids"].shape[0]
            seq_length = labels.shape[1] if labels is not None else inputs["input_ids"].shape[1]
            vocab_size = self.model.config.vocab_size

            generated_tokens = torch.zeros(
                (batch_size, seq_length, vocab_size),
                device=self.accelerator.device,
                dtype=config.unsloth_dtype
            )

        if labels is not None:
            if len(labels.shape) == 3:
                labels = torch.argmax(labels, dim=-1)
            elif len(labels.shape) == 1:
                batch_size = inputs["input_ids"].shape[0]
                labels = labels.view(batch_size, -1)

        return loss, generated_tokens, labels

In [10]:
# Your original QLoRA loading code
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16 if config.precision == "fp16" else torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

tokenizer = AutoTokenizer.from_pretrained(config.model_id, use_fast=False)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForSequenceClassification.from_pretrained(
    config.model_id,
    quantization_config=quantization_config,
    num_labels=2,
    device_map="auto",
)

model.config.pad_token_id = tokenizer.pad_token_id

target_modules = []
for i in range(0, 32):
    for module in config.lora_target_modules:
        if "proj" in module:  # attention modules
            target_modules.append(f"model.layers.{i}.self_attn.{module}")
        else:  # MLP modules  
            target_modules.append(f"model.layers.{i}.mlp.{module}")
            
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
    r=config.lora_rank,
    lora_alpha=config.lora_alpha,
    target_modules=target_modules,
    lora_dropout=config.lora_dropout,
    bias="none",
    task_type="SEQ_CLS",
    inference_mode=False,
    use_rslora=True,
)
model = get_peft_model(model, lora_config)

model.print_trainable_parameters()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at meta-llama/Llama-3.2-3B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 9,181,184 || all params: 3,221,937,152 || trainable%: 0.2850


In [11]:
# ---------------- Dataset Preparation ----------------

train_dataset = load_dataset(config.dataset_name, 'qnli', split="train")
val_dataset = load_dataset(config.dataset_name, 'qnli', split="validation")

# --- Code to calculate average token length ---
print("Calculating average token length...")

def get_token_length(example):
    """Simple function to tokenize and return length."""
    return {"length": len(tokenizer(example['question'], example['sentence']).input_ids)}

# Calculate lengths on a subset for speed, or the full dataset for accuracy
lengths_dataset = train_dataset.map(get_token_length, num_proc=4)

average_length = np.mean(lengths_dataset['length'])
print(f"📊 Average token length in training set: {average_length:.2f}")
print(f"👉 Recommendation: Set config.max_length to a value like {2**int(math.log2(average_length) + 1)} (e.g., 128, 256)")
# --- End of token length calculation ---

def tokenize(example):
    """ Tokenizer for QNLI, using 'question' and 'sentence' columns. """
    return tokenizer(
        example['question'],
        example['sentence'],
        truncation=True,
        max_length=config.max_length
    )

tokenized_train_dataset = train_dataset.map(tokenize, remove_columns=['question', 'sentence', 'idx'])
tokenized_val_dataset = val_dataset.map(tokenize, remove_columns=['question', 'sentence', 'idx'])


# Compute total training steps
train_len = len(tokenized_train_dataset)
eff_batch = config.batch_size * config.gradient_accumulation_steps
total_steps = math.ceil(train_len / eff_batch) * config.num_epochs
print(f"🏁 Total training steps: {total_steps}")

Calculating average token length...
📊 Average token length in training set: 49.58
👉 Recommendation: Set config.max_length to a value like 64 (e.g., 128, 256)
🏁 Total training steps: 1637


In [12]:
run_name = f"grit-{config.model_id.split('/')[-1]}-{config.dataset_name.split('/')[-1]}-QNLI"

import wandb

# Start a new wandb run to track this script.
run = wandb.init(
    # Set the wandb project where this run will be logged.
    entity="RAAPID",
    project="GRIT-Final",
    name=run_name,
    job_type="training",
)

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mchandrav[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [13]:
import numpy as np
import evaluate

# Load the accuracy and F1 metrics from the library
accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)

    accuracy = accuracy_metric.compute(predictions=predictions, references=labels)
    # *** KEY CHANGE: Use average="binary" for F1 score in a two-class task ***
    f1 = f1_metric.compute(predictions=predictions, references=labels, average="binary")

    return {
        "accuracy": accuracy["accuracy"],
        "f1": f1["f1"],
    }

In [14]:
# ---------------- Training Arguments ----------------
training_args = TrainingArguments(
    output_dir="./results",
    run_name=run_name,

    per_device_train_batch_size=config.batch_size,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    num_train_epochs=config.num_epochs,
    learning_rate=config.learning_rate,

    #eval_strategy="steps",
    #eval_steps=1000,                 # Evaluate less often
    logging_steps=200,
    #eval_accumulation_steps=1,
    #max_steps=200,

    save_strategy="epoch",
    #save_steps=250,                 # Match eval_steps for early stopping
    save_total_limit=2,

    fp16=config.precision == "fp16",
    bf16=config.precision == "bf16",
    gradient_checkpointing=True,
    dataloader_num_workers=config.num_workers,
    dataloader_pin_memory=config.pin_memory,
    remove_unused_columns=False,

    report_to="wandb",
    #metric_for_best_model="bleu",
    #greater_is_better=True,
    #predict_with_generate=True,
    
    # --- Generation settings for faster evaluation ---
    # generation_max_length=config.max_length + 128,
    # generation_num_beams=1,

    # --- Early Stopping ---
    #load_best_model_at_end=config.enable_early_stopping,

    optim="paged_adamw_8bit"
)

torch.cuda.empty_cache()
torch.cuda.synchronize()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("🚀 Initializing GRITManager on device:", device)
grit_manager = GRITManager(model, config, device)

# Instantiate the new callback
grit_callback = GritCallback(grit_manager)

# Let the Trainer use its default high-performance optimizer
# The custom 8-bit optimizer is not needed when memory is abundant

🚀 Initializing GRITManager on device: cuda
🎯 Instrumented 112 modules for GRIT optimization with custom autograd:
   • 112 attention modules
🚀 Using r-dim (16x16) covariances for maximum memory efficiency.
GRITManager: Initialization complete.
🔍 Optimizing 112 key LoRA modules.
💾 K-FAC matrices stored on CPU for memory efficiency.


In [15]:
def count_parameters(model):
    """Count and display parameter statistics"""
    total_params = 0
    trainable_params = 0
    lora_params = 0
    
    layer_params = {}  # Track params by layer
    
    for name, param in model.named_parameters():
        total_params += param.numel()
        
        if param.requires_grad:
            trainable_params += param.numel()
            if "lora_" in name:
                lora_params += param.numel()
                
                # Extract layer number for statistics
                for i in range(32):
                    if f'layers.{i}.' in name:
                        if i not in layer_params:
                            layer_params[i] = 0
                        layer_params[i] += param.numel()
                        break
    
    print(f"\n📊 Parameter Statistics:")
    print(f"🔢 Total parameters: {total_params:,}")
    print(f"🔥 Trainable parameters: {trainable_params:,} ({trainable_params/total_params*100:.3f}%)")
    print(f"📐 LoRA parameters: {lora_params:,}")
    
    print(f"\n📍 Layer-wise LoRA Distribution:")
    active_layers = sorted(layer_params.keys())
    for layer_id in active_layers:
        print(f"   Layer {layer_id}: {layer_params[layer_id]:,} params")
    
    if active_layers:
        print(f"\n🎯 Strategy: LoRA + GRIT applied to layers {min(active_layers)}-{max(active_layers)}")

    return total_params, trainable_params

# Call this after model setup
total_params, trainable_params = count_parameters(model)
if wandb.run:
    wandb.config.update({
        "total_model_params": total_params,
        "lora_trainable_params": trainable_params,
        "initial_lora_rank_r": config.lora_rank,
    })


📊 Parameter Statistics:
🔢 Total parameters: 1,812,651,008
🔥 Trainable parameters: 9,181,184 (0.507%)
📐 LoRA parameters: 9,175,040

📍 Layer-wise LoRA Distribution:
   Layer 0: 327,680 params
   Layer 1: 327,680 params
   Layer 2: 327,680 params
   Layer 3: 327,680 params
   Layer 4: 327,680 params
   Layer 5: 327,680 params
   Layer 6: 327,680 params
   Layer 7: 327,680 params
   Layer 8: 327,680 params
   Layer 9: 327,680 params
   Layer 10: 327,680 params
   Layer 11: 327,680 params
   Layer 12: 327,680 params
   Layer 13: 327,680 params
   Layer 14: 327,680 params
   Layer 15: 327,680 params
   Layer 16: 327,680 params
   Layer 17: 327,680 params
   Layer 18: 327,680 params
   Layer 19: 327,680 params
   Layer 20: 327,680 params
   Layer 21: 327,680 params
   Layer 22: 327,680 params
   Layer 23: 327,680 params
   Layer 24: 327,680 params
   Layer 25: 327,680 params
   Layer 26: 327,680 params
   Layer 27: 327,680 params

🎯 Strategy: LoRA + GRIT applied to layers 0-27


In [16]:
# Trainer uses GRITManager for K-FAC steps
trainer = GritTrainer(
    grit_manager=grit_manager,
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_val_dataset,
    tokenizer=tokenizer,
    callbacks=[grit_callback],
    compute_metrics=compute_metrics,
    #callbacks=[EarlyStoppingCallback(early_stopping_patience=config.early_stopping_patience)] if config.enable_early_stopping else [],
)

# Final memory cleanup
import gc
for _ in range(3):
    gc.collect()
    torch.cuda.empty_cache()
torch.cuda.synchronize()

print(f"🎯 Effective batch (per step): {config.batch_size * config.gradient_accumulation_steps}")
print("⚡ Mixed precision:", config.precision)
print("📏 Max sequence length:", config.max_length)
print("🔧 LoRA rank:", config.lora_rank)
print("🚀 Starting training now...")

trainer.train()

print("🎉 Training completed successfully!")

  super().__init__(*args, **kwargs)
No label_names provided for model class `PeftModelForSequenceClassification`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


GritTrainer: Initialized with GRIT implementation.
🎯 Effective batch (per step): 64
⚡ Mixed precision: bf16
📏 Max sequence length: 256
🔧 LoRA rank: 16
🚀 Starting training now...
🎁 Wrapping the optimizer with GRIT preconditioning logic.


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
200,0.4467
400,0.2556
600,0.2406
800,0.2309
1000,0.2234
1200,0.2176
1400,0.2055
1600,0.2085



GRITManager: Inverting K-FAC factors at step 100...

GRITManager: Neural reprojection at step 100...
✅ Neural reprojection completed. Effective parameter count reduced.
   - Initial GRIT params: 9,175,040
   - Final GRIT params:   8,601,600
   - Reduction:           573,440 (6.25%)

GRITManager: Applying Natural Gradient preconditioner at step 100 to 150...

GRITManager: Inverting K-FAC factors at step 200...

GRITManager: Neural reprojection at step 200...
✅ Neural reprojection completed. Effective parameter count reduced.
   - Initial GRIT params: 9,175,040
   - Final GRIT params:   8,028,160
   - Reduction:           1,146,880 (12.50%)

GRITManager: Applying Natural Gradient preconditioner at step 200 to 250...

GRITManager: Inverting K-FAC factors at step 275...

GRITManager: Neural reprojection at step 275...
✅ Neural reprojection completed. Effective parameter count reduced.
   - Initial GRIT params: 9,175,040
   - Final GRIT params:   7,454,720
   - Reduction:           1,720,3

In [17]:
trainer.evaluate()
wandb.finish()


🧹 Clearing VRAM before evaluation...


0,1
GRIT/Effective Rank (k)/base_model.model.model.layers.0.self_attn.k_proj,█▇▆▅▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
GRIT/Effective Rank (k)/base_model.model.model.layers.0.self_attn.o_proj,█▇▆▅▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
GRIT/Effective Rank (k)/base_model.model.model.layers.0.self_attn.q_proj,█▇▆▅▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
GRIT/Effective Rank (k)/base_model.model.model.layers.0.self_attn.v_proj,█▇▆▅▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
GRIT/Effective Rank (k)/base_model.model.model.layers.1.self_attn.k_proj,█▇▆▅▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
GRIT/Effective Rank (k)/base_model.model.model.layers.1.self_attn.o_proj,█▇▆▅▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
GRIT/Effective Rank (k)/base_model.model.model.layers.1.self_attn.q_proj,█▇▆▅▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
GRIT/Effective Rank (k)/base_model.model.model.layers.1.self_attn.v_proj,█▇▆▅▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
GRIT/Effective Rank (k)/base_model.model.model.layers.10.self_attn.k_proj,█▇▆▅▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
GRIT/Effective Rank (k)/base_model.model.model.layers.10.self_attn.o_proj,█▇▆▅▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
GRIT/Effective Rank (k)/base_model.model.model.layers.0.self_attn.k_proj,9.0
GRIT/Effective Rank (k)/base_model.model.model.layers.0.self_attn.o_proj,9.0
GRIT/Effective Rank (k)/base_model.model.model.layers.0.self_attn.q_proj,9.0
GRIT/Effective Rank (k)/base_model.model.model.layers.0.self_attn.v_proj,9.0
GRIT/Effective Rank (k)/base_model.model.model.layers.1.self_attn.k_proj,9.0
GRIT/Effective Rank (k)/base_model.model.model.layers.1.self_attn.o_proj,9.0
GRIT/Effective Rank (k)/base_model.model.model.layers.1.self_attn.q_proj,9.0
GRIT/Effective Rank (k)/base_model.model.model.layers.1.self_attn.v_proj,9.0
GRIT/Effective Rank (k)/base_model.model.model.layers.10.self_attn.k_proj,9.0
GRIT/Effective Rank (k)/base_model.model.model.layers.10.self_attn.o_proj,9.0


In [18]:
# === Hugging Face Upload ===
from huggingface_hub import create_repo, upload_folder, HfApi

print("\n🚀 Uploading fine-tuned model to Hugging Face Hub...")

# === Hugging Face Upload ===
print("\n🚀 Uploading fine-tuned model to Hugging Face Hub...")

HF_USERNAME = "te4bag" # Replace with your Hugging Face username
FULL_MODEL_NAME = f"{HF_USERNAME}/GRIT-Full-GLUE-QNLI-llama-3.2-3B-Energy-0.9"

try:
    # Save model and tokenizer locally first
    output_dir = "./grit_trained_model"
    model.save_pretrained(output_dir, safe_serialization=True)
    tokenizer.save_pretrained(output_dir)

    # Create/update model card (with hardware and hyperparams)
    model_card_content = f"""---
tags:
- llama
- Natural Language Inference
- grit
- lora
- qlora
- unsloth
- instruction-tuning
- fine-tuned
base_model: {config.model_id}
library_name: peft
license: apache-2.0
datasets:
- {config.dataset_name}
language:
- en
pipeline_tag: Sequence Classification
---

# {config.model_id} Fine-tuned with GRIT and QLoRA

This model is a fine-tuned version of [{config.model_id}](https://huggingface.co/{config.model_id}) using the **GRIT** (Geometric Reprojection Instruction Tuning) algorithm and **QLoRA** on the [{config.dataset_name} dataset](https://huggingface.co/datasets/{config.dataset_name}).

The base model is quantized to 4-bit (NF4) to enable efficient fine-tuning.

## 🚀 Training Details

### GRIT Algorithm
- **K-FAC Updates**: Every {config.kfac_update_freq} steps (adaptive) for second-order preconditioning.
- **Neural Reprojection**: Every {config.reprojection_freq} steps (adaptive) for rank optimization.
- **Rank Adaptation**: {'Enabled' if config.enable_rank_adaptation else 'Disabled'} (Threshold: {config.rank_adaptation_threshold}, Min Rank: {config.min_lora_rank}).
- **Optimized LoRA Modules**: {config.lora_target_modules}

### Fine-tuning Configuration
- **Base Model**: {config.model_id}
- **Quantization**: 4-bit (NF4) with {config.precision} compute.
- **LoRA Rank**: {config.lora_rank}
- **LoRA Alpha**: {config.lora_alpha}
- **Batch Size**: {config.batch_size} (per device)
- **Gradient Accumulation**: {config.gradient_accumulation_steps} (Effective batch = {config.batch_size * config.gradient_accumulation_steps})
- **Learning Rate**: {config.learning_rate:.1e}
- **Precision**: {config.precision} mixed precision
- **Sequence Length**: {config.max_length} tokens
- **Gradient Checkpointing**: Enabled

### Performance Improvements
- ✅ **Faster Convergence**: K-FAC preconditioning aligns updates with curvature.
- ✅ **Memory-Efficient**: 4-bit quantization (QLoRA) and gradient checkpointing used.
- ✅ **Adaptive Rank**: Dynamically prunes LoRA rank to improve parameter efficiency.

## 📊 Training Metrics
- **Total Steps**: {trainer.state.global_step if 'trainer' in locals() else 'N/A'}
- **Final Loss**: {trainer.state.log_history[-1].get('train_loss', 'N/A') if 'trainer' in locals() and trainer.state.log_history else 'N/A'}
- **Trainable Params**: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}

## 📝 Algorithm Details
- **K-FAC Preconditioning** (Natural Gradient) and **Neural Reprojection** as per GRIT method.
- **Memory Efficient**: Covariance matrices on CPU to reduce GPU load.

## 🏆 Results
In benchmark comparisons, GRIT has shown **faster convergence and better stability** than standard LoRA or fine-tuning, making it well-suited for efficient single-epoch training. The use of Unsloth further accelerates this process.

## 📝 Citation
If you use this model, please cite the original GRIT paper and:
```bibtex
@misc{{grit-lora-{config.model_id.split('/')[-1]}-{config.dataset_name.split('/')[-1]}}},
  title={{ {config.model_id} Fine-tuned with GRIT on {config.dataset_name} }},
  author={{{HF_USERNAME}}},
  year={{2024}},
  publisher={{Hugging Face}},
  url={{https://huggingface.co/{FULL_MODEL_NAME}}}
}}
```

## ⚖️ License
This model inherits the Apache 2.0 license.
"""
    with open(f"{output_dir}/README.md", "w") as f:
        f.write(model_card_content)
    print("📝 Model card created with training details.")

    # Create repository on the Hub and upload the folder
    print(f"🌐 Uploading model to https://huggingface.co/{FULL_MODEL_NAME} ...")
    create_repo(FULL_MODEL_NAME, exist_ok=True)
    api = HfApi()
    api.upload_folder(
        folder_path=output_dir,
        repo_id=FULL_MODEL_NAME,
        commit_message=f"GRIT fine-tuned {config.model_id.split('/')[-1]} on {config.dataset_name.split('/')[-1]}"
    )
    print(f"✅ Model successfully uploaded: https://huggingface.co/{FULL_MODEL_NAME}")

except Exception as e:
    print(f"❌ Upload failed: {e}\\n(Model saved locally in {output_dir})")

print("🎉 GRIT fine-tuning script completed!")


🚀 Uploading fine-tuned model to Hugging Face Hub...

🚀 Uploading fine-tuned model to Hugging Face Hub...
📝 Model card created with training details.
🌐 Uploading model to https://huggingface.co/te4bag/GRIT-Full-GLUE-QNLI-llama-3.2-3B-Energy-0.9 ...


- The pipeline tag "Sequence Classification" is not in the official list: text-classification, token-classification, table-question-answering, question-answering, zero-shot-classification, translation, summarization, feature-extraction, text-generation, fill-mask, sentence-similarity, text-to-speech, text-to-audio, automatic-speech-recognition, audio-to-audio, audio-classification, audio-text-to-text, voice-activity-detection, depth-estimation, image-classification, object-detection, image-segmentation, text-to-image, image-to-text, image-to-image, image-to-video, unconditional-image-generation, video-classification, reinforcement-learning, robotics, tabular-classification, tabular-regression, tabular-to-text, table-to-text, multiple-choice, text-ranking, text-retrieval, time-series-forecasting, text-to-video, image-text-to-text, visual-question-answering, document-question-answering, zero-shot-image-classification, graph-ml, mask-generation, zero-shot-object-detection, text-to-3d, ima

Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

  /grit_trained_model/tokenizer.json    : 100%|##########| 17.2MB / 17.2MB            

  ...ned_model/adapter_model.safetensors:   0%|          | 26.1kB / 36.8MB            

✅ Model successfully uploaded: https://huggingface.co/te4bag/GRIT-Full-GLUE-QNLI-llama-3.2-3B-Energy-0.9
🎉 GRIT fine-tuning script completed!
