# Fine-tuning Hyena-GLT Models

This notebook demonstrates how to fine-tune pre-trained Hyena-GLT models for specific genomic tasks.

## Learning Objectives
- Understand fine-tuning strategies for genomic models
- Implement task-specific adaptations
- Apply LoRA and other parameter-efficient methods
- Monitor and optimize fine-tuning performance

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Hyena-GLT imports
from hyena_glt.models import HyenaGLT
from hyena_glt.tokenizer import GenomicTokenizer
from hyena_glt.training import Trainer
from hyena_glt.data import GenomicDataset
from hyena_glt.utils import (
    plot_training_curves,
    count_parameters,
    validate_model_config
)

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

print("Fine-tuning environment ready!")

## 1. Fine-tuning Strategies

### Full Fine-tuning vs Parameter-Efficient Methods

In [None]:
class FineTuningStrategy:
    """Base class for fine-tuning strategies"""
    
    def __init__(self, model, task_type='classification'):
        self.model = model
        self.task_type = task_type
        
    def prepare_model(self):
        """Prepare model for fine-tuning"""
        raise NotImplementedError
        
    def get_trainable_parameters(self):
        """Get parameters to optimize"""
        return self.model.parameters()

class FullFineTuning(FineTuningStrategy):
    """Full model fine-tuning"""
    
    def prepare_model(self):
        # Unfreeze all parameters
        for param in self.model.parameters():
            param.requires_grad = True
            
        # Add task-specific head
        if self.task_type == 'classification':
            self.model.classifier = nn.Linear(
                self.model.config.hidden_size, 
                self.model.config.num_classes
            )
        elif self.task_type == 'regression':
            self.model.regressor = nn.Linear(
                self.model.config.hidden_size, 1
            )
            
        print(f"Full fine-tuning prepared. Trainable parameters: {count_parameters(self.model)}")

class LoRAFineTuning(FineTuningStrategy):
    """LoRA (Low-Rank Adaptation) fine-tuning"""
    
    def __init__(self, model, task_type='classification', rank=16, alpha=32):
        super().__init__(model, task_type)
        self.rank = rank
        self.alpha = alpha
        self.lora_modules = {}
        
    def add_lora_layer(self, module, name):
        """Add LoRA adaptation to a linear layer"""
        if isinstance(module, nn.Linear):
            in_features = module.in_features
            out_features = module.out_features
            
            # Create LoRA matrices
            lora_A = nn.Parameter(torch.randn(self.rank, in_features) * 0.01)
            lora_B = nn.Parameter(torch.zeros(out_features, self.rank))
            
            self.lora_modules[name] = {
                'A': lora_A,
                'B': lora_B,
                'original': module
            }
            
            # Freeze original weights
            module.weight.requires_grad = False
            if module.bias is not None:
                module.bias.requires_grad = False
                
    def prepare_model(self):
        # Freeze base model
        for param in self.model.parameters():
            param.requires_grad = False
            
        # Add LoRA to attention and MLP layers
        for name, module in self.model.named_modules():
            if 'attention' in name or 'mlp' in name:
                if isinstance(module, nn.Linear):
                    self.add_lora_layer(module, name)
                    
        # Register LoRA parameters
        for name, lora_dict in self.lora_modules.items():
            self.model.register_parameter(f'lora_A_{name}', lora_dict['A'])
            self.model.register_parameter(f'lora_B_{name}', lora_dict['B'])
            
        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        print(f"LoRA fine-tuning prepared. Trainable parameters: {trainable_params}")
        
    def get_trainable_parameters(self):
        return [p for p in self.model.parameters() if p.requires_grad]

# Example usage
config = {
    'vocab_size': 4096,
    'hidden_size': 512,
    'num_layers': 8,
    'num_classes': 10
}

model = HyenaGLT(config)
print(f"Base model parameters: {count_parameters(model)}")

# Compare strategies
full_ft = FullFineTuning(model, 'classification')
lora_ft = LoRAFineTuning(model, 'classification', rank=16)

print("\nFine-tuning strategies comparison:")
print(f"Full fine-tuning: Updates all {count_parameters(model)} parameters")
print(f"LoRA (rank {lora_ft.rank}): Updates ~{lora_ft.rank * 2 * config['hidden_size']} parameters")

## 2. Task-Specific Adaptations

### Sequence Classification

In [None]:
class GenomicClassificationHead(nn.Module):
    """Task-specific head for genomic sequence classification"""
    
    def __init__(self, hidden_size, num_classes, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(hidden_size, num_classes)
        self.layer_norm = nn.LayerNorm(hidden_size)
        
    def forward(self, hidden_states):
        # Pool sequence representations
        pooled = hidden_states.mean(dim=1)  # Average pooling
        pooled = self.layer_norm(pooled)
        pooled = self.dropout(pooled)
        return self.classifier(pooled)

class GenomicRegressionHead(nn.Module):
    """Task-specific head for genomic regression tasks"""
    
    def __init__(self, hidden_size, output_dim=1, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.regressor = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size // 2, output_dim)
        )
        self.layer_norm = nn.LayerNorm(hidden_size)
        
    def forward(self, hidden_states):
        pooled = hidden_states.mean(dim=1)
        pooled = self.layer_norm(pooled)
        return self.regressor(pooled)

class MultiTaskHead(nn.Module):
    """Multi-task learning head"""
    
    def __init__(self, hidden_size, task_configs):
        super().__init__()
        self.task_configs = task_configs
        self.shared_layer = nn.Linear(hidden_size, hidden_size)
        
        self.task_heads = nn.ModuleDict()
        for task_name, config in task_configs.items():
            if config['type'] == 'classification':
                self.task_heads[task_name] = nn.Linear(hidden_size, config['num_classes'])
            elif config['type'] == 'regression':
                self.task_heads[task_name] = nn.Linear(hidden_size, config.get('output_dim', 1))
                
    def forward(self, hidden_states, task_name=None):
        # Shared representation
        pooled = hidden_states.mean(dim=1)
        shared = torch.relu(self.shared_layer(pooled))
        
        if task_name:
            return self.task_heads[task_name](shared)
        else:
            # Return all task outputs
            outputs = {}
            for name, head in self.task_heads.items():
                outputs[name] = head(shared)
            return outputs

# Example task configurations
task_configs = {
    'promoter_prediction': {'type': 'classification', 'num_classes': 2},
    'expression_level': {'type': 'regression', 'output_dim': 1},
    'regulatory_elements': {'type': 'classification', 'num_classes': 5}
}

# Create task heads
hidden_size = 512
classification_head = GenomicClassificationHead(hidden_size, num_classes=10)
regression_head = GenomicRegressionHead(hidden_size)
multitask_head = MultiTaskHead(hidden_size, task_configs)

print("Task-specific heads created successfully!")
print(f"Classification head parameters: {count_parameters(classification_head)}")
print(f"Regression head parameters: {count_parameters(regression_head)}")
print(f"Multi-task head parameters: {count_parameters(multitask_head)}")

## 3. Fine-tuning Implementation

In [None]:
class FineTuner:
    """Complete fine-tuning pipeline"""
    
    def __init__(self, model, strategy, task_head, device='cuda'):
        self.model = model
        self.strategy = strategy
        self.task_head = task_head
        self.device = device
        
        # Prepare model for fine-tuning
        self.strategy.prepare_model()
        
        # Add task head
        self.model.task_head = task_head
        self.model.to(device)
        
    def create_optimizer(self, learning_rate=1e-4, weight_decay=0.01):
        """Create optimizer for trainable parameters"""
        if hasattr(self.strategy, 'get_trainable_parameters'):
            params = self.strategy.get_trainable_parameters()
        else:
            params = [p for p in self.model.parameters() if p.requires_grad]
            
        # Add task head parameters
        params.extend([p for p in self.task_head.parameters() if p.requires_grad])
        
        return torch.optim.AdamW(params, lr=learning_rate, weight_decay=weight_decay)
    
    def create_scheduler(self, optimizer, num_training_steps):
        """Create learning rate scheduler"""
        from torch.optim.lr_scheduler import CosineAnnealingLR
        return CosineAnnealingLR(optimizer, T_max=num_training_steps)
    
    def train_step(self, batch, optimizer, criterion):
        """Single training step"""
        self.model.train()
        
        input_ids = batch['input_ids'].to(self.device)
        labels = batch['labels'].to(self.device)
        
        # Forward pass
        outputs = self.model(input_ids)
        logits = self.task_head(outputs.last_hidden_state)
        
        loss = criterion(logits, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        optimizer.step()
        
        return loss.item(), logits
    
    def evaluate_step(self, batch, criterion):
        """Single evaluation step"""
        self.model.eval()
        
        with torch.no_grad():
            input_ids = batch['input_ids'].to(self.device)
            labels = batch['labels'].to(self.device)
            
            outputs = self.model(input_ids)
            logits = self.task_head(outputs.last_hidden_state)
            
            loss = criterion(logits, labels)
            
        return loss.item(), logits
    
    def fine_tune(self, train_loader, val_loader, num_epochs=5, 
                  learning_rate=1e-4, save_path=None):
        """Complete fine-tuning process"""
        
        # Setup training components
        optimizer = self.create_optimizer(learning_rate)
        num_training_steps = len(train_loader) * num_epochs
        scheduler = self.create_scheduler(optimizer, num_training_steps)
        
        # Loss function (assuming classification)
        criterion = nn.CrossEntropyLoss()
        
        # Training history
        history = {
            'train_loss': [],
            'val_loss': [],
            'train_acc': [],
            'val_acc': []
        }
        
        best_val_loss = float('inf')
        
        for epoch in range(num_epochs):
            print(f"\nEpoch {epoch+1}/{num_epochs}")
            print("-" * 50)
            
            # Training phase
            train_loss = 0.0
            train_correct = 0
            train_total = 0
            
            for batch_idx, batch in enumerate(train_loader):
                loss, logits = self.train_step(batch, optimizer, criterion)
                train_loss += loss
                
                # Calculate accuracy
                _, predicted = torch.max(logits.data, 1)
                train_total += batch['labels'].size(0)
                train_correct += (predicted == batch['labels'].to(self.device)).sum().item()
                
                scheduler.step()
                
                if batch_idx % 50 == 0:
                    print(f"Batch {batch_idx}/{len(train_loader)}, Loss: {loss:.4f}")
            
            # Validation phase
            val_loss = 0.0
            val_correct = 0
            val_total = 0
            
            for batch in val_loader:
                loss, logits = self.evaluate_step(batch, criterion)
                val_loss += loss
                
                _, predicted = torch.max(logits.data, 1)
                val_total += batch['labels'].size(0)
                val_correct += (predicted == batch['labels'].to(self.device)).sum().item()
            
            # Calculate epoch metrics
            train_loss /= len(train_loader)
            val_loss /= len(val_loader)
            train_acc = 100 * train_correct / train_total
            val_acc = 100 * val_correct / val_total
            
            # Store history
            history['train_loss'].append(train_loss)
            history['val_loss'].append(val_loss)
            history['train_acc'].append(train_acc)
            history['val_acc'].append(val_acc)
            
            print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
            print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
            
            # Save best model
            if val_loss < best_val_loss and save_path:
                best_val_loss = val_loss
                torch.save({
                    'model_state_dict': self.model.state_dict(),
                    'task_head_state_dict': self.task_head.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'epoch': epoch,
                    'val_loss': val_loss
                }, save_path)
                print(f"Best model saved to {save_path}")
        
        return history

print("Fine-tuning pipeline ready!")

## 4. Practical Fine-tuning Example

In [None]:
# Create synthetic fine-tuning data
def create_fine_tuning_data(num_samples=1000, seq_length=512):
    """Create synthetic genomic data for fine-tuning"""
    
    # Generate random DNA sequences
    nucleotides = ['A', 'T', 'G', 'C']
    sequences = []
    labels = []
    
    for _ in range(num_samples):
        # Generate sequence
        seq = ''.join(np.random.choice(nucleotides, seq_length))
        
        # Create synthetic label based on GC content
        gc_content = (seq.count('G') + seq.count('C')) / len(seq)
        label = 1 if gc_content > 0.5 else 0  # High vs low GC content
        
        sequences.append(seq)
        labels.append(label)
    
    return sequences, labels

# Create dataset
train_sequences, train_labels = create_fine_tuning_data(800, 256)
val_sequences, val_labels = create_fine_tuning_data(200, 256)

print(f"Created training data: {len(train_sequences)} samples")
print(f"Created validation data: {len(val_sequences)} samples")
print(f"Label distribution - Train: {np.bincount(train_labels)}")
print(f"Label distribution - Val: {np.bincount(val_labels)}")

# Visualize data distribution
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# GC content distribution
train_gc = [(seq.count('G') + seq.count('C')) / len(seq) for seq in train_sequences]
val_gc = [(seq.count('G') + seq.count('C')) / len(seq) for seq in val_sequences]

axes[0].hist(train_gc, alpha=0.7, label='Train', bins=20)
axes[0].hist(val_gc, alpha=0.7, label='Val', bins=20)
axes[0].set_xlabel('GC Content')
axes[0].set_ylabel('Frequency')
axes[0].set_title('GC Content Distribution')
axes[0].legend()

# Label distribution
axes[1].bar(['Low GC', 'High GC'], np.bincount(train_labels), alpha=0.7, label='Train')
axes[1].bar(['Low GC', 'High GC'], np.bincount(val_labels), alpha=0.7, label='Val')
axes[1].set_ylabel('Count')
axes[1].set_title('Label Distribution')
axes[1].legend()

plt.tight_layout()
plt.show()

## 5. Training Results Visualization

In [None]:
def plot_fine_tuning_results(history, strategy_name):
    """Plot fine-tuning training curves"""
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Loss curves
    axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Training Loss')
    axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Validation Loss')
    axes[0, 0].set_title(f'{strategy_name} - Loss Curves')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    # Accuracy curves
    axes[0, 1].plot(epochs, history['train_acc'], 'b-', label='Training Accuracy')
    axes[0, 1].plot(epochs, history['val_acc'], 'r-', label='Validation Accuracy')
    axes[0, 1].set_title(f'{strategy_name} - Accuracy Curves')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy (%)')
    axes[0, 1].legend()
    axes[0, 1].grid(True)
    
    # Loss difference
    loss_diff = np.array(history['val_loss']) - np.array(history['train_loss'])
    axes[1, 0].plot(epochs, loss_diff, 'g-', linewidth=2)
    axes[1, 0].axhline(y=0, color='k', linestyle='--', alpha=0.5)
    axes[1, 0].set_title('Overfitting Analysis (Val Loss - Train Loss)')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Loss Difference')
    axes[1, 0].grid(True)
    
    # Final metrics
    final_metrics = {
        'Final Train Loss': history['train_loss'][-1],
        'Final Val Loss': history['val_loss'][-1],
        'Final Train Acc': history['train_acc'][-1],
        'Final Val Acc': history['val_acc'][-1],
        'Best Val Acc': max(history['val_acc'])
    }
    
    metrics_text = '\n'.join([f'{k}: {v:.3f}' for k, v in final_metrics.items()])
    axes[1, 1].text(0.1, 0.5, metrics_text, fontsize=12, 
                    verticalalignment='center', fontfamily='monospace')
    axes[1, 1].set_title('Final Metrics')
    axes[1, 1].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return final_metrics

# Simulate training results for demonstration
def simulate_training_history(num_epochs=5, strategy='LoRA'):
    """Simulate realistic training history"""
    
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_acc': [],
        'val_acc': []
    }
    
    # Different convergence patterns for different strategies
    if strategy == 'LoRA':
        # LoRA typically has slower but more stable convergence
        base_train_loss, base_val_loss = 1.2, 1.3
        base_train_acc, base_val_acc = 60, 58
        decay_rate = 0.15
    else:  # Full fine-tuning
        # Full fine-tuning converges faster but may overfit
        base_train_loss, base_val_loss = 1.0, 1.1
        base_train_acc, base_val_acc = 65, 62
        decay_rate = 0.25
    
    for epoch in range(num_epochs):
        # Exponential decay with noise
        noise = np.random.normal(0, 0.02)
        
        train_loss = base_train_loss * np.exp(-decay_rate * epoch) + noise
        val_loss = base_val_loss * np.exp(-decay_rate * epoch * 0.8) + noise * 1.5
        
        train_acc = 100 - (100 - base_train_acc) * np.exp(-decay_rate * epoch) + noise * 2
        val_acc = 100 - (100 - base_val_acc) * np.exp(-decay_rate * epoch * 0.8) + noise * 3
        
        history['train_loss'].append(max(0.1, train_loss))
        history['val_loss'].append(max(0.1, val_loss))
        history['train_acc'].append(min(95, max(50, train_acc)))
        history['val_acc'].append(min(90, max(45, val_acc)))
    
    return history

# Compare strategies
lora_history = simulate_training_history(5, 'LoRA')
full_history = simulate_training_history(5, 'Full')

print("LoRA Fine-tuning Results:")
lora_metrics = plot_fine_tuning_results(lora_history, 'LoRA Fine-tuning')

print("\nFull Fine-tuning Results:")
full_metrics = plot_fine_tuning_results(full_history, 'Full Fine-tuning')

## 6. Best Practices and Tips

### Fine-tuning Guidelines

In [None]:
class FineTuningBestPractices:
    """Collection of fine-tuning best practices"""
    
    @staticmethod
    def learning_rate_recommendations():
        """Provide learning rate recommendations"""
        recommendations = {
            'Full Fine-tuning': {
                'range': '1e-5 to 5e-5',
                'reasoning': 'Lower LR to preserve pre-trained features',
                'typical': 2e-5
            },
            'LoRA': {
                'range': '1e-4 to 1e-3',
                'reasoning': 'Higher LR for adaptation layers only',
                'typical': 3e-4
            },
            'Task Head Only': {
                'range': '1e-3 to 1e-2',
                'reasoning': 'Random initialization requires higher LR',
                'typical': 5e-3
            }
        }
        
        print("Learning Rate Recommendations:")
        print("=" * 50)
        for strategy, info in recommendations.items():
            print(f"\n{strategy}:")
            print(f"  Range: {info['range']}")
            print(f"  Typical: {info['typical']}")
            print(f"  Reasoning: {info['reasoning']}")
    
    @staticmethod
    def data_requirements():
        """Provide data size recommendations"""
        requirements = {
            'Minimum': {
                'samples': 1000,
                'note': 'For simple binary classification with LoRA'
            },
            'Recommended': {
                'samples': 10000,
                'note': 'For robust performance across tasks'
            },
            'Optimal': {
                'samples': 100000,
                'note': 'For full fine-tuning and complex tasks'
            }
        }
        
        print("\nData Size Requirements:")
        print("=" * 50)
        for level, info in requirements.items():
            print(f"{level}: {info['samples']:,} samples")
            print(f"  Note: {info['note']}")
    
    @staticmethod
    def regularization_strategies():
        """Regularization techniques for fine-tuning"""
        strategies = {
            'Dropout': {
                'value': 0.1,
                'where': 'Task head and attention layers'
            },
            'Weight Decay': {
                'value': 0.01,
                'where': 'All trainable parameters'
            },
            'Gradient Clipping': {
                'value': 1.0,
                'where': 'Global norm clipping'
            },
            'Layer Freezing': {
                'strategy': 'Gradually unfreeze layers',
                'where': 'Start with embeddings, then lower layers'
            }
        }
        
        print("\nRegularization Strategies:")
        print("=" * 50)
        for technique, info in strategies.items():
            print(f"\n{technique}:")
            if 'value' in info:
                print(f"  Recommended value: {info['value']}")
            if 'strategy' in info:
                print(f"  Strategy: {info['strategy']}")
            print(f"  Application: {info['where']}")
    
    @staticmethod
    def evaluation_checklist():
        """Checklist for proper evaluation"""
        checklist = [
            "✓ Hold-out test set (never used during training)",
            "✓ Stratified sampling for balanced evaluation",
            "✓ Multiple random seeds for statistical significance",
            "✓ Cross-validation for small datasets",
            "✓ Task-specific metrics (not just accuracy)",
            "✓ Confusion matrix analysis",
            "✓ Learning curve analysis",
            "✓ Computational cost assessment",
            "✓ Comparison with baseline models",
            "✓ Error analysis and failure cases"
        ]
        
        print("\nEvaluation Checklist:")
        print("=" * 50)
        for item in checklist:
            print(item)

# Display best practices
practices = FineTuningBestPractices()
practices.learning_rate_recommendations()
practices.data_requirements()
practices.regularization_strategies()
practices.evaluation_checklist()

## 7. Advanced Techniques

In [None]:
class GradualUnfreezing:
    """Implement gradual unfreezing strategy"""
    
    def __init__(self, model, unfreeze_schedule):
        self.model = model
        self.unfreeze_schedule = unfreeze_schedule
        self.current_epoch = 0
        
    def update_frozen_layers(self, epoch):
        """Update which layers are frozen based on epoch"""
        self.current_epoch = epoch
        
        for layer_group, unfreeze_epoch in self.unfreeze_schedule.items():
            if epoch >= unfreeze_epoch:
                self._unfreeze_layer_group(layer_group)
                print(f"Epoch {epoch}: Unfroze {layer_group}")
    
    def _unfreeze_layer_group(self, layer_group):
        """Unfreeze specific layer group"""
        if layer_group == 'embeddings':
            for param in self.model.embeddings.parameters():
                param.requires_grad = True
        elif layer_group.startswith('layer_'):
            layer_idx = int(layer_group.split('_')[1])
            for param in self.model.layers[layer_idx].parameters():
                param.requires_grad = True

class DiscriminativeLearningRates:
    """Apply different learning rates to different layer groups"""
    
    def __init__(self, model, base_lr=1e-4, decay_factor=0.5):
        self.model = model
        self.base_lr = base_lr
        self.decay_factor = decay_factor
        
    def create_param_groups(self):
        """Create parameter groups with different learning rates"""
        param_groups = []
        
        # Task head - highest learning rate
        if hasattr(self.model, 'task_head'):
            param_groups.append({
                'params': self.model.task_head.parameters(),
                'lr': self.base_lr
            })
        
        # Upper layers - medium learning rate
        if hasattr(self.model, 'layers'):
            num_layers = len(self.model.layers)
            upper_layers = self.model.layers[num_layers//2:]
            param_groups.append({
                'params': [p for layer in upper_layers for p in layer.parameters()],
                'lr': self.base_lr * self.decay_factor
            })
            
            # Lower layers - lowest learning rate
            lower_layers = self.model.layers[:num_layers//2]
            param_groups.append({
                'params': [p for layer in lower_layers for p in layer.parameters()],
                'lr': self.base_lr * (self.decay_factor ** 2)
            })
        
        # Embeddings - very low learning rate
        if hasattr(self.model, 'embeddings'):
            param_groups.append({
                'params': self.model.embeddings.parameters(),
                'lr': self.base_lr * (self.decay_factor ** 3)
            })
        
        return param_groups

class CurriculumLearning:
    """Implement curriculum learning for genomic tasks"""
    
    def __init__(self, difficulty_fn, initial_ratio=0.3, final_ratio=1.0):
        self.difficulty_fn = difficulty_fn
        self.initial_ratio = initial_ratio
        self.final_ratio = final_ratio
        
    def get_curriculum_subset(self, dataset, epoch, total_epochs):
        """Get subset of data based on curriculum"""
        # Calculate current ratio
        progress = epoch / total_epochs
        current_ratio = self.initial_ratio + progress * (self.final_ratio - self.initial_ratio)
        
        # Sort by difficulty
        difficulties = [self.difficulty_fn(sample) for sample in dataset]
        sorted_indices = np.argsort(difficulties)
        
        # Select subset
        num_samples = int(len(dataset) * current_ratio)
        selected_indices = sorted_indices[:num_samples]
        
        return selected_indices

# Example difficulty functions
def sequence_length_difficulty(sample):
    """Difficulty based on sequence length"""
    return len(sample['sequence'])

def gc_content_difficulty(sample):
    """Difficulty based on GC content deviation from 50%"""
    seq = sample['sequence']
    gc_content = (seq.count('G') + seq.count('C')) / len(seq)
    return abs(gc_content - 0.5)

print("Advanced fine-tuning techniques implemented!")
print("\nTechniques available:")
print("- Gradual Unfreezing: Progressively unfreeze model layers")
print("- Discriminative Learning Rates: Different LR for different layers")
print("- Curriculum Learning: Train on easier examples first")

## Summary

This notebook covered comprehensive fine-tuning strategies for Hyena-GLT models:

### Key Concepts:
1. **Fine-tuning Strategies**: Full fine-tuning vs LoRA vs parameter-efficient methods
2. **Task Adaptation**: Classification, regression, and multi-task heads
3. **Training Pipeline**: Complete implementation with monitoring
4. **Best Practices**: Learning rates, regularization, evaluation
5. **Advanced Techniques**: Gradual unfreezing, discriminative LR, curriculum learning

### Next Steps:
- Experiment with different fine-tuning strategies for your specific task
- Implement proper evaluation and comparison frameworks
- Consider computational constraints when choosing strategies
- Monitor for overfitting and apply appropriate regularization

Continue to the next notebook: **07_generation.ipynb** for text generation with Hyena-GLT!