In [None]:
import os
import torch
import logging
from typing import List, Dict, Union
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    TrainingArguments, 
    Trainer
)
from peft import (
    LoraConfig, 
    get_peft_model, 
    PeftModel,
    prepare_model_for_kbit_training
)
from datasets import Dataset, concatenate_datasets
import numpy as np

class ContinuousFineTuner:
    """
    Advanced continuous fine-tuning strategy for language models
    Supports multiple approaches to incremental learning
    """
    def __init__(
        self, 
        model_name: str,
        cache_dir: str = "./continuous_finetune_cache",
        max_memory_buffer: int = 10000
    ):
        """
        Initialize continuous fine-tuning manager
        
        Args:
            model_name: Base model to fine-tune
            cache_dir: Directory to store model checkpoints and data
            max_memory_buffer: Maximum number of recent samples to retain
        """
        self.model_name = model_name
        self.cache_dir = cache_dir
        self.max_memory_buffer = max_memory_buffer
        
        # Create directories
        os.makedirs(cache_dir, exist_ok=True)
        os.makedirs(os.path.join(cache_dir, "checkpoints"), exist_ok=True)
        os.makedirs(os.path.join(cache_dir, "memory_buffer"), exist_ok=True)
        
        # Setup logging
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            filename=os.path.join(cache_dir, "continuous_finetuning.log")
        )
        self.logger = logging.getLogger(__name__)
        
        # Model components
        self.base_model = None
        self.tokenizer = None
        self.current_model = None
    
    def load_base_model(self):
        """Load base model and tokenizer"""
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.base_model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        self.current_model = self.base_model
    
    def _create_memory_buffer(self):
        """
        Create a sliding window memory buffer for continuous learning
        
        Returns:
            List of recent training samples
        """
        buffer_path = os.path.join(self.cache_dir, "memory_buffer")
        
        # Load existing buffer files
        buffer_files = [
            os.path.join(buffer_path, f) 
            for f in os.listdir(buffer_path) 
            if f.endswith(".pt")
        ]
        
        # Sort by modification time and keep most recent
        buffer_files.sort(key=os.path.getmtime, reverse=True)
        buffer_files = buffer_files[:self.max_memory_buffer]
        
        # Load tensors
        memory_buffer = []
        for file in buffer_files:
            try:
                memory_buffer.append(torch.load(file))
            except Exception as e:
                self.logger.warning(f"Could not load memory buffer {file}: {e}")
        
        return memory_buffer
    
    def continuous_fine_tune(
        self, 
        new_data: Union[List[Dict], Dataset],
        strategy: str = "incremental",
        learning_rate: float = 1e-4,
        buffer_ratio: float = 0.2
    ):
        """
        Continuous fine-tuning with multiple learning strategies
        
        Args:
            new_data: New training data
            strategy: Fine-tuning approach
                - 'incremental': Add new data to existing knowledge
                - 'adaptive': Dynamically adjust learning based on data distribution
                - 'memory_replay': Use experience replay to prevent catastrophic forgetting
            learning_rate: Base learning rate
            buffer_ratio: Proportion of memory buffer to use in training
        """
        # Ensure model is loaded
        if not self.base_model:
            self.load_base_model()
        
        # Prepare new dataset
        if isinstance(new_data, list):
            new_dataset = Dataset.from_list(new_data)
        else:
            new_dataset = new_data
        
        # Tokenize new data
        def tokenize_function(examples):
            return self.tokenizer(
                examples['text'], 
                truncation=True, 
                padding='max_length', 
                max_length=512
            )
        
        tokenized_new_data = new_dataset.map(
            tokenize_function, 
            batched=True, 
            remove_columns=new_dataset.column_names
        )
        
        # Strategies for continuous learning
        if strategy == "incremental":
            # Simple incremental learning
            training_dataset = tokenized_new_data
        
        elif strategy == "adaptive":
            # Dynamically adjust learning rate based on data novelty
            novelty_score = self._compute_data_novelty(tokenized_new_data)
            learning_rate *= novelty_score
            training_dataset = tokenized_new_data
        
        elif strategy == "memory_replay":
            # Experience replay to prevent catastrophic forgetting
            memory_buffer = self._create_memory_buffer()
            
            # Convert memory buffer to dataset
            buffer_dataset = Dataset.from_list(memory_buffer)
            
            # Combine new data with memory buffer
            training_dataset = concatenate_datasets([
                tokenized_new_data, 
                buffer_dataset
            ])
        
        # LoRA configuration for efficient fine-tuning
        peft_config = LoraConfig(
            r=16,  # Rank of LoRA adaptation
            lora_alpha=32,
            target_modules=["q_proj", "v_proj"],
            lora_dropout=0.1,
            bias="none"
        )
        
        # Prepare model for LoRA training
        model_to_train = prepare_model_for_kbit_training(self.current_model)
        peft_model = get_peft_model(model_to_train, peft_config)
        
        # Training arguments
        training_args = TrainingArguments(
            output_dir=os.path.join(self.cache_dir, "checkpoints"),
            learning_rate=learning_rate,
            per_device_train_batch_size=4,
            num_train_epochs=1,  # Short epochs for continuous learning
            logging_dir=os.path.join(self.cache_dir, "logs"),
            save_strategy="steps",
            save_steps=100,
            logging_steps=10
        )
        
        # Trainer
        trainer = Trainer(
            model=peft_model,
            args=training_args,
            train_dataset=training_dataset
        )
        
        # Train
        trainer.train()
        
        # Save the fine-tuned model
        checkpoint_path = os.path.join(
            self.cache_dir, 
            f"checkpoint_{len(new_data)}_samples"
        )
        peft_model.save_pretrained(checkpoint_path)
        
        # Update current model
        self.current_model = peft_model
        
        # Log training metrics
        self.logger.info(f"Continuous fine-tuning completed with {len(new_data)} new samples")
    
    def _compute_data_novelty(self, dataset):
        """
        Compute novelty score of new data relative to existing knowledge
        
        Args:
            dataset: New dataset to evaluate
        
        Returns:
            Novelty score (0-1)
        """
        # Example simple novelty computation
        # You might replace this with more sophisticated methods like:
        # - Embedding-based distance metrics
        # - Entropy of model predictions
        # - Surprise-based novelty detection
        
        # Placeholder novelty computation
        vocab_coverage = len(set(dataset['input_ids'].flatten())) / len(self.tokenizer.vocab)
        return min(vocab_coverage, 1.0)
    
    def inference_with_continuous_model(self, prompt: str):
        """
        Run inference with the continuously fine-tuned model
        
        Args:
            prompt: Input prompt
        
        Returns:
            Model generation
        """
        # Ensure model is loaded
        if not self.current_model:
            self.load_base_model()
        
        # Tokenize input
        inputs = self.tokenizer(
            prompt, 
            return_tensors="pt", 
            truncation=True, 
            max_length=512
        ).to(self.current_model.device)
        
        # Generate
        outputs = self.current_model.generate(
            **inputs, 
            max_new_tokens=100, 
            do_sample=True, 
            temperature=0.7
        )
        
        # Decode output
        generated_text = self.tokenizer.decode(
            outputs[0], 
            skip_special_tokens=True
        )
        
        return generated_text

# Example usage
if __name__ == "__main__":
    # Initialize continuous fine-tuner
    continuous_tuner = ContinuousFineTuner(
        model_name="microsoft/Phi-4-mini-instruct"
    )
    
    # Simulate continuous data stream
    for batch_number in range(10):
        # Simulate new data coming in
        new_training_data = [
            {"text": f"New sample from batch {batch_number}"}
            for _ in range(100)
        ]
        
        # Continuous fine-tuning
        continuous_tuner.continuous_fine_tune(
            new_data=new_training_data,
            strategy="memory_replay"
        )
        
        # Optional: Periodic inference or evaluation
        print(continuous_tuner.inference_with_continuous_model(
            "Explain the concept of continuous learning"
        ))