In [None]:
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install datasets torch torchvision torchaudio huggingface_hub tqdm psutil GPUtil flash-attn

: 

In [None]:
from huggingface_hub import login
login("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", add_to_git_credential=True)

## SF8 TRAINER (bfloat16) Flash Attention

In [None]:
import torch
import torch.nn as nn
import numpy as np
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
from datasets import load_dataset
from torch.utils.data import DataLoader
from flash_attn import flash_attn_func, flash_attn_kvpacked_func
from flash_attn.flash_attention import FlashMHA
from flash_attn.bert_padding import unpad_input, pad_input
from tqdm.auto import tqdm
import psutil
import GPUtil
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
!mkdir sf8_llama_quantized

def setup_ddp(rank, world_size):
    """Initialize the distributed environment."""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    
    # Initialize process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    
    # Set device for this process
    torch.cuda.set_device(rank)
    
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)

def cleanup_ddp():
    """Clean up the distributed environment."""
    dist.destroy_process_group()

out_of_range_detected = False

def modified_tanh(x):
    return torch.tanh(x) * 0.99609375

def activation_check_hook(module, input, output):
    global out_of_range_detected
    
    if isinstance(output, tuple):
        output_tensor = output[0]
    else:
        output_tensor = output

    clamped_output = modified_tanh(output_tensor)
    
    if isinstance(output, tuple):
        return tuple(clamped_output if isinstance(t, torch.Tensor) else t for t in output)
    else:
        return clamped_output

# Flash Attention with MHA
# class FlashLlamaAttention(nn.Module):
#     """Flash Attention implementation for LLaMA"""
#     def __init__(self, config):
#         super().__init__()
#         self.config = config
#         self.hidden_size = config.hidden_size
#         self.num_heads = config.num_attention_heads
#         self.head_dim = self.hidden_size // self.num_heads
#         self.max_position_embeddings = config.max_position_embeddings

#         # Initialize Flash attention
#         self.flash_attention = FlashMHA(
#             embed_dim=self.hidden_size,
#             num_heads=self.num_heads,
#             bias=False,
#             batch_first=True,
#             causal=True,
#             device="cuda"
#         )
        
#         # Original LLaMA projection layers
#         self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
#         self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
#         self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
#         self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        
#     def forward(self, hidden_states, attention_mask=None, past_key_value=None):
#         batch_size, seq_length = hidden_states.shape[:2]
        
#         # Project queries, keys, and values
#         query_states = self.q_proj(hidden_states)
#         key_states = self.k_proj(hidden_states)
#         value_states = self.v_proj(hidden_states)
        
#         # Reshape for flash attention
#         query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim)
#         key_states = key_states.view(batch_size, seq_length, self.num_heads, self.head_dim)
#         value_states = value_states.view(batch_size, seq_length, self.num_heads, self.head_dim)
        
#         # Handle padding if attention mask is provided
#         if attention_mask is not None:
#             unpad_masks = attention_mask.bool()
#             query_states, indices, cu_seqlens, max_seqlen = unpad_input(query_states, unpad_masks)
#             key_states, _, _, _ = unpad_input(key_states, unpad_masks)
#             value_states, _, _, _ = unpad_input(value_states, unpad_masks)
            
#             # Apply Flash Attention
#             attn_output = flash_attn_func(
#                 query_states, key_states, value_states,
#                 cu_seqlens=cu_seqlens,
#                 max_seqlen=max_seqlen,
#                 causal=True
#             )
            
#             # Pad output back
#             attn_output = pad_input(attn_output, indices, batch_size, seq_length)
#         else:
#             # Direct Flash Attention without padding handling
#             attn_output = self.flash_attention(
#                 query_states,
#                 key_states,
#                 value_states
#             )
        
#         attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size)
#         attn_output = self.o_proj(attn_output)
        
#         # Clamp outputs to SF8 range
#         attn_output = torch.clamp(attn_output, min=-0.99609375, max=0.99609375)
        
#         return attn_output, None

# Flash Attention with KV Packed Function
class FlashLlamaAttention(nn.Module):
    """Flash Attention implementation for LLaMA with KV packing optimization"""
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.max_position_embeddings = config.max_position_embeddings

        # Original LLaMA projection layers
        self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        
    def forward(self, hidden_states, attention_mask=None, past_key_value=None):
        batch_size, seq_length = hidden_states.shape[:2]
        
        # Project queries, keys, and values
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        
        # Reshape for flash attention
        query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim)
        key_states = key_states.view(batch_size, seq_length, self.num_heads, self.head_dim)
        value_states = value_states.view(batch_size, seq_length, self.num_heads, self.head_dim)
        
        # Pack key and value states together
        kv_states = torch.stack([key_states, value_states], dim=2)
        # Shape becomes: [batch_size, seq_length, 2, num_heads, head_dim]
        
        # Handle padding if attention mask is provided
        if attention_mask is not None:
            unpad_masks = attention_mask.bool()
            query_states, indices, cu_seqlens, max_seqlen = unpad_input(query_states, unpad_masks)
            kv_states, _, _, _ = unpad_input(kv_states, unpad_masks)
            
            # Apply Flash Attention with packed KV
            attn_output = flash_attn_kvpacked_func(
                query_states,
                kv_states,
                cu_seqlens=cu_seqlens,
                max_seqlen=max_seqlen,
                causal=True,
                softmax_scale=1.0 / math.sqrt(self.head_dim)
            )
            
            # Pad output back
            attn_output = pad_input(attn_output, indices, batch_size, seq_length)
        else:
            # If no attention mask, reshape tensors for flash attention
            attn_output = flash_attn_kvpacked_func(
                query_states,
                kv_states,
                causal=True,
                softmax_scale=1.0 / math.sqrt(self.head_dim)
            )
        
        attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size)
        attn_output = self.o_proj(attn_output)
        
        # Clamp outputs to SF8 range
        attn_output = torch.clamp(attn_output, min=-0.99609375, max=0.99609375)
        
        return attn_output, None
    
def register_hooks(model):
    """Register hooks to check activations in the model layers."""
    hooks = []
    for layer in model.model.layers:
        hooks.append(layer.register_forward_hook(activation_check_hook))
    hooks.append(model.model.norm.register_forward_hook(activation_check_hook))
    return hooks

def get_memory_usage():
    gpu_memory = ""
    if torch.cuda.is_available():
        gpu = GPUtil.getGPUs()[0]
        gpu_memory = f"GPU Memory: {gpu.memoryUsed:.0f}MB/{gpu.memoryTotal:.0f}MB ({gpu.memoryUtil*100:.1f}%)"
    ram_memory = f"RAM: {psutil.Process().memory_info().rss / 1024 / 1024:.0f}MB"
    return f"{gpu_memory} | {ram_memory}"

class SF8:
    """
    Super Float 8 (SF8) implementation
    1 bit for sign, 7 bits for mantissa
    Range: (-1, 1) exclusive
    """
    def __init__(self, tensor):
        self.tensor = tensor
    
    @staticmethod
    def to_sf8(tensor):
        return torch.clamp(tensor, min=-0.99609375, max=0.99609375)
    
    @staticmethod
    def from_sf8(tensor):
        return tensor

class SF8Parameter(nn.Parameter):
    """Custom Parameter class for SF8"""
    def __new__(cls, data=None, requires_grad=True):
        tensor = SF8.to_sf8(data) if data is not None else None
        return super(SF8Parameter, cls).__new__(cls, tensor, requires_grad)

def reclamp_parameters(model):
    """Clamp all model parameters to SF8 range after conversion."""
    for name, param in model.named_parameters():
        if isinstance(param, nn.Parameter):
            param.data = torch.clamp(param.data, min=-0.99609375, max=0.99609375)
    return model

def convert_model_to_sf8_with_flash(model):
    """Convert model parameters to SF8 format and modify attention layers to use Flash Attention"""
    # Convert parameters to bfloat16 and SF8
    for name, param in model.named_parameters():
        if isinstance(param, nn.Parameter):
            param.data = param.data.to(torch.bfloat16)
            sf8_tensor = SF8.to_sf8(param.data)
            model._parameters[name] = SF8Parameter(sf8_tensor, requires_grad=param.requires_grad)
    
    # Replace attention layers with Flash Attention
    for layer in model.model.layers:
        # Create new Flash Attention module
        flash_attention = FlashLlamaAttention(model.config)
        
        # Copy weights from original attention
        flash_attention.q_proj.weight.data = layer.self_attn.q_proj.weight.data
        flash_attention.k_proj.weight.data = layer.self_attn.k_proj.weight.data
        flash_attention.v_proj.weight.data = layer.self_attn.v_proj.weight.data
        flash_attention.o_proj.weight.data = layer.self_attn.o_proj.weight.data
        
        # Replace attention module
        layer.self_attn = flash_attention
        
        # Add activation clamping
        layer.register_forward_hook(activation_check_hook)
    
    return model

class SF8Optimizer(torch.optim.Adam):
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    p.grad.data = torch.clamp(p.grad.data, min=-0.99609375, max=0.99609375)
                    super().step(closure)
                    with torch.no_grad():
                        p.data.clamp_(-0.99609375, 0.99609375)
        
        return loss

def prepare_dataset(tokenizer, rank, world_size, max_length=512):
    """Prepare the dataset with proper tensor formatting"""
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
    
    def tokenize_function(examples):
        outputs = tokenizer(
            examples["text"],
            truncation=True,
            max_length=max_length,
            padding="max_length",
            return_tensors="pt"
        )
        return outputs

    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=dataset.column_names
    )

    # Create distributed sampler
    sampler = DistributedSampler(
        tokenized_dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True
    )

    return tokenized_dataset, sampler

def collate_fn(batch):
    """Custom collate function to properly format tensors"""
    input_ids = torch.stack([torch.tensor(example['input_ids']) for example in batch])
    attention_mask = torch.stack([torch.tensor(example['attention_mask']) for example in batch])
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask
    }

def check_sf8_params(model):
    """Check if all parameters in the model are within the SF8 range."""
    for name, param in model.named_parameters():
        if param.data.dim() > 0:  # Check only non-scalar tensors
            if not ((param.data >= -0.99609375) & (param.data <= 0.99609375)).all():
                print(f"Parameter {name} out of range!")
                return False
    print("All parameters are within the SF8 range.")
    return True

def train_llama_sf8_ddp(rank, world_size):
    """Main training function with DDP support"""
    # Initialize distributed setup
    setup_ddp(rank, world_size)
    
    # Set device
    device = torch.device(f"cuda:{rank}")
    
    # Initialize model and tokenizer
    model_name = "meta-llama/Llama-3.2-1B"
    model = LlamaForCausalLM.from_pretrained(model_name, cache_dir='./')
    model = model.to(torch.bfloat16).to(device)
    
    # Enable gradient checkpointing
    model.gradient_checkpointing_enable()
    
    # Convert model to SF8 with Flash Attention
    model = reclamp_parameters(convert_model_to_sf8_with_flash(model))
    
    # Wrap model with DDP
    model = DDP(
        model,
        device_ids=[rank],
        output_device=rank,
        find_unused_parameters=False
    )
    
    # Initialize tokenizer only on rank 0 to avoid duplicate downloads
    if rank == 0:
        tokenizer = PreTrainedTokenizerFast.from_pretrained(model_name, cache_dir='./')
        tokenizer.pad_token = tokenizer.eos_token
    
    # Broadcast tokenizer from rank 0 to all processes
    dist.barrier()
    if rank != 0:
        tokenizer = PreTrainedTokenizerFast.from_pretrained(model_name, cache_dir='./')
        tokenizer.pad_token = tokenizer.eos_token
    
    # Prepare dataset with distributed sampler
    tokenized_dataset, sampler = prepare_dataset(tokenizer, rank, world_size)
    
    # Create dataloader with distributed sampler
    dataloader = DataLoader(
        tokenized_dataset,
        batch_size=4,
        sampler=sampler,
        collate_fn=collate_fn,
        pin_memory=True,
        num_workers=4
    )
    
    # Initialize optimizer
    optimizer = SF8Optimizer(model.parameters())
    
    # Training loop
    num_epochs = 100
    max_grad_norm = 0.99609375
    best_loss = float('inf')
    
    # Only create progress bars on rank 0
    if rank == 0:
        epoch_pbar = tqdm(range(num_epochs), desc="Training", position=0)
    else:
        epoch_pbar = range(num_epochs)
    
    for epoch in epoch_pbar:
        # Set epoch for sampler
        sampler.set_epoch(epoch)
        
        model.train()
        total_loss = 0
        correct_predictions = 0
        total_predictions = 0
        
        # Only create batch progress bar on rank 0
        if rank == 0:
            batch_pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}", position=1, leave=False)
        else:
            batch_pbar = dataloader
        
        for batch in batch_pbar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=input_ids
            )
            loss = outputs.loss
            
            # Calculate accuracy
            predictions = outputs.logits.argmax(dim=-1)
            labels = input_ids[:, 1:]
            pred = predictions[:, :-1]
            mask = attention_mask[:, 1:]
            
            del outputs, input_ids, attention_mask
            torch.cuda.empty_cache()
            
            # Gather predictions and loss across all processes
            dist.all_reduce(loss)
            loss = loss / world_size
            
            correct_pred = ((pred == labels) * mask).sum()
            total_pred = mask.sum()
            
            dist.all_reduce(correct_pred)
            dist.all_reduce(total_pred)
            
            correct_predictions += correct_pred.item()
            total_predictions += total_pred.item()
            
            optimizer.zero_grad()
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
            
            total_loss += loss.item()
            
            # Update progress bar on rank 0
            if rank == 0:
                current_lr = optimizer.param_groups[0]['lr']
                batch_pbar.set_postfix({
                    'loss': f"{loss.item():.4f}",
                    'memory': get_memory_usage(),
                    'lr': f"{current_lr:.2e}"
                })
        
        # Calculate epoch metrics
        avg_loss = total_loss / len(dataloader)
        accuracy = 100 * correct_predictions / total_predictions if total_predictions > 0 else 0
        
        # Update progress bar on rank 0
        if rank == 0:
            epoch_pbar.set_postfix({
                'avg_loss': f"{avg_loss:.4f}",
                'accuracy': f"{accuracy:.2f}%",
                'memory': get_memory_usage()
            })
            
            # Save checkpoints only on rank 0
            if avg_loss < best_loss:
                best_loss = avg_loss
                torch.save(model.module.state_dict(), 'sf8_llama_quantized/sf8_llama_flash_ddp_best_model.pt')
            
            if (epoch + 1) % 10 == 0:
                torch.save(
                    model.module.state_dict(),
                    f'sf8_llama_quantized/sf8_llama_flash_ddp_checkpoint_epoch_{epoch+1}.pt'
                )
    
    # Clean up
    cleanup_ddp()

def main():
    """Main function to launch distributed training"""
    world_size = torch.cuda.device_count()
    if world_size < 2:
        raise ValueError("Need at least 2 GPUs for distributed training!")
    
    print(f"Starting distributed training on {world_size} GPUs...")
    
    # Launch processes
    mp.spawn(
        train_llama_sf8_ddp,
        args=(world_size,),
        nprocs=world_size,
        join=True
    )

if __name__ == "__main__":
    main()

## SF8 TRAINER (bfloat16)

In [None]:
import torch
import torch.nn as nn
import numpy as np
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
from datasets import load_dataset
from torch.utils.data import DataLoader

from tqdm.auto import tqdm
import psutil
import GPUtil

!mkdir sf8_llama_quantized

out_of_range_detected = False

def modified_tanh(x):
    return torch.tanh(x) * 0.99609375

def activation_check_hook(module, input, output):
    global out_of_range_detected
    
    # If output is a tuple, take the first element
    if isinstance(output, tuple):
        output_tensor = output[0]
    else:
        output_tensor = output

    # Apply modified tanh to the output
    clamped_output = modified_tanh(output_tensor)
    
    if isinstance(output, tuple):
        return tuple(clamped_output if isinstance(t, torch.Tensor) else t for t in output)
    else:
        return clamped_output

def register_hooks(model):
    """Register hooks to check activations in the model layers."""
    hooks = []
    
    # Register hooks for each LlamaDecoderLayer
    for layer in model.model.layers:
        hooks.append(layer.register_forward_hook(activation_check_hook))
    
    # Also register for the final layer norm
    hooks.append(model.model.norm.register_forward_hook(activation_check_hook))
    
    return hooks

def get_memory_usage():
    gpu_memory = ""
    if torch.cuda.is_available():
        gpu = GPUtil.getGPUs()[0]
        gpu_memory = f"GPU Memory: {gpu.memoryUsed:.0f}MB/{gpu.memoryTotal:.0f}MB ({gpu.memoryUtil*100:.1f}%)"
    ram_memory = f"RAM: {psutil.Process().memory_info().rss / 1024 / 1024:.0f}MB"
    return f"{gpu_memory} | {ram_memory}"

class SF8LlamaAttention(nn.Module):
    def forward(self, *args, **kwargs):
        outputs = super().forward(*args, **kwargs)
        
        # Clamp attention outputs
        if isinstance(outputs, tuple):
            clamped_outputs = (
                torch.clamp(outputs[0], min=-0.99609375, max=0.99609375),
                *outputs[1:]
            )
            return clamped_outputs
        return torch.clamp(outputs, min=-0.99609375, max=0.99609375)

class SF8:
    """
    Super Float 8 (SF8) implementation
    1 bit for sign, 7 bits for mantissa
    Range: (-1, 1) exclusive
    """
    def __init__(self, tensor):
        self.tensor = tensor
    
    @staticmethod
    def to_sf8(tensor):
        # Clamp values to (-0.99609375, 0.99609375)
        return torch.clamp(tensor, min=-0.99609375, max=0.99609375)
    
    @staticmethod
    def from_sf8(tensor):
        # Convert back to regular float32
        return tensor

class SF8Parameter(nn.Parameter):
    """Custom Parameter class for SF8"""
    def __new__(cls, data=None, requires_grad=True):
        tensor = SF8.to_sf8(data) if data is not None else None
        return super(SF8Parameter, cls).__new__(cls, tensor, requires_grad)
    
def reclamp_parameters(model):
    """Clamp all model parameters to SF8 range after conversion."""
    for name, param in model.named_parameters():
        if isinstance(param, nn.Parameter):
            param.data = torch.clamp(param.data, min=-0.99609375, max=0.99609375)
    return model

def convert_model_to_sf8(model):
    """Convert all model parameters to SF8 format and modify attention layers"""
    # First convert parameters to bfloat16 and then to SF8
    for name, param in model.named_parameters():
        if isinstance(param, nn.Parameter):
            param.data = param.data.to(torch.bfloat16)
            sf8_tensor = SF8.to_sf8(param.data)
#             print(f"Parameter {name} range: min={sf8_tensor.min():.6f}, max={sf8_tensor.max():.6f}")
            model._parameters[name] = SF8Parameter(sf8_tensor, requires_grad=param.requires_grad)
    
    # Then modify attention layers to use SF8 attention
    for layer in model.model.layers:
        # Wrap the original attention module with SF8 attention
        original_attention = layer.self_attn
        sf8_attention = SF8LlamaAttention()
        sf8_attention.__dict__ = original_attention.__dict__.copy()
        layer.self_attn = sf8_attention
        
        # Add activation clamping after feed-forward
        layer.register_forward_hook(activation_check_hook)
    
    return model

def check_sf8_params(model):
    """Check if all parameters in the model are within the SF8 range."""
    for name, param in model.named_parameters():
        if param.data.dim() > 0:  # Check only non-scalar tensors
            if not ((param.data >= -0.99609375) & (param.data <= 0.99609375)).all():
                print(f"Parameter {name} out of range!")
                return False
    print("All parameters are within the SF8 range.")
    return True

class SF8Optimizer(torch.optim.Adam):
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    # Clamp gradients before the update
                    p.grad.data = torch.clamp(p.grad.data, min=-0.99609375, max=0.99609375)
                    
                    # Regular Adam update
                    super().step(closure)
                    
                    # Clamp parameters after the update
                    with torch.no_grad():
                        p.data.clamp_(-0.99609375, 0.99609375)
        
        return loss

def prepare_dataset(tokenizer, max_length=512):
    """Prepare the dataset with proper tensor formatting"""
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
    
    def tokenize_function(examples):
        outputs = tokenizer(
            examples["text"],
            truncation=True,
            max_length=max_length,
            padding="max_length",
            return_tensors="pt"
        )
        return outputs

    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=dataset.column_names
    )

    return tokenized_dataset

def collate_fn(batch):
    """Custom collate function to properly format tensors"""
    input_ids = torch.stack([torch.tensor(example['input_ids']) for example in batch])
    attention_mask = torch.stack([torch.tensor(example['attention_mask']) for example in batch])
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask
    }

def train_llama_sf8():
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Initialize model and tokenizer
    model_name = "meta-llama/Llama-3.2-1B"
    model = LlamaForCausalLM.from_pretrained(model_name, cache_dir='./')
    model = model.to(torch.bfloat16).to(device)
    
    # Enable gradient checkpointing to save memory
    model.gradient_checkpointing_enable()
    
    tokenizer = PreTrainedTokenizerFast.from_pretrained(model_name, cache_dir='./')
    
    tokenizer.pad_token = tokenizer.eos_token
    
    # Convert model to SF8 and move to device
    model = reclamp_parameters(convert_model_to_sf8(model))

    # Check if all parameters are in the SF8 range
    if not check_sf8_params(model):
        raise ValueError("Some parameters are out of the SF8 range.")
    
    # Save the quantized model before training
    torch.save(model.state_dict(), "sf8_llama_quantized/sf8_llama_quantized.pt")
    tokenizer.save_pretrained("sf8_llama_quantized")
    print("Saved quantized model")
    
    # Clear VRAM
    if torch.cuda.is_available():
        del model
        torch.cuda.empty_cache()
        print("Cleared VRAM")
    
    # Reload model from saved files
    model = LlamaForCausalLM.from_pretrained(model_name, cache_dir = './')
    model.load_state_dict(torch.load("sf8_llama_quantized/sf8_llama_quantized.pt"))
    model = model.to(device)
    print("Reloaded model to fresh VRAM")
    
    # input_text = "Sing me a song"
    # inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
    
    # global out_of_range_detected
    # out_of_range_detected = False
    
    # hooks = register_hooks(model)
    
    # with torch.no_grad():
    #     outputs = model.generate(
    #         inputs['input_ids'],
    #         attention_mask=inputs['attention_mask'],
    #         max_length=1024,
    #         num_return_sequences=1
    #     )
    # # Remove hooks after inference
    # for hook in hooks:
    #     hook.remove()
    
    # generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # print(generated_text)
    
    # Prepare dataset
    tokenized_dataset = prepare_dataset(tokenizer)
    dataloader = DataLoader(
        tokenized_dataset,
        batch_size=1,
        shuffle=True,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    # Initialize optimizer
    optimizer = SF8Optimizer(model.parameters())
    
    # Rest of your training loop remains the same
    num_epochs = 100
    max_grad_norm = 0.99609375
    best_loss = float('inf')
    
    # Create epoch progress bar
    epoch_pbar = tqdm(range(num_epochs), desc="Training", position=0)
    
    for epoch in epoch_pbar:
        model.train()
        total_loss = 0
        correct_predictions = 0
        total_predictions = 0
        
        # Create batch progress bar
        batch_pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}", position=1, leave=False)
        
        for batch in batch_pbar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=input_ids
            )
            loss = outputs.loss
            
            # Calculate accuracy (using next token prediction)
            predictions = outputs.logits.argmax(dim=-1)
            labels = input_ids[:, 1:]  # Shift right to get next token
            pred = predictions[:, :-1]  # Remove last prediction
            mask = attention_mask[:, 1:] # Adjust mask accordingly
            
            del outputs, input_ids, attention_mask
            torch.cuda.empty_cache()

            correct_predictions += ((pred == labels) * mask).sum().item()
            total_predictions += mask.sum().item()
            
            optimizer.zero_grad()
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            
            optimizer.step()
            
            total_loss += loss.item()
            
            # Update batch progress bar
            current_lr = optimizer.param_groups[0]['lr']
            batch_pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'memory': get_memory_usage(),
                'lr': f"{current_lr:.2e}"
            })
        
        # Calculate epoch metrics
        avg_loss = total_loss / len(dataloader)
        accuracy = 100 * correct_predictions / total_predictions if total_predictions > 0 else 0
        
        # Update epoch progress bar
        epoch_pbar.set_postfix({
            'avg_loss': f"{avg_loss:.4f}",
            'accuracy': f"{accuracy:.2f}%",
            'memory': get_memory_usage()
        })
        
        # Save checkpoint if loss improved
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), 'sf8_llama_quantized/sf8_llama_best_model.pt')
        
        # Regular checkpoint saving
        if (epoch + 1) % 10 == 0:
            torch.save(model.state_dict(), f'sf8_llama_quantized/sf8_llama_checkpoint_epoch_{epoch+1}.pt')
            
if __name__ == "__main__":
    train_llama_sf8()

In [None]:
# Important: Only run this when pushing files to Hub!
# from huggingface_hub import login

# login("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", add_to_git_credential=True)

# huggingface-cli upload-large-folder aoxo/llama-3.2-sf8 --repo-type=model /kaggle/working/sf8_llama_quantized --num-workers=16

## SF16 TRAINER (fp32)

In [None]:
import torch
import torch.nn as nn
import numpy as np
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
from datasets import load_dataset
from torch.utils.data import DataLoader

from tqdm.auto import tqdm
import psutil
import GPUtil

!mkdir sf16_llama_quantized

out_of_range_detected = False

def modified_tanh(x):
    return torch.tanh(x) * 0.9999847412109375

def activation_check_hook(module, input, output):
    global out_of_range_detected
    
    # If output is a tuple, take the first element
    if isinstance(output, tuple):
        output_tensor = output[0]
    else:
        output_tensor = output

    # Apply modified tanh to the output
    clamped_output = modified_tanh(output_tensor)
    
    if isinstance(output, tuple):
        return tuple(clamped_output if isinstance(t, torch.Tensor) else t for t in output)
    else:
        return clamped_output

def register_hooks(model):
    """Register hooks to check activations in the model layers."""
    hooks = []
    
    # Register hooks for each LlamaDecoderLayer
    for layer in model.model.layers:
        hooks.append(layer.register_forward_hook(activation_check_hook))
    
    # Also register for the final layer norm
    hooks.append(model.model.norm.register_forward_hook(activation_check_hook))
    
    return hooks

def get_memory_usage():
    gpu_memory = ""
    if torch.cuda.is_available():
        gpu = GPUtil.getGPUs()[0]
        gpu_memory = f"GPU Memory: {gpu.memoryUsed:.0f}MB/{gpu.memoryTotal:.0f}MB ({gpu.memoryUtil*100:.1f}%)"
    ram_memory = f"RAM: {psutil.Process().memory_info().rss / 1024 / 1024:.0f}MB"
    return f"{gpu_memory} | {ram_memory}"

class SF16LlamaAttention(nn.Module):
    def forward(self, *args, **kwargs):
        outputs = super().forward(*args, **kwargs)
        
        # Clamp attention outputs
        if isinstance(outputs, tuple):
            clamped_outputs = (
                torch.clamp(outputs[0], min=-0.9999847412109375, max=0.9999847412109375),
                *outputs[1:]
            )
            return clamped_outputs
        return torch.clamp(outputs, min=-0.9999847412109375, max=0.9999847412109375)

class SF16:
    """
    Super Float 16 (SF16) implementation
    1 bit for sign, 7 bits for mantissa
    Range: (-1, 1) exclusive
    """
    def __init__(self, tensor):
        self.tensor = tensor
    
    @staticmethod
    def to_sf16(tensor):
        # Clamp values to (-0.9999847412109375, 0.9999847412109375)
        return torch.clamp(tensor, min=-0.9999847412109375, max=0.9999847412109375)
    
    @staticmethod
    def from_sf16(tensor):
        # Convert back to regular float32
        return tensor

class SF16Parameter(nn.Parameter):
    """Custom Parameter class for SF16"""
    def __new__(cls, data=None, requires_grad=True):
        tensor = SF16.to_sf16(data) if data is not None else None
        return super(SF16Parameter, cls).__new__(cls, tensor, requires_grad)
    
def reclamp_parameters(model):
    """Clamp all model parameters to SF16 range after conversion."""
    for name, param in model.named_parameters():
        if isinstance(param, nn.Parameter):
            param.data = torch.clamp(param.data, min=-0.9999847412109375, max=0.9999847412109375)
    return model

def convert_model_to_sf16(model):
    """Convert all model parameters to SF16 format and modify attention layers"""
    # Convert parameters to SF16
    for name, param in model.named_parameters():
        if isinstance(param, nn.Parameter):
            sf16_tensor = SF16.to_sf16(param.data)
#             print(f"Parameter {name} range: min={sf16_tensor.min():.6f}, max={sf16_tensor.max():.6f}")
            model._parameters[name] = SF16Parameter(sf16_tensor, requires_grad=param.requires_grad)
    
    # Then modify attention layers to use SF16 attention
    for layer in model.model.layers:
        # Wrap the original attention module with SF16 attention
        original_attention = layer.self_attn
        sf16_attention = SF16LlamaAttention()
        sf16_attention.__dict__ = original_attention.__dict__.copy()
        layer.self_attn = sf16_attention
        
        # Add activation clamping after feed-forward
        layer.register_forward_hook(activation_check_hook)
    
    return model

def check_sf16_params(model):
    """Check if all parameters in the model are within the SF16 range."""
    for name, param in model.named_parameters():
        if param.data.dim() > 0:  # Check only non-scalar tensors
            if not ((param.data >= -0.9999847412109375) & (param.data <= 0.9999847412109375)).all():
                print(f"Parameter {name} out of range!")
                return False
    print("All parameters are within the SF16 range.")
    return True

class SF16Optimizer(torch.optim.Adam):
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    # Clamp gradients before the update
                    p.grad.data = torch.clamp(p.grad.data, min=-0.9999847412109375, max=0.9999847412109375)
                    
                    # Regular Adam update
                    super().step(closure)
                    
                    # Clamp parameters after the update
                    with torch.no_grad():
                        p.data.clamp_(-0.9999847412109375, 0.9999847412109375)
        
        return loss

def prepare_dataset(tokenizer, max_length=512):
    """Prepare the dataset with proper tensor formatting"""
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
    
    def tokenize_function(examples):
        outputs = tokenizer(
            examples["text"],
            truncation=True,
            max_length=max_length,
            padding="max_length",
            return_tensors="pt"
        )
        return outputs

    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=dataset.column_names
    )

    return tokenized_dataset

def collate_fn(batch):
    """Custom collate function to properly format tensors"""
    input_ids = torch.stack([torch.tensor(example['input_ids']) for example in batch])
    attention_mask = torch.stack([torch.tensor(example['attention_mask']) for example in batch])
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask
    }

def train_llama_sf16():
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Initialize model and tokenizer
    model_name = "meta-llama/Llama-3.2-1B"
    model = LlamaForCausalLM.from_pretrained(model_name, cache_dir='./').to(device)
    
    # Enable gradient checkpointing to save memory
    model.gradient_checkpointing_enable()
    
    tokenizer = PreTrainedTokenizerFast.from_pretrained(model_name, cache_dir='./')
    
    tokenizer.pad_token = tokenizer.eos_token
    
    # Convert model to SF16 and move to device
    model = reclamp_parameters(convert_model_to_sf16(model))

    # Check if all parameters are in the SF16 range
    if not check_sf16_params(model):
        raise ValueError("Some parameters are out of the SF16 range.")
    
    # Save the quantized model before training
    torch.save(model.state_dict(), "sf16_llama_quantized/sf16_llama_quantized.pt")
    tokenizer.save_pretrained("sf16_llama_quantized")
    print("Saved quantized model")
    
    # Clear VRAM
    if torch.cuda.is_available():
        del model
        torch.cuda.empty_cache()
        print("Cleared VRAM")
    
    # Reload model from saved files
    model = LlamaForCausalLM.from_pretrained(model_name, cache_dir = './')
    model.load_state_dict(torch.load("sf16_llama_quantized/sf16_llama_quantized.pt"))
    model = model.to(device)
    print("Reloaded model to fresh VRAM")
    
    input_text = "Sing me a song"
    inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
    
    global out_of_range_detected
    out_of_range_detected = False
    
    hooks = register_hooks(model)
    
    with torch.no_grad():
        outputs = model.generate(
            inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_length=1024,
            num_return_sequences=1
        )
    # Remove hooks after inference
    for hook in hooks:
        hook.remove()
    
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(generated_text)
    
    # Prepare dataset
    tokenized_dataset = prepare_dataset(tokenizer)
    dataloader = DataLoader(
        tokenized_dataset,
        batch_size=1,
        shuffle=True,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    # Initialize optimizer
    optimizer = SF16Optimizer(model.parameters())
    
    # Rest of your training loop remains the same
    num_epochs = 100
    max_grad_norm = 0.9999847412109375
    best_loss = float('inf')
    
    # Create epoch progress bar
    epoch_pbar = tqdm(range(num_epochs), desc="Training", position=0)
    
    for epoch in epoch_pbar:
        model.train()
        total_loss = 0
        correct_predictions = 0
        total_predictions = 0
        
        # Create batch progress bar
        batch_pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}", position=1, leave=False)
        
        for batch in batch_pbar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=input_ids
            )
            loss = outputs.loss
            
            # Calculate accuracy (using next token prediction)
            predictions = outputs.logits.argmax(dim=-1)
            labels = input_ids[:, 1:]  # Shift right to get next token
            pred = predictions[:, :-1]  # Remove last prediction
            mask = attention_mask[:, 1:] # Adjust mask accordingly
            
            del outputs, input_ids, attention_mask
            torch.cuda.empty_cache()

            correct_predictions += ((pred == labels) * mask).sum().item()
            total_predictions += mask.sum().item()
            
            optimizer.zero_grad()
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            
            optimizer.step()
            
            total_loss += loss.item()
            
            # Update batch progress bar
            current_lr = optimizer.param_groups[0]['lr']
            batch_pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'memory': get_memory_usage(),
                'lr': f"{current_lr:.2e}"
            })
        
        # Calculate epoch metrics
        avg_loss = total_loss / len(dataloader)
        accuracy = 100 * correct_predictions / total_predictions if total_predictions > 0 else 0
        
        # Update epoch progress bar
        epoch_pbar.set_postfix({
            'avg_loss': f"{avg_loss:.4f}",
            'accuracy': f"{accuracy:.2f}%",
            'memory': get_memory_usage()
        })
        
        # Save checkpoint if loss improved
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), 'sf16_llama_quantized/sf16_llama_best_model.pt')
        
        # Regular checkpoint saving
        if (epoch + 1) % 10 == 0:
            torch.save(model.state_dict(), f'sf16_llama_quantized/sf16_llama_checkpoint_epoch_{epoch+1}.pt')
            
if __name__ == "__main__":
    train_llama_sf16()

In [None]:
# Important: Only run this when pushing files to Hub!
# from huggingface_hub import login

# login("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", add_to_git_credential=True)

# huggingface-cli upload-large-folder aoxo/llama-3.2-sf16 --repo-type=model /kaggle/working/sf16_llama_quantized --num-workers=16