<a href="https://colab.research.google.com/github/MrSharon/Arona/blob/main/Arona.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Core Model Architecture

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class RelativePositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=10000):
        super().__init__()
        self.d_model = d_model
        self.max_len = max_len
        self.pe = self._create_relative_positional_encoding()

    def _create_relative_positional_encoding(self):
        # Implementation of relative positional encoding for extended context
        position = torch.arange(0, self.max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, self.d_model, 2) * -(math.log(10000.0) / self.d_model))
        pe = torch.zeros(self.max_len, self.d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe

    def forward(self, x):
        # Apply relative positional encoding
        # x shape: [batch_size, seq_len, d_model]
        batch_size, seq_len = x.size(0), x.size(1)
        relative_positions = torch.arange(-seq_len+1, seq_len).unsqueeze(0).expand(batch_size, -1)
        relative_positions = torch.clamp(relative_positions, 0, self.max_len-1)
        return x + self.pe[relative_positions]

class ExtendedAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.o_proj = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, mask=None, rel_pos=None):
        batch_size = q.size(0)

        # Project to multi-head
        q = self.q_proj(q).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(k).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(v).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # Scaled dot-product attention with relative positional encoding
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

        # Add relative positional bias if provided
        if rel_pos is not None:
            scores = scores + rel_pos

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attention = F.softmax(scores, dim=-1)
        context = torch.matmul(attention, v)

        # Reshape back to batch_size x seq_len x d_model
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.o_proj(context)

        return output, attention

class HierarchicalMemoryBlock(nn.Module):
    def __init__(self, d_model, memory_size=1000):
        super().__init__()
        self.d_model = d_model
        self.memory_size = memory_size

        # Memory stores
        self.active_memory = nn.Parameter(torch.zeros(memory_size, d_model))
        self.short_term_memory = nn.Parameter(torch.zeros(memory_size, d_model))
        self.long_term_memory = nn.Parameter(torch.zeros(memory_size, d_model))

        # Query projections
        self.query_proj = nn.Linear(d_model, d_model)

        # Memory compression
        self.compressor = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Linear(d_model // 2, d_model)
        )

    def store_memory(self, input_data, memory_type="active"):
        # Compress and store new information
        compressed = self.compressor(input_data)

        if memory_type == "active":
            # Shift active memory and add new data
            self.active_memory = torch.cat([compressed, self.active_memory[:-compressed.size(0)]], dim=0)
        elif memory_type == "short_term":
            # Store important information from active to short-term
            importance = self._calculate_importance(input_data)
            mask = importance > 0.5  # Threshold for importance
            self.short_term_memory = torch.cat(
                [compressed[mask], self.short_term_memory[:-torch.sum(mask).item()]], dim=0
            )
        elif memory_type == "long_term":
            # Store critical information to long-term
            importance = self._calculate_importance(input_data)
            mask = importance > 0.8  # Higher threshold for long-term
            self.long_term_memory = torch.cat(
                [compressed[mask], self.long_term_memory[:-torch.sum(mask).item()]], dim=0
            )

    def _calculate_importance(self, data):
        # Calculate importance score for memory items
        # This could be based on emotional salience, novelty, etc.
        norm = torch.norm(data, dim=-1)
        return torch.sigmoid(norm - 5.0)  # Example threshold

    def retrieve_memory(self, query, top_k=5):
        # Project query
        query_proj = self.query_proj(query)

        # Calculate relevance scores
        active_scores = F.cosine_similarity(query_proj.unsqueeze(1), self.active_memory.unsqueeze(0), dim=-1)
        short_term_scores = F.cosine_similarity(query_proj.unsqueeze(1), self.short_term_memory.unsqueeze(0), dim=-1)
        long_term_scores = F.cosine_similarity(query_proj.unsqueeze(1), self.long_term_memory.unsqueeze(0), dim=-1)

        # Combine scores with decreasing weights for further memories
        combined_memories = torch.cat([
            self.active_memory * 1.0,
            self.short_term_memory * 0.7,
            self.long_term_memory * 0.5
        ], dim=0)

        combined_scores = torch.cat([
            active_scores * 1.0,
            short_term_scores * 0.7,
            long_term_scores * 0.5
        ], dim=1)

        # Get top-k memories
        top_k_scores, top_k_indices = torch.topk(combined_scores, k=top_k, dim=1)
        top_k_weights = F.softmax(top_k_scores, dim=1)

        # Retrieve and weight memories
        batch_size = query.size(0)
        retrieved_memories = torch.zeros(batch_size, self.d_model, device=query.device)

        for i in range(batch_size):
            selected_memories = combined_memories[top_k_indices[i]]
            retrieved_memories[i] = torch.sum(selected_memories * top_k_weights[i].unsqueeze(1), dim=0)

        return retrieved_memories

Tranining Architecture


In [2]:
class MultiObjectiveLoss(nn.Module):
    def __init__(self, alpha=1.0, beta=0.5, gamma=0.3):
        super().__init__()
        self.alpha = alpha  # Weight for language modeling
        self.beta = beta    # Weight for context retention
        self.gamma = gamma  # Weight for conversation structure

    def forward(self, lm_loss, context_loss, conv_loss):
        return self.alpha * lm_loss + self.beta * context_loss + self.gamma * conv_loss

def train_epoch(model, dataloader, optimizer, scheduler, device):
    model.train()
    total_loss = 0

    for batch in dataloader:
        # Get data
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        context_ids = batch["context_ids"].to(device)  # Additional context information

        # Forward pass
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            context_ids=context_ids
        )

        # Calculate losses
        lm_loss = F.cross_entropy(outputs.logits.view(-1, outputs.logits.size(-1)), labels.view(-1))
        context_loss = calculate_context_retention_loss(outputs.context_predictions, batch["context_labels"].to(device))
        conv_loss = calculate_conversation_structure_loss(outputs.conv_predictions, batch["conv_labels"].to(device))

        # Combined loss
        criterion = MultiObjectiveLoss(alpha=1.0, beta=0.5, gamma=0.3)
        loss = criterion(lm_loss, context_loss, conv_loss)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        scheduler.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

# Distributed training setup
def setup_distributed_training(model, world_size):
    """Set up model for distributed training"""
    # Model parallelism - split layers across GPUs
    model = torch.nn.parallel.DistributedDataParallel(model)

    # Mixed precision training
    scaler = torch.cuda.amp.GradScaler()

    return model, scaler

Emotional Intellegence Componenet

In [3]:
class EmotionRecognitionModule(nn.Module):
    def __init__(self, d_model, num_emotions=8):
        super().__init__()
        self.d_model = d_model
        self.num_emotions = num_emotions

        # Token-level emotion detection
        self.token_emotion = nn.Linear(d_model, num_emotions)

        # Utterance-level aggregation
        self.query_emotion = nn.Parameter(torch.randn(d_model))

        # Conversation-level tracking
        self.emotion_gru = nn.GRU(num_emotions, num_emotions, batch_first=True)

        # Dimensional emotion representation (valence, arousal, dominance)
        self.emotion_dimensions = nn.Linear(d_model, 3)

    def forward(self, hidden_states, utterance_boundaries=None):
        batch_size, seq_len = hidden_states.size(0), hidden_states.size(1)

        # Token-level emotion
        token_emotions = self.token_emotion(hidden_states)  # [batch, seq, num_emotions]

        # Utterance-level emotion with attention
        query = self.query_emotion.expand(batch_size, 1, self.d_model)
        attention_scores = torch.bmm(query, hidden_states.transpose(1, 2)).squeeze(1)
        attention_weights = F.softmax(attention_scores, dim=1).unsqueeze(1)
        utterance_emotion = torch.bmm(attention_weights, token_emotions).squeeze(1)

        # Split by utterance boundaries if provided
        if utterance_boundaries is not None:
            utterance_emotions = []
            for b in range(batch_size):
                # Extract utterances based on boundaries
                boundaries = utterance_boundaries[b]
                batch_utterances = []

                for i in range(len(boundaries) - 1):
                    start, end = boundaries[i], boundaries[i+1]
                    # Average emotions for this utterance
                    utterance_avg = token_emotions[b, start:end].mean(dim=0)
                    batch_utterances.append(utterance_avg)

                utterance_emotions.append(torch.stack(batch_utterances))

            # Pad to same length
            max_utterances = max(len(u) for u in utterance_emotions)
            padded_emotions = torch.zeros(batch_size, max_utterances, self.num_emotions, device=hidden_states.device)
            for b, emotions in enumerate(utterance_emotions):
                padded_emotions[b, :len(emotions)] = emotions

            # Conversation-level emotion tracking with GRU
            conversation_emotions, _ = self.emotion_gru(padded_emotions)
        else:
            # If no boundaries, treat sequence as one utterance
            conversation_emotions = utterance_emotion.unsqueeze(1)

        # Dimensional emotion representation
        emotion_dims = self.emotion_dimensions(hidden_states)  # [batch, seq, 3]
        valence, arousal, dominance = emotion_dims.chunk(3, dim=-1)

        return {
            "token_emotions": token_emotions,
            "utterance_emotion": utterance_emotion,
            "conversation_emotions": conversation_emotions,
            "emotion_dimensions": {
                "valence": valence.squeeze(-1),
                "arousal": arousal.squeeze(-1),
                "dominance": dominance.squeeze(-1)
            }
        }

Emotion Response Generation

In [4]:
class EmotionalResponseGenerator(nn.Module):
    def __init__(self, vocab_size, d_model, emotion_dim=8):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.emotion_dim = emotion_dim

        # Emotion conditioning layer
        self.emotion_conditioning = nn.Linear(emotion_dim, d_model)

        # Vocabulary biasing
        self.emotion_vocab_bias = nn.Linear(emotion_dim, vocab_size)

        # Output projection
        self.output_projection = nn.Linear(d_model, vocab_size)

    def forward(self, hidden_states, target_emotion):
        batch_size, seq_len = hidden_states.size(0), hidden_states.size(1)

        # Condition hidden states on target emotion
        emotion_features = self.emotion_conditioning(target_emotion).unsqueeze(1)
        conditioned_states = hidden_states + emotion_features

        # Generate base logits
        base_logits = self.output_projection(conditioned_states)

        # Apply emotion-specific vocabulary bias
        emotion_bias = self.emotion_vocab_bias(target_emotion).unsqueeze(1)
        biased_logits = base_logits + 0.1 * emotion_bias  # Scale factor for bias

        return biased_logits

    def generate_appropriate_emotion(self, user_emotion, context_state):
        # Calculate appropriate emotional response
        # - Mirror user emotion with dampening for negative emotions
        # - Enhance positive emotions slightly
        valence = user_emotion[:, 0]  # Assuming first dimension is valence

        # If negative emotion, respond with more positive but understanding emotion
        response_valence = torch.where(
            valence < 0,
            valence * 0.5 + 0.2,  # Dampen negative, shift positive
            valence * 1.1          # Slightly enhance positive
        )

        # For other dimensions (arousal, dominance), match with moderation
        response_emotion = user_emotion.clone()
        response_emotion[:, 0] = response_valence
        response_emotion[:, 1] = user_emotion[:, 1] * 0.8  # Moderate arousal

        return response_emotion

 Conversational Initiative System
 Initiative Detection


In [5]:
class InitiativeSystem(nn.Module):
    def __init__(self, d_model, d_hidden=128):
        super().__init__()
        self.d_model = d_model

        # Initiative prediction network
        self.initiative_predictor = nn.Sequential(
            nn.Linear(d_model, d_hidden),
            nn.ReLU(),
            nn.Linear(d_hidden, 1),
            nn.Sigmoid()
        )

        # Initiative type classifier
        self.initiative_type = nn.Sequential(
            nn.Linear(d_model, d_hidden),
            nn.ReLU(),
            nn.Linear(d_hidden, 4)  # 4 types: question, suggestion, follow-up, info-sharing
        )

        # Initiative threshold parameters
        self.base_threshold = 0.7
        self.decay_rate = 0.1

    def should_take_initiative(self, context_state, time_since_last_message):
        # Calculate initiative score
        initiative_score = self.initiative_predictor(context_state).squeeze(-1)

        # Calculate adaptive threshold based on time
        threshold = self.base_threshold * (1.0 - torch.exp(-self.decay_rate * time_since_last_message))

        # Decision to take initiative
        should_initiate = initiative_score > threshold

        # If initiative should be taken, determine type
        initiative_type_logits = self.initiative_type(context_state)
        initiative_type = torch.argmax(initiative_type_logits, dim=-1)

        return should_initiate, initiative_type

    def calculate_follow_up_score(self, project_info):
        # Calculate when to follow up on projects/activities
        importance = project_info['importance']
        time_since_mention = project_info['time_since_last_mentioned']
        status = project_info['status_factor']

        # Higher score means higher priority for follow-up
        follow_up_score = importance * (time_since_mention + 1) * status

        return follow_up_score

Project and Activity Tracking

In [6]:
class ProjectTracker:
    def __init__(self):
        self.projects = {}
        self.current_time = 0

    def update_time(self, increment=1):
        self.current_time += increment

    def add_project(self, project_id, name, details="", importance=0.5, status="active"):
        """Add a new project to tracking"""
        self.projects[project_id] = {
            "id": project_id,
            "name": name,
            "last_mentioned": self.current_time,
            "status": status,
            "details": details,
            "importance": importance,
            "follow_up_schedule": self.current_time + int(10 / importance)  # More important = earlier follow-up
        }

    def update_project(self, project_id, **kwargs):
        """Update project information"""
        if project_id in self.projects:
            for key, value in kwargs.items():
                if key in self.projects[project_id]:
                    self.projects[project_id][key] = value

            # Auto-update last_mentioned
            self.projects[project_id]["last_mentioned"] = self.current_time

    def get_projects_for_follow_up(self):
        """Get projects that need follow-up"""
        follow_up_candidates = []

        for project_id, project in self.projects.items():
            if project["status"] != "completed":
                # Calculate follow-up score
                time_since_mention = self.current_time - project["last_mentioned"]
                status_factor = 1.5 if project["status"] == "blocked" else 1.0

                score = project["importance"] * time_since_mention * status_factor

                if score > 5.0:  # Threshold for follow-up
                    follow_up_candidates.append((project_id, score))

        # Sort by score
        follow_up_candidates.sort(key=lambda x: x[1], reverse=True)

        return [self.projects[pid] for pid, _ in follow_up_candidates]

    def extract_project_info(self, text):
        """Extract project information from text (simplified)"""
        # In a real implementation, this would use NER and relation extraction
        # This is a placeholder implementation

        # Simple keyword matching
        project_keywords = ["project", "task", "work on", "assignment"]

        detected_projects = []
        # Simplified detection logic
        if any(keyword in text.lower() for keyword in project_keywords):
            # Extract project name (simplified)
            words = text.split()
            for i, word in enumerate(words):
                if word.lower() in project_keywords and i+1 < len(words):
                    project_name = words[i+1]

                    # Generate ID
                    project_id = f"proj_{len(self.projects)}"

                    # Estimate importance (simplified)
                    importance_words = ["urgent", "important", "critical", "crucial"]
                    importance = 0.5  # Default
                    for imp_word in importance_words:
                        if imp_word in text.lower():
                            importance = 0.8
                            break

                    detected_projects.append({
                        "id": project_id,
                        "name": project_name,
                        "importance": importance
                    })

        return detected_projects

Integration Architecture

In [7]:
class ContextAwareEmotionalLLM(nn.Module):
    def __init__(
        self,
        vocab_size,
        d_model=768,
        n_layers=12,
        n_heads=12,
        max_seq_len=2048,
        emotion_dim=8
    ):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size

        # Embedding layers
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = RelativePositionalEncoding(d_model, max_seq_len)

        # Transformer layers
        self.layers = nn.ModuleList([
            TransformerLayer(d_model, n_heads) for _ in range(n_layers)
        ])

        # Memory system
        self.memory = HierarchicalMemoryBlock(d_model)

        # Emotion system
        self.emotion_recognition = EmotionRecognitionModule(d_model, emotion_dim)
        self.emotional_generator = EmotionalResponseGenerator(vocab_size, d_model, emotion_dim)

        # Initiative system
        self.initiative_system = InitiativeSystem(d_model)

        # Project tracking
        self.project_tracker = ProjectTracker()

        # Output layer
        self.output_projection = nn.Linear(d_model, vocab_size)

    def forward(self, input_ids, attention_mask=None, context_ids=None):
        batch_size, seq_len = input_ids.size()

        # Embeddings
        token_embeds = self.token_embedding(input_ids)
        position_embeds = self.positional_encoding(token_embeds)
        hidden_states = token_embeds + position_embeds

        # Process through transformer layers
        for layer in self.layers:
            hidden_states = layer(hidden_states, attention_mask)

        # Retrieve from memory if context provided
        if context_ids is not None:
            context_embeds = self.token_embedding(context_ids)
            memory_states = self.memory.retrieve_memory(context_embeds)
            hidden_states = hidden_states + memory_states

        # Recognize emotions
        emotion_data = self.emotion_recognition(hidden_states)
        user_emotion = emotion_data["utterance_emotion"]

        # Generate appropriate emotional response
        target_emotion = self.emotional_generator.generate_appropriate_emotion(
            user_emotion, hidden_states[:, -1]
        )

        # Check for initiative
        time_since_last = torch.ones(batch_size)  # Placeholder, would come from context
        should_initiate, initiative_type = self.initiative_system.should_take_initiative(
            hidden_states[:, -1], time_since_last
        )

        # Generate output logits with emotional conditioning
        logits = self.emotional_generator(hidden_states, target_emotion)

        # Store important information in memory
        self.memory.store_memory(hidden_states[:, -1].detach(), "active")

        # Update project tracking from input
        for b in range(batch_size):
            text = "Sample text"  # Would decode from input_ids
            projects = self.project_tracker.extract_project_info(text)
            for project in projects:
                self.project_tracker.add_project(**project)

        return {
            "logits": logits,
            "emotion_data": emotion_data,
            "target_emotion": target_emotion,
            "initiative": {
                "should_initiate": should_initiate,
                "initiative_type": initiative_type
            }
        }

    def generate(self, input_ids, max_length=100, temperature=0.7, do_sample=True):
        """Simple generation function"""
        batch_size = input_ids.size(0)
        current_ids = input_ids.clone()

        for _ in range(max_length):
            # Forward pass
            outputs = self.forward(current_ids)
            next_token_logits = outputs["logits"][:, -1, :]

            # Apply temperature
            next_token_logits = next_token_logits / temperature

            # Sample or greedy
            if do_sample:
                probs = F.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
            else:
                next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)

            # Concatenate new tokens
            current_ids = torch.cat([current_ids, next_token], dim=1)

            # Check for EOS
            if (next_token == eos_token_id).all():
                break

        return current_ids

Training and Optimization Framework

Multi-Objective Training Strategy

In [8]:
def training_framework(model, train_dataloader, val_dataloader, num_epochs=10):
    # Optimizer setup
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=5e-5,
        weight_decay=0.01,
        eps=1e-8
    )

    # Learning rate scheduler
    total_steps = len(train_dataloader) * num_epochs
    warmup_steps = int(0.1 * total_steps)

    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )

    # Training loop
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        # Train
        train_loss = train_epoch(model, train_dataloader, optimizer, scheduler, device)

        # Validate
        val_loss = validate_model(model, val_dataloader, device)

        # Track metrics
        print(f"Epoch {epoch+1}/{num_epochs} - Train loss: {train_loss:.4f}, Val loss: {val_loss:.4f}")

        # Save checkpoint if improved
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_checkpoint(model, optimizer, scheduler, epoch, val_loss)

    # Final evaluation
    test_metrics = evaluate_model(model, test_dataloader, device)

    return model, test_metrics

Preference Learning and Reinforcement

In [9]:
class RewardModel(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model

        self.reward_head = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Linear(d_model // 2, 1)
        )

    def forward(self, hidden_states):
        # Extract CLS token or use mean pooling
        pooled = hidden_states.mean(dim=1)
        reward = self.reward_head(pooled)
        return reward

def train_with_dpo(model, optimizer, dataloader, reward_model, beta=0.1):
    model.train()
    total_loss = 0

    for batch in dataloader:
        # Get preferred and rejected responses for each context
        contexts = batch["contexts"].to(device)
        preferred = batch["preferred_responses"].to(device)
        rejected = batch["rejected_responses"].to(device)

        # Get model outputs
        with torch.no_grad():
            # Generate log probs for current model
            preferred_outputs = model(contexts, preferred)
            rejected_outputs = model(contexts, rejected)

            # Get reward scores from reward model
            preferred_rewards = reward_model(preferred_outputs.hidden_states).squeeze(-1)
            rejected_rewards = reward_model(rejected_outputs.hidden_states).squeeze(-1)

        # Calculate DPO loss
        # L_DPO(θ) = -E[(log σ(β(r_θ(x,y_w) - r_θ(x,y_l)))]
        reward_diff = preferred_rewards - rejected_rewards
        loss = -torch.log(torch.sigmoid(beta * reward_diff)).mean()

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

Continual Learning Framework

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, ConcatDataset, Subset
import copy
from typing import Dict, List, Tuple, Optional, Union, Callable, Any
import logging
import os
import json
import matplotlib.pyplot as plt
from collections import defaultdict
import random
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class ContinualLearningFramework:
    """
    A comprehensive framework for implementing various continual learning methods.
    This framework supports:
    - Experience Replay
    - Elastic Weight Consolidation (EWC)
    - Knowledge Distillation
    - Task-specific adapters
    - Learning without Forgetting (LwF)
    - Synaptic Intelligence (SI)
    - Functional Regularization
    - Progressive Neural Networks
    """

    def __init__(
        self,
        model: nn.Module,
        optimizer_class: torch.optim.Optimizer = optim.Adam,
        optimizer_kwargs: Dict = {'lr': 0.001},
        device: str = 'cuda' if torch.cuda.is_available() else 'cpu',

        # Experience Replay parameters
        replay_buffer_size: int = 1000,
        replay_strategy: str = 'random',  # 'random', 'prioritized', 'class_balanced'
        replay_frequency: int = 5,  # Every N batches

        # EWC parameters
        use_ewc: bool = False,
        ewc_lambda: float = 100.0,
        ewc_fisher_sample_size: int = 200,
        ewc_mode: str = 'separate',  # 'separate' or 'online'

        # Distillation parameters
        use_distillation: bool = False,
        distillation_temp: float = 2.0,
        distillation_alpha: float = 0.5,

        # Adapter parameters
        use_adapters: bool = False,
        adapter_size: int = 64,
        freeze_base_model: bool = True,

        # LwF parameters
        use_lwf: bool = False,
        lwf_alpha: float = 1.0,
        lwf_temp: float = 2.0,

        # Synaptic Intelligence parameters
        use_si: bool = False,
        si_lambda: float = 1.0,
        si_omega_decay: float = 0.95,

        # Progressive networks parameters
        use_progressive_nets: bool = False,

        # Functional regularization parameters
        use_func_regularization: bool = False,
        func_reg_lambda: float = 1.0,
        func_sample_size: int = 200,

        # Misc parameters
        checkpoint_dir: str = './checkpoints',
        log_metrics: bool = True,
        verbose: bool = True
    ):
        """
        Initialize the continual learning framework.

        Args:
            model: The neural network model
            optimizer_class: The optimizer class to use
            optimizer_kwargs: Arguments for the optimizer
            device: Device to run the model on

            # Experience Replay parameters
            replay_buffer_size: Size of experience replay buffer
            replay_strategy: Strategy for sampling from replay buffer ('random', 'prioritized', 'class_balanced')
            replay_frequency: How often to perform replay (every N batches)

            # EWC parameters
            use_ewc: Whether to use Elastic Weight Consolidation
            ewc_lambda: EWC regularization strength
            ewc_fisher_sample_size: Number of samples to estimate Fisher information
            ewc_mode: 'separate' (maintain separate Fisher matrices) or 'online' (update single matrix)

            # Distillation parameters
            use_distillation: Whether to use knowledge distillation
            distillation_temp: Temperature for distillation
            distillation_alpha: Weight for distillation loss

            # Adapter parameters
            use_adapters: Whether to use task-specific adapters
            adapter_size: Size of adapter hidden representations
            freeze_base_model: Whether to freeze the base model when using adapters

            # LwF parameters
            use_lwf: Whether to use Learning without Forgetting
            lwf_alpha: Weight for LwF loss
            lwf_temp: Temperature for LwF

            # Synaptic Intelligence parameters
            use_si: Whether to use Synaptic Intelligence
            si_lambda: SI regularization strength
            si_omega_decay: Decay rate for importance weights

            # Progressive networks parameters
            use_progressive_nets: Whether to use Progressive Neural Networks

            # Functional regularization parameters
            use_func_regularization: Whether to use functional regularization
            func_reg_lambda: Weight for functional regularization
            func_sample_size: Number of samples for functional regularization

            # Misc parameters
            checkpoint_dir: Directory to save checkpoints
            log_metrics: Whether to log metrics
            verbose: Whether to print verbose output
        """
        self.model = model.to(device)
        self.base_model = copy.deepcopy(model) if use_progressive_nets else None
        self.optimizer_class = optimizer_class
        self.optimizer_kwargs = optimizer_kwargs
        self.device = device
        self.current_task_id = None
        self.seen_tasks = set()
        self.verbose = verbose

        # Experience Replay
        self.replay_buffer_size = replay_buffer_size
        self.replay_strategy = replay_strategy
        self.replay_frequency = replay_frequency
        self.replay_buffer = {}  # task_id -> [(x, y, importance), ...]
        self.replay_counter = 0

        # EWC
        self.use_ewc = use_ewc
        self.ewc_lambda = ewc_lambda
        self.ewc_fisher_sample_size = ewc_fisher_sample_size
        self.ewc_mode = ewc_mode
        self.fisher_dict = {}  # task_id -> {param_name: fisher}
        self.param_dict = {}   # task_id -> {param_name: param_value}

        # Knowledge Distillation
        self.use_distillation = use_distillation
        self.distillation_temp = distillation_temp
        self.distillation_alpha = distillation_alpha
        self.teacher_model = None

        # Task-specific Adapters
        self.use_adapters = use_adapters
        self.adapter_size = adapter_size
        self.freeze_base_model = freeze_base_model
        self.adapters = nn.ModuleDict()  # task_id -> adapter

        # Learning without Forgetting
        self.use_lwf = use_lwf
        self.lwf_alpha = lwf_alpha
        self.lwf_temp = lwf_temp
        self.lwf_models = {}  # task_id -> model

        # Synaptic Intelligence
        self.use_si = use_si
        self.si_lambda = si_lambda
        self.si_omega_decay = si_omega_decay
        self.si_omega = {}  # Parameter importance
        self.si_prev_params = {}  # Previous parameter values
        self.si_accumulated_delta = {}  # Accumulated delta*gradient

        # Progressive Neural Networks
        self.use_progressive_nets = use_progressive_nets
        self.column_models = nn.ModuleDict()  # task_id -> column model

        # Functional Regularization
        self.use_func_regularization = use_func_regularization
        self.func_reg_lambda = func_reg_lambda
        self.func_sample_size = func_sample_size
        self.func_samples = {}  # task_id -> [(x, output), ...]

        # Metrics and Logging
        self.log_metrics = log_metrics
        self.metrics = defaultdict(list)
        self.checkpoint_dir = checkpoint_dir
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        # Initialize optimizer
        self.optimizer = self.optimizer_class(self.model.parameters(), **self.optimizer_kwargs)

        # Initialize SI parameters if needed
        if self.use_si:
            self._init_si_params()

        if self.verbose:
            logger.info(f"Initialized Continual Learning Framework on device: {device}")
            logger.info(f"Active methods: EWC={use_ewc}, Distillation={use_distillation}, Adapters={use_adapters}, "
                      f"LwF={use_lwf}, SI={use_si}, ProgressiveNets={use_progressive_nets}, "
                      f"FunctionalReg={use_func_regularization}")

    def _init_si_params(self):
        """Initialize parameters for Synaptic Intelligence."""
        for n, p in self.model.named_parameters():
            if p.requires_grad:
                self.si_omega[n] = torch.zeros_like(p.data)
                self.si_prev_params[n] = p.data.clone()
                self.si_accumulated_delta[n] = torch.zeros_like(p.data)

    def _compute_fisher_information(self, data_loader: DataLoader, num_samples: int = None) -> Dict[str, torch.Tensor]:
        """
        Compute the Fisher information matrix for parameters.

        Args:
            data_loader: DataLoader with the current task data
            num_samples: Number of samples to use for estimation

        Returns:
            Dictionary with parameter names and their Fisher information
        """
        if num_samples is None:
            num_samples = self.ewc_fisher_sample_size

        self.model.eval()
        fisher = {n: torch.zeros_like(p) for n, p in self.model.named_parameters() if p.requires_grad}

        if self.verbose:
            logger.info(f"Computing Fisher information matrix using {num_samples} samples")

        sample_count = 0
        for data_batch in data_loader:
            if sample_count >= num_samples:
                break

            inputs, targets = data_batch
            inputs = inputs.to(self.device)
            targets = targets.to(self.device)
            batch_size = inputs.shape[0]

            self.model.zero_grad()
            outputs = self.model(inputs)

            # For each sample, compute the fisher information
            for i in range(min(batch_size, num_samples - sample_count)):
                log_prob = torch.nn.functional.log_softmax(outputs[i:i+1], dim=1)

                # Use actual prediction for fisher computation
                pred = torch.argmax(log_prob, dim=1)
                loss = -log_prob[0, pred]
                loss.backward(retain_graph=(i < batch_size - 1))

                # Accumulate the gradients
                for n, p in self.model.named_parameters():
                    if p.grad is not None and p.requires_grad:
                        fisher[n] += p.grad.pow(2).detach() / num_samples

                self.model.zero_grad()
                sample_count += 1

        if self.verbose:
            logger.info(f"Fisher information matrix computed with {sample_count} samples")
        return fisher

    def _save_current_parameters(self) -> Dict[str, torch.Tensor]:
        """
        Save the current parameters of the model.

        Returns:
            Dictionary with parameter names and their values
        """
        return {n: p.detach().clone() for n, p in self.model.named_parameters() if p.requires_grad}

    def _add_to_replay_buffer(self, task_id: str, data: List[Tuple[torch.Tensor, torch.Tensor]], compute_importance: bool = False):
        """
        Add data to the replay buffer for a given task.

        Args:
            task_id: Identifier for the task
            data: List of (input, target) tuples
            compute_importance: Whether to compute importance scores for prioritized replay
        """
        if task_id not in self.replay_buffer:
            self.replay_buffer[task_id] = []

        # Compute importance scores if using prioritized replay
        if compute_importance and self.replay_strategy == 'prioritized':
            with torch.no_grad():
                for x, y in data:
                    x = x.to(self.device).unsqueeze(0)
                    y = y.to(self.device).unsqueeze(0)
                    output = self.model(x)
                    prob = torch.softmax(output, dim=1)
                    # Importance is inverse of confidence (lower confidence -> higher importance)
                    importance = 1.0 - prob[0, y.item()].item()
                    self.replay_buffer[task_id].append((x.squeeze(0).cpu(), y.squeeze(0).cpu(), importance))
        else:
            # Default importance of 0.5 for random or class-balanced replay
            self.replay_buffer[task_id].extend([(x, y, 0.5) for x, y in data])

        # Limit the buffer size by sampling according to the chosen strategy
        if len(self.replay_buffer[task_id]) > self.replay_buffer_size:
            if self.replay_strategy == 'random':
                indices = np.random.choice(
                    len(self.replay_buffer[task_id]),
                    self.replay_buffer_size,
                    replace=False
                )
                self.replay_buffer[task_id] = [self.replay_buffer[task_id][i] for i in indices]

            elif self.replay_strategy == 'prioritized':
                # Sort by importance (highest first) and keep top samples
                self.replay_buffer[task_id].sort(key=lambda x: x[2], reverse=True)
                self.replay_buffer[task_id] = self.replay_buffer[task_id][:self.replay_buffer_size]

            elif self.replay_strategy == 'class_balanced':
                # Group by class and sample equally from each class
                class_samples = {}
                for x, y, imp in self.replay_buffer[task_id]:
                    y_item = y.item()
                    if y_item not in class_samples:
                        class_samples[y_item] = []
                    class_samples[y_item].append((x, y, imp))

                # Calculate samples per class
                num_classes = len(class_samples)
                samples_per_class = self.replay_buffer_size // num_classes
                remainder = self.replay_buffer_size % num_classes

                # Create a balanced buffer
                balanced_buffer = []
                for cls, samples in class_samples.items():
                    # Add extra sample to classes with remainder
                    cls_samples = samples_per_class + (1 if remainder > 0 else 0)
                    remainder -= 1 if remainder > 0 else 0

                    # Random sample if we have more than needed
                    if len(samples) > cls_samples:
                        indices = np.random.choice(len(samples), cls_samples, replace=False)
                        balanced_buffer.extend([samples[i] for i in indices])
                    else:
                        balanced_buffer.extend(samples)

                self.replay_buffer[task_id] = balanced_buffer

        if self.verbose:
            logger.info(f"Added data to replay buffer for task {task_id}, buffer size: {len(self.replay_buffer[task_id])}")

    def _create_adapter(self, task_id: str, input_dim: int, output_dim: int):
        """
        Create a task-specific adapter module.

        Args:
            task_id: Task identifier
            input_dim: Input dimension of the adapter
            output_dim: Output dimension of the adapter
        """
        adapter = nn.Sequential(
            nn.Linear(input_dim, self.adapter_size),
            nn.ReLU(),
            nn.Linear(self.adapter_size, output_dim)
        ).to(self.device)

        self.adapters[task_id] = adapter
        if self.verbose:
            logger.info(f"Created adapter for task {task_id}")

    def _create_column_model(self, task_id: str):
        """
        Create a new column model for Progressive Neural Networks.

        Args:
            task_id: Task identifier
        """
        # Create a copy of the base model
        column_model = copy.deepcopy(self.base_model).to(self.device)

        # Create lateral connections to previous columns
        # This is a simplified implementation; a real implementation would need
        # specific lateral connection layers between columns

        self.column_models[task_id] = column_model
        if self.verbose:
            logger.info(f"Created column model for task {task_id}")

    def _update_si_weights(self):
        """Update the importance weights for Synaptic Intelligence."""
        for n, p in self.model.named_parameters():
            if n in self.si_accumulated_delta and p.requires_grad:
                # Calculate change in parameter
                delta = p.data - self.si_prev_params[n]

                # Update omega (importance) based on accumulated gradients
                if torch.norm(delta) > 0:
                    self.si_omega[n] += self.si_accumulated_delta[n] / (delta ** 2 + 1e-7)

                # Apply decay to previous omega values
                self.si_omega[n] *= self.si_omega_decay

                # Reset accumulated delta and update previous parameters
                self.si_accumulated_delta[n] = torch.zeros_like(p.data)
                self.si_prev_params[n] = p.data.clone()

    def _collect_functional_samples(self, data_loader: DataLoader, task_id: str, num_samples: int = None):
        """
        Collect input-output pairs for functional regularization.

        Args:
            data_loader: DataLoader with task data
            task_id: Task identifier
            num_samples: Number of samples to collect
        """
        if num_samples is None:
            num_samples = self.func_sample_size

        if task_id not in self.func_samples:
            self.func_samples[task_id] = []

        self.model.eval()
        sample_count = 0

        with torch.no_grad():
            for data_batch in data_loader:
                if sample_count >= num_samples:
                    break

                inputs, _ = data_batch
                inputs = inputs.to(self.device)
                batch_size = inputs.shape[0]

                outputs = self.model(inputs)

                for i in range(min(batch_size, num_samples - sample_count)):
                    self.func_samples[task_id].append((
                        inputs[i].detach().cpu(),
                        outputs[i].detach().cpu()
                    ))
                    sample_count += 1

        if self.verbose:
            logger.info(f"Collected {sample_count} functional samples for task {task_id}")

    def start_task(self, task_id: str):
        """
        Start training on a new task.

        Args:
            task_id: Identifier for the task
        """
        self.current_task_id = task_id

        # For Progressive Neural Networks
        if self.use_progressive_nets:
            if task_id not in self.column_models:
                self._create_column_model(task_id)

        # For task adapters
        if self.use_adapters:
            if task_id in self.adapters:
                if self.verbose:
                    logger.info(f"Switching to existing adapter for task {task_id}")

                # Freeze or unfreeze parameters based on mode
                if self.freeze_base_model:
                    for param in self.model.parameters():
                        param.requires_grad = False

                    # Only unfreeze the adapter parameters
                    for param in self.adapters[task_id].parameters():
                        param.requires_grad = True

        # Save teacher model for distillation if needed
        if (self.use_distillation or self.use_lwf) and task_id not in self.seen_tasks:
            if self.verbose:
                logger.info("Saving teacher model for knowledge distillation/LwF")
            self.teacher_model = copy.deepcopy(self.model)
            self.teacher_model.eval()

            if self.use_lwf:
                self.lwf_models[task_id] = self.teacher_model

        # Create a new optimizer for the task
        if self.use_adapters and task_id in self.adapters and self.freeze_base_model:
            # Only optimize the adapter parameters for existing tasks
            self.optimizer = self.optimizer_class(
                self.adapters[task_id].parameters(),
                **self.optimizer_kwargs
            )
        elif self.use_progressive_nets:
            # Only optimize the current column parameters
            self.optimizer = self.optimizer_class(
                self.column_models[task_id].parameters(),
                **self.optimizer_kwargs
            )
        else:
            # Optimize all parameters for new tasks
            self.optimizer = self.optimizer_class(
                self.model.parameters(),
                **self.optimizer_kwargs
            )

        if self.verbose:
            logger.info(f"Started task {task_id}")

    def _accumulate_si_gradients(self):
        """Accumulate gradients for Synaptic Intelligence."""
        for n, p in self.model.named_parameters():
            if n in self.si_accumulated_delta and p.grad is not None and p.requires_grad:
                # Accumulate gradient * (current param - initial param)
                self.si_accumulated_delta[n] -= p.grad * (p.data - self.si_prev_params[n])

    def _compute_si_loss(self) -> torch.Tensor:
        """
        Compute the Synaptic Intelligence regularization loss.

        Returns:
            SI loss
        """
        si_loss = 0
        for n, p in self.model.named_parameters():
            if n in self.si_omega and n in self.si_prev_params and p.requires_grad:
                # Compute quadratic penalty on parameter changes
                si_loss += (self.si_omega[n] * (p - self.si_prev_params[n]) ** 2).sum()
        return si_loss

    def _compute_functional_regularization_loss(self) -> torch.Tensor:
        """
        Compute the functional regularization loss.

        Returns:
            Functional regularization loss
        """
        func_loss = 0
        sample_count = 0

        # Sample from previous tasks
        for task_id, samples in self.func_samples.items():
            if task_id == self.current_task_id:
                continue

            # Randomly sample from the stored functional samples
            indices = np.random.choice(len(samples), min(len(samples), 20), replace=False)

            for idx in indices:
                x, y_prev = samples[idx]
                x = x.to(self.device).unsqueeze(0)
                y_prev = y_prev.to(self.device)

                # Forward pass with current model
                y_current = self.model(x).squeeze(0)

                # Mean squared error between previous and current outputs
                func_loss += torch.mean((y_current - y_prev) ** 2)
                sample_count += 1

        return func_loss / max(1, sample_count)

    def train_step(
        self,
        inputs: torch.Tensor,
        targets: torch.Tensor,
        task_specific_loss_fn: Callable = nn.CrossEntropyLoss()
    ) -> Dict[str, float]:
        """
        Perform a single training step.

        Args:
            inputs: Input batch
            targets: Target batch
            task_specific_loss_fn: Loss function specific to the current task

        Returns:
            Dictionary with loss metrics
        """
        self.model.train()

        # Move data to device
        inputs = inputs.to(self.device)
        targets = targets.to(self.device)

        # Forward pass
        self.optimizer.zero_grad()
        outputs = self.model(inputs)

        # Task-specific loss
        task_loss = task_specific_loss_fn(outputs, targets)
        total_loss = task_loss
        loss_metrics = {"task_loss": task_loss.item()}

        # Knowledge distillation loss
        if self.use_distillation and self.teacher_model is not None:
            with torch.no_grad():
                teacher_outputs = self.teacher_model(inputs)

            # Compute distillation loss
            distillation_loss = self._compute_distillation_loss(
                outputs, teacher_outputs, self.distillation_temp
            )

            # Combine losses
            total_loss = (
                (1 - self.distillation_alpha) * task_loss +
                self.distillation_alpha * distillation_loss
            )
            loss_metrics["distillation_loss"] = distillation_loss.item()

        # Learning without Forgetting loss
        if self.use_lwf and self.lwf_models:
            lwf_loss = 0
            for task_id, old_model in self.lwf_models.items():
                if task_id == self.current_task_id:
                    continue

                with torch.no_grad():
                    old_outputs = old_model(inputs)

                # Compute LwF loss
                lwf_task_loss = self._compute_distillation_loss(
                    outputs, old_outputs, self.lwf_temp
                )
                lwf_loss += lwf_task_loss

            if lwf_loss > 0:
                total_loss += self.lwf_alpha * lwf_loss
                loss_metrics["lwf_loss"] = lwf_loss.item()

        # EWC loss
        if self.use_ewc and self.fisher_dict and self.param_dict:
            ewc_loss = self._compute_ewc_loss()
            total_loss += self.ewc_lambda * ewc_loss
            loss_metrics["ewc_loss"] = ewc_loss.item()

        # Synaptic Intelligence loss
        if self.use_si and self.si_omega:
            si_loss = self._compute_si_loss()
            total_loss += self.si_lambda * si_loss
            loss_metrics["si_loss"] = si_loss.item()

        # Functional regularization loss
        if self.use_func_regularization and self.func_samples:
            func_loss = self._compute_functional_regularization_loss()
            total_loss += self.func_reg_lambda * func_loss
            loss_metrics["func_reg_loss"] = func_loss.item()

        # Backward pass and optimizer step
        total_loss.backward()

        # Accumulate gradients for Synaptic Intelligence
        if self.use_si:
            self._accumulate_si_gradients()

        self.optimizer.step()

        # Increment the replay counter
        self.replay_counter += 1

        loss_metrics["total_loss"] = total_loss.item()
        return loss_metrics

    def _compute_distillation_loss(
        self,
        outputs: torch.Tensor,
        teacher_outputs: torch.Tensor,
        temperature: float
    ) -> torch.Tensor:
        """
        Compute the knowledge distillation loss.

        Args:
            outputs: Model outputs
            teacher_outputs: Teacher model outputs
            temperature: Softmax temperature

        Returns:
            Distillation loss
        """
        soft_targets = torch.nn.functional.softmax(teacher_outputs / temperature, dim=1)
        log_probs = torch.nn.functional.log_softmax(outputs / temperature, dim=1)
        distillation_loss = -torch.sum(soft_targets * log_probs) / outputs.size(0)
        return distillation_loss * (temperature ** 2)

    def _compute_ewc_loss(self) -> torch.Tensor:
        """
        Compute the EWC regularization loss.

        Returns:
            EWC loss
        """
        ewc_loss = 0

        if self.ewc_mode == 'separate':
            # Use separate Fisher matrix for each task
            for task_id in self.seen_tasks:
                if task_id in self.fisher_dict and task_id in self.param_dict:
                    for n, p in self.model.named_parameters():
                        if n in self.fisher_dict[task_id] and n in self.param_dict[task_id] and p.requires_grad:
                            fisher = self.fisher_dict[task_id][n]
                            old_param = self.param_dict[task_id][n]
                            ewc_loss += torch.sum(fisher * (p - old_param) ** 2)
        elif self.ewc_mode == 'online':
            # Use a single Fisher matrix that's updated online
            if 'online' in self.fisher_dict and 'online' in self.param_dict:
                for n, p in self.model.named_parameters():
                    if n in self.fisher_dict['online'] and n in self.param_dict['online'] and p.requires_grad:
                        fisher = self.fisher_dict['online'][n]
                        old_param = self.param_dict['online'][n]
                        ewc_loss += torch.sum(fisher * (p - old_param) ** 2)

        return ewc_loss

    def _sample_from_replay_buffer(self, batch_size: int = 32) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Sample a batch from the replay buffer.

        Args:
            batch_size: Number of samples to draw

        Returns:
            Tuple of (inputs, targets)
        """
        all_experiences = []
        for task_id, experiences in self.replay_buffer.items():
            if task_id != self.current_task_id:  # Don't replay current task
                all_experiences.extend(experiences)

        if not all_experiences:
            return None, None

        # Sample according to strategy
        if self.replay_strategy == 'random':
            indices = np.random.choice(len(all_experiences), min(len(all_experiences), batch_size), replace=False)
            sampled_experiences = [all_experiences[i] for i in indices]

        elif self.replay_strategy == 'prioritized':
            # Sample based on importance (higher importance = higher probability)
            importance = np.array([exp[2] for exp in all_experiences])
            probs = importance / importance.sum()
            indices = np.random.choice(
                len(all_experiences),
                min(len(all_experiences), batch_size),
                replace=False,
                p=probs
            )
            sampled_experiences = [all_experiences[i] for i in indices]

        elif self.replay_strategy == 'class_balanced':
            # Group by class
            class_samples = {}
            for x, y, imp in all_experiences:
                y_item = y.item()
                if y_item not in class_