<a href="https://colab.research.google.com/github/Eric-Chung-0511/Learning-Record/blob/main/Data%20Science%20Projects/PawMatchAI/%5BICARL%5D%5B88_70%5DPawMatchAI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import copy
import json
import re
import os
from datetime import datetime
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import pandas as pd
from PIL import Image
from torchvision import transforms

In [None]:
class ICARLIncrementalLearner:

    def __init__(self, model_path, json_path, evaluator, device='cuda', memory_size=2000):
        """
        # Initialization Function Explanation:
        This is the initialization function of the ICARL incremental learner, setting up all necessary components and parameters

        Parameter Description:
        - model_path: Path to the pre-trained model file
        - json_path: Path to JSON file containing breed information
        - evaluator: Performance evaluator instance
        - device: Computing device, defaults to CUDA
        - memory_size: Feature memory size, defaults to 2000
        """
        # Basic setup
        self.device = device
        self.json_path = json_path
        self.memory_size = memory_size
        self.evaluator = evaluator

        # Temperature parameter setup for knowledge distillation
        self.initial_temperature = 1  # Initial temperature for training start
        self.max_temperature = 2.5    # Maximum temperature for late training
        self.min_temperature = 1      # Minimum temperature for early training

        # Load breed information
        with open(json_path, 'r', encoding='utf-8') as f:
            self.breeds_data = json.load(f)  # Load complete breed data
            self.breeds = self.breeds_data['breeds']  # Extract breed list

        # Load pre-trained model
        self.model_path = model_path
        self._load_model()  # Call model loading function

        # Initialize memory-related components
        self.feature_memory = {}      # Store features for each breed
        self.prototype_memory = {}    # Store prototype features for each breed

        # Training-related parameters
        self.best_loss = float('inf')  # Record best loss value
        self.best_model_state = None   # Store best model state

        # Protection mechanism attributes
        self.feature_importance = {}  # Store feature importance weights
        self.similarity_matrix = {}   # Store breed similarity matrix

    def _load_model(self):
        """
        Model Loading Function Explanation:
        Responsible for loading the pre-trained model and establishing the teacher model
        Includes error handling and support for multiple checkpoint formats
        """
        try:
            # Step 1: Create base model instance
            self.base_model = BaseModel(
                num_classes=len(self.breeds),  # Set output class number
                device=self.device  # Specify computing device
            ).to(self.device)

            # Step 2: Load model weights
            checkpoint = torch.load(self.model_path, map_location=self.device)

            # Step 3: Handle different checkpoint formats
            if 'base_model' in checkpoint:
                self.base_model.load_state_dict(checkpoint['base_model'])
            elif 'model_state_dict' in checkpoint:
                self.base_model.load_state_dict(checkpoint['model_state_dict'])
            elif isinstance(checkpoint, collections.OrderedDict):
                self.base_model.load_state_dict(checkpoint)
            else:
                raise ValueError("Unexpected checkpoint format")

            # Step 4: Create teacher model (for knowledge distillation)
            self.teacher_model = copy.deepcopy(self.base_model).to(self.device)
            self.teacher_model.eval()

            print(f"Successfully loaded model from {self.model_path}")

        except Exception as e:
            raise RuntimeError(f"Error loading model: {str(e)}")

    def _match_breed_name(self, query_breed):
        """
        Function: Match breed names with fuzzy matching support

        Args:
            query_breed (str): The breed name to query

        Returns:
            str: Complete breed name, returns None if not found
        """
        # Convert query string to lowercase and clean
        query = query_breed.lower().strip()

        # Exact match
        for breed in self.breeds:
            if breed.lower() == query:
                return breed

        # Partial match
        for breed in self.breeds:
            if query in breed.lower():
                return breed

        # Use more lenient matching rules
        for breed in self.breeds:
            # Compare after removing parenthetical content
            clean_breed = re.sub(r'\([^)]*\)', '', breed.lower()).strip()
            if query in clean_breed:
                return breed

            # Handle hyphen and space cases
            normalized_breed = clean_breed.replace('-', ' ').replace('_', ' ')
            normalized_query = query.replace('-', ' ').replace('_', ' ')

            if normalized_query in normalized_breed:
                return breed

        # Return None if no match is found
        return None

    def find_best_lr(self, train_loader, init_value=1e-8, final_value=1e-5, num_iter=100):
        """
        Function to find the optimal learning rate:
        Uses a learning rate range test to determine the best learning rate.
        """
        # Step 1: Save the current model state
        model_state = copy.deepcopy(self.base_model.state_dict())

        # Step 2: Set up the optimizer and learning rate range
        optimizer = optim.AdamW([{
            'params': self.base_model.parameters(),
            'initial_lr': init_value
        }], lr=init_value)

        # Step 3: Compute the learning rate growth factor
        gamma = (final_value / init_value) ** (1 / num_iter)
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

        # Step 4: Initialize recording lists
        losses = []
        learning_rates = []
        best_loss = float('inf')
        best_lr = init_value

        # Step 5: Conduct the learning rate test
        self.base_model.train()

        for iteration in range(num_iter):
            try:  # Fetch a batch of data
                inputs, labels = next(iter(train_loader))
            except StopIteration:
                train_iter = iter(train_loader)
                inputs, labels = next(train_iter)

            # Move data to the designated device
            inputs = inputs.to(self.device)
            labels = labels.to(self.device)

            # Forward pass and loss computation
            optimizer.zero_grad()
            outputs, _ = self.base_model(inputs)
            loss = F.cross_entropy(outputs, labels)

            # Update the best learning rate
            current_lr = optimizer.param_groups[0]['lr']
            if loss.item() < best_loss:
                best_loss = loss.item()
                best_lr = current_lr

            # Record results
            losses.append(loss.item())
            learning_rates.append(current_lr)

            # Backpropagation and update
            loss.backward()
            optimizer.step()
            lr_scheduler.step()

        # Use the evaluator method to plot the graph
        self.evaluator.plot_lr_finder(learning_rates, losses)

        self.base_model.load_state_dict(model_state)
        return best_lr

    def update_feature_memory(self, images, breed_name):
        """
        Function: Update Feature Memory
        Purpose: Store image features for specific breeds in the memory bank

        # Key Steps:
        1. Set model to evaluation mode
        2. Batch process images to get features
        3. Store features in memory bank

        # Key Parameters:
        images: Input image collection
        breed_name: Breed name
        """
        # Set model to evaluation mode to avoid batch normalization effects
        self.base_model.eval()
        # Use torch.no_grad() to avoid gradient calculation, improving efficiency and reducing memory usage
        with torch.no_grad():
            features = []
            for img in images:
                # unsqueeze(0) converts single image to batch format
                # Extract features, _ ignores model classification output, keeping only features
                _, feature = self.base_model(img.unsqueeze(0).to(self.device))
                features.append(feature)
            # Concatenate all features along batch dimension to form feature collection for this breed
            self.feature_memory[breed_name] = torch.cat(features, dim=0)

    def update_prototypes(self):
        """
        Function: Update Prototype Memory
        Purpose: Calculate mean features as prototypes for each breed

        # Key Steps:
        1. Iterate through all breeds in feature memory
        2. Calculate mean features for each breed
        3. Normalize the features
        """
        for breed, features in self.feature_memory.items():
            # Calculate mean features as breed prototype representation
            mean_feature = features.mean(dim=0)
            # Apply L2 normalization to ensure feature comparability across breeds
            self.prototype_memory[breed] = F.normalize(mean_feature, p=2, dim=0)

    def _calculate_feature_importance(self):
        """
        Function: Feature Importance Calculation

        Implementation Principle:
        Uses absolute value means of features to measure their importance
        Helps identify most influential feature dimensions for classification
        """
        self.base_model.eval()

        with torch.no_grad():
            for breed, features in self.feature_memory.items():
                # Calculate mean absolute value for each feature dimension as importance indicator
                importance = torch.abs(features).mean(dim=0)
                self.feature_importance[breed] = importance

    def _update_similarity_matrix(self):
        """
        Function Purpose:
        Updates the similarity matrix between breeds.

        Key Steps:
        1. Iterate through all breed pairs.
        2. Compute cosine similarity between them.
        3. Store the results in the similarity matrix.

        Application Scenario:
        Used for feature protection mechanisms to ensure that features of similar breeds do not interfere with each other.
        """
        # Iterate through all possible breed pairs
        for breed1 in self.breeds:
            if breed1 not in self.similarity_matrix:  # Initialize the similarity dictionary for the breed
                self.similarity_matrix[breed1] = {}

            for breed2 in self.breeds:
                if breed1 == breed2:  # Skip self-similarity calculation
                    continue

                if breed2 not in self.similarity_matrix[breed1]:  # Avoid redundant similarity calculations
                    feat1 = self.prototype_memory.get(breed1)
                    feat2 = self.prototype_memory.get(breed2)

                    if feat1 is not None and feat2 is not None:  # Compute similarity only if both breeds have prototype features
                        similarity = F.cosine_similarity(feat1, feat2, dim=0)  # Use cosine similarity to measure feature vector similarity
                        self.similarity_matrix[breed1][breed2] = similarity.item()

    def _get_most_similar_breeds(self, target_breed, top_k=5):
        """
        Function: Retrieve the most similar breeds to a target breed.

        Parameter Explanation:
            - target_breed: The name of the target breed.
            - top_k: The number of most similar breeds to return, default is 5.

        Returns: A list of breeds sorted by similarity.
        """
        if target_breed in self.similarity_matrix:  # Check if the target breed exists in the similarity matrix
            similarities = self.similarity_matrix[target_breed]  # Retrieve similarity scores for the breed
            return sorted(similarities.items(), key=lambda x: x[1], reverse=True)[:top_k]  # Sort by similarity in descending order and return the top k
        return []  # Return an empty list if the target breed is not found


    def _calculate_protection_loss(self, student_features, teacher_features, breed_name):
        """
        Function: Calculate Feature Protection Loss

        Key Components:
        1. Importance-based feature protection
        2. Similar breed feature protection consideration
        3. Weighted loss calculation

        Protection Mechanism:
        Ensures important features are protected during training through importance masks and similarity weights
        """
        protection_loss = 0  # Calculate feature protection loss for target breed

        # Get important feature mask
        if breed_name in self.feature_importance:
            importance_mask = self.feature_importance[breed_name]

            # Calculate MSE loss for important features
            weighted_mse = F.mse_loss(
                student_features * importance_mask,
                teacher_features * importance_mask,
                reduction='none'
            ).mean()

            protection_loss += weighted_mse

        # Consider feature protection for similar breeds
        similar_breeds = self._get_most_similar_breeds(breed_name)
        for similar_breed, similarity in similar_breeds:
            if similar_breed in self.feature_importance:
                sim_importance = self.feature_importance[similar_breed]  # Get feature importance of similar breed
                sim_weight = similarity * 0.5  # Weight coefficient based on similarity

                # Calculate feature protection loss for similar breed
                sim_loss = F.mse_loss(
                    student_features * sim_importance,
                    teacher_features * sim_importance,
                    reduction='none'
                ).mean() * sim_weight

                protection_loss += sim_loss

        return protection_loss

    def _prepare_data_loader(self, images, batch_size, breed_name):
        """
        Function: Prepare Data Loader
        Main Tasks:
            1. Organize image data into batches
            2. Generate corresponding labels
            3. Configure data loading parameters

        Parameter Description:
            - images: Input image collection
            - batch_size: Size of each batch
            - breed_name: Breed name (for label generation)
        """
        breed_idx = self.breeds.index(breed_name)  # Get numeric label for breed
        labels = torch.full((len(images),), breed_idx, dtype=torch.long)  # Create label tensor filled with breed index

        dataset = torch.utils.data.TensorDataset(images, labels)  # Create dataset object pairing images with labels

        # Return configured data loader with shuffle enabled
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True
        )

    def enhance_breed(self, images, breed_name, epochs=100, lr=None, batch_size=8):
        """
        Function: Main entry function for breed enhancement.

        Execution Process:
            1. Breed name validation.
            2. Data splitting (training and evaluation sets).
            3. Pre-enhancement evaluation.
            4. Perform enhancement learning.
            5. Post-enhancement evaluation.
            6. Update feature memory.
        """
        full_breed_name = self._match_breed_name(breed_name)  # Breed name matching and validation
        if not full_breed_name:
            raise ValueError(f"Breed {breed_name} not found in database")

        # Perform evaluation before enhancement
        eval_images = images[:len(images)//5]  # Use 20% of images as the evaluation set
        train_images = images[len(images)//5:]  # Use the remaining 80% for training

        # Pre-enhancement evaluation - using the matched breed name
        pre_enhancement_results = self.evaluator.evaluate_breed_performance(
            self.base_model,
            eval_images,
            full_breed_name,  # Use the fully matched name
            self.breeds,
            self.device
        )

        # Automatically determine the learning rate
        if lr is None:
            train_loader = self._prepare_data_loader(train_images, batch_size, full_breed_name)
            lr = self.find_best_lr(train_loader)
            print(f"Found optimal learning rate: {lr:.2e}")

        try:
            # Perform enhancement learning
            best_model = self._perform_enhancement(
                train_images, full_breed_name, epochs, lr, batch_size)

            if best_model is not None:
                self.base_model.load_state_dict(best_model)

                # Post-enhancement evaluation - using the matched breed name
                post_enhancement_results = self.evaluator.evaluate_breed_performance(
                    self.base_model,
                    eval_images,
                    full_breed_name,  # Use the fully matched name
                    self.breeds,
                    self.device
                )

                # Generate an enhancement performance report
                self.evaluator.generate_enhancement_report(
                    full_breed_name,  # Use the fully matched name
                    pre_enhancement_results,
                    post_enhancement_results
                )

                # Update feature memory and prototypes
                self.update_feature_memory(train_images, breed_name)
                self.update_prototypes()

                self.evaluator.save_enhanced_model(self.base_model, breed_name)
                return True
            else:
                print("Enhancement failed, rolling back...")
                self.base_model.load_state_dict(original_state)
                return False

        except Exception as e:
            print(f"Enhancement failed: {str(e)}")
            self.base_model.load_state_dict(original_state)
            return False

        def validate_enhancement(model, images, breed_name):
            model.eval()
            with torch.no_grad():
                correct = 0
                for img in images:
                    img = img.unsqueeze(0).to(self.device)
                    logits, _ = model(img)
                    pred = torch.argmax(logits, dim=1)
                    if self.breeds[pred.item()] == breed_name:
                        correct += 1
                return correct / len(images)


    def _calculate_temperature(self, current_epoch, total_epochs):
        """
        Function: Implement smooth temperature adjustment using cosine function

        Args:
            current_epoch (int): Current training epoch
            total_epochs (int): Total number of training epochs

        Returns:
            float: Calculated current temperature value
        """
        progress = current_epoch / total_epochs

        if progress <= 0.3:  # Early learning phase: maintain low temperature
            return self.min_temperature
        elif progress >= 0.7:  # Late learning phase: maintain high temperature
            else:
                # Mid-phase: smooth transition using cosine function
                # Map progress to [0, π] interval
                normalized_progress = (progress - 0.3) / 0.4 * math.pi
                cos_value = math.cos(normalized_progress)
                # Map cosine value (-1 to 1) to temperature range, calculate smooth transition temperature
                return self.min_temperature + (self.max_temperature - self.min_temperature) * (1 - (cos_value + 1) / 2)


    def _perform_enhancement(self, images, breed_name, epochs, lr, batch_size):
        """
        Function: Execute the core enhancement learning process.

        Main Objective:
        Improve the recognition capability of a specific breed through knowledge distillation and feature protection.

        Processing Flow:
        Training, evaluation, early stopping mechanism, dynamic temperature adjustment.
        """
        self.base_model.train()
        train_loader = self._prepare_data_loader(images, batch_size, breed_name)

        # Initialize optimizer and scheduler
        optimizer = optim.AdamW(self.base_model.parameters(), lr=lr)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=epochs,
            eta_min=1e-7
        )

        # Compute initial feature importance and similarity matrix
        self._calculate_feature_importance()
        self._update_similarity_matrix()

        # Early Stopping settings
        best_loss = float('inf')
        best_model_state = None
        patience = 10
        patience_counter = 0

        # Training records
        train_losses = []
        learning_rates = []
        temperatures = []  # New: Track temperature changes

        for epoch in range(epochs):
            epoch_loss = 0
            batch_count = 0

            # Compute the temperature value for the current epoch
            current_temperature = self._calculate_temperature(epoch, epochs)
            temperatures.append(current_temperature)  # Track temperature changes

            for batch_images, labels in train_loader:
                batch_images = batch_images.to(self.device)
                labels = labels.to(self.device)

                # Teacher model prediction using the current temperature (fixed parameter)
                with torch.no_grad():
                    teacher_logits, teacher_features = self.teacher_model(batch_images)
                    teacher_probs = F.softmax(teacher_logits / current_temperature, dim=1)  # Apply temperature-scaled softmax for teacher model probability output

                # Student model prediction (parameters to be optimized)
                optimizer.zero_grad()  # Clear gradients
                student_logits, student_features = self.base_model(batch_images)
                student_probs = F.softmax(student_logits / current_temperature, dim=1)

                # 1. Knowledge distillation loss using the current temperature
                distill_loss = F.kl_div(
                    F.log_softmax(student_logits / current_temperature, dim=1),
                    teacher_probs,
                    reduction='batchmean'
                ) * (current_temperature ** 2)  # Apply temperature squared factor

                # 2. Cross-entropy loss
                ce_loss = F.cross_entropy(student_logits, labels)

                # 3. Feature distillation loss
                feature_loss = F.mse_loss(student_features, teacher_features)

                # 4. Feature protection loss
                protection_loss = self._calculate_protection_loss(
                    student_features,
                    teacher_features,
                    breed_name
                )

                # Combine all losses, incorporating protection loss
                loss = (
                    (0.3 * distill_loss) +   # Reduce weight for distillation loss, so the model will no give too much attention to the distillation loss
                    (0.4 * ce_loss) +        # Maintain weight for cross-entropy loss
                    (0.2 * feature_loss) +   # Reduce weight for feature loss
                    (0.1 * protection_loss)   # Add protection loss
                )

                # Backpropagation
                loss.backward()

                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(self.base_model.parameters(), max_norm=1.0)

                optimizer.step()

                # Accumulate loss
                epoch_loss += loss.item()
                batch_count += 1

            # Compute average loss
            avg_loss = epoch_loss / batch_count
            current_lr = optimizer.param_groups[0]['lr']

            # Record training information
            train_losses.append(avg_loss)
            learning_rates.append(current_lr)

            print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, LR: {current_lr:.2e}, Temperature: {current_temperature:.4f}")

            # Early Stopping check
            if avg_loss < best_loss:
                best_loss = avg_loss
                best_model_state = copy.deepcopy(self.base_model.state_dict())
                patience_counter = 0
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print(f"Early stopping triggered after {epoch + 1} epochs")
                break

            # Update learning rate
            scheduler.step()

        # Use evaluator to plot training curves (including temperature changes)
        self.evaluator.plot_training_curves(train_losses, learning_rates, temperatures)

        return best_model_state if best_model_state is not None else None

In [None]:
class PerformanceEvaluator:

    def __init__(self, save_dir):
        """
        Initialize the PerformanceEvaluator.

        Args:
            save_dir: Directory where evaluation results will be saved.
        """
        self.save_dir = save_dir
        self.visualization_dir = os.path.join(save_dir, 'visualizations')
        os.makedirs(self.visualization_dir, exist_ok=True)
        self.timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

    def evaluate(self, model, test_loader, breeds, device):
        """
        Evaluate model performance.

        Args:
            model: The model to be evaluated.
            test_loader: DataLoader containing test dataset.
            breeds: List of breed names.
            device: Computing device.

        Returns:
            tuple: (All ground-truth labels, All predicted labels)
        """
        model.eval()
        all_preds = []
        all_labels = []
        pred_probs = []

        with torch.no_grad():
            for images, labels in test_loader:
                images = images.to(device)
                labels = labels.to(device)

                logits, _ = model(images)
                probs = F.softmax(logits, dim=1)
                preds = torch.argmax(logits, dim=1)

                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                pred_probs.extend(probs.cpu().numpy())

        # Generate and save evaluation report
        self._save_classification_report(all_labels, all_preds, breeds)

        # Plot and save confusion matrix
        self._plot_confusion_matrix(all_labels, all_preds, breeds)

        # Compute and save confidence distribution for each breed
        self._analyze_confidence_distribution(pred_probs, all_labels, breeds)

        return all_labels, all_preds

    def _save_classification_report(self, y_true, y_pred, breeds):
        """
        Generate and save a classification report.

        Args:
            y_true: Ground-truth labels.
            y_pred: Predicted labels.
            breeds: List of breed names.
        """
        report = classification_report(y_true, y_pred,
                                       target_names=breeds,
                                       output_dict=True)

        # Save as CSV
        report_df = pd.DataFrame(report).transpose()
        report_path = os.path.join(self.save_dir,
                                   f'classification_report_{self.timestamp}.csv')
        report_df.to_csv(report_path)

        # Print the report
        print("\nClassification Report:")
        print(classification_report(y_true, y_pred, target_names=breeds))

    def _plot_confusion_matrix(self, y_true, y_pred, breeds):
        """
        Plot and save the confusion matrix.

        Args:
            y_true: Ground-truth labels.
            y_pred: Predicted labels.
            breeds: List of breed names.
        """
        plt.figure(figsize=(20, 20))
        cm = confusion_matrix(y_true, y_pred)

        # Compute percentage values
        cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100

        # Use seaborn to plot the heatmap
        sns.heatmap(cm_percent, annot=True, fmt='.1f',
                    xticklabels=breeds,
                    yticklabels=breeds,
                    cmap='Blues')

        plt.title('Confusion Matrix (%)')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')

        # Rotate labels for better readability
        plt.xticks(rotation=45, ha='right')
        plt.yticks(rotation=45)

        # Adjust layout and save
        plt.tight_layout()
        cm_path = os.path.join(self.save_dir,
                               f'confusion_matrix_{self.timestamp}.png')
        plt.savefig(cm_path, bbox_inches='tight', dpi=300)
        plt.close()

    def _analyze_confidence_distribution(self, pred_probs, true_labels, breeds):
        """
        Analyze and visualize the confidence distribution of predictions.

        Args:
            pred_probs: List of predicted probabilities.
            true_labels: Ground-truth labels.
            breeds: List of breed names.
        """
        pred_probs = np.array(pred_probs)
        true_labels = np.array(true_labels)

        plt.figure(figsize=(15, 8))

        # Plot confidence distribution for correct and incorrect predictions
        confidences = np.max(pred_probs, axis=1)
        predictions = np.argmax(pred_probs, axis=1)
        correct_mask = predictions == true_labels

        plt.hist(confidences[correct_mask], bins=50, alpha=0.5,
                 label='Correct Predictions', color='green')
        plt.hist(confidences[~correct_mask], bins=50, alpha=0.5,
                 label='Wrong Predictions', color='red')

        plt.title('Prediction Confidence Distribution')
        plt.xlabel('Confidence')
        plt.ylabel('Count')
        plt.legend()

        # Save the figure
        conf_path = os.path.join(self.save_dir,
                                 f'confidence_distribution_{self.timestamp}.png')
        plt.savefig(conf_path)
        plt.close()

    def get_evaluation_summary(self):
        """
        Return an evaluation summary.

        Returns:
            dict: Summary of evaluation results.
        """
        return {
            'results_dir': self.save_dir,
            'timestamp': self.timestamp,
            'files': {
                'classification_report': f'classification_report_{self.timestamp}.csv',
                'confusion_matrix': f'confusion_matrix_{self.timestamp}.png',
                'confidence_distribution': f'confidence_distribution_{self.timestamp}.png'
            }
        }

    def save_enhanced_model(self, model, breed_name):
        """
        Save the enhanced model.

        Args:
            model: The model to be saved.
            breed_name: The name of the breed.
        """
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        save_path = os.path.join(
            self.save_dir,
            f'enhanced_{breed_name}_{timestamp}.pth'
        )

        checkpoint = {
            'model_state_dict': model.state_dict(),
            'breed_info': {
                'breed_name': breed_name,
                'timestamp': timestamp
            }
        }

        torch.save(checkpoint, save_path)
        print(f"Enhanced model saved to: {save_path}")

    def evaluate_breed_performance(self, model, eval_images, breed_name, breeds, device):
        """
        Evaluate the performance of a specific breed.

        Args:
            model: The model to be evaluated.
            eval_images: List of evaluation images.
            breed_name: Name of the target breed.
            breeds: List of breed names.
            device: Computing device.

        Returns:
            dict: Performance metrics for the breed.
        """
        model.eval()
        predictions = []
        confidences = []
        correct_count = 0
        breed_idx = breeds.index(breed_name)

        with torch.no_grad():
            for img in eval_images:
                img = img.unsqueeze(0).to(device)
                logits, _ = model(img)
                probs = F.softmax(logits, dim=1)

                pred = torch.argmax(logits, dim=1)
                conf = torch.max(probs, dim=1)[0]

                if pred.item() == breed_idx:
                    correct_count += 1

                predictions.append(pred.item())
                confidences.append(conf.item())

        return {
            'predictions': predictions,
            'confidences': confidences,
            'breed_name': breed_name,
            'mean_confidence': np.mean(confidences),
            'correct_predictions': correct_count
        }

    def generate_enhancement_report(self, breed_name, pre_results, post_results):
        """
        Generate a comparison report for the enhancement process.

        Args:
            breed_name: The breed that was enhanced.
            pre_results: Performance results before enhancement.
            post_results: Performance results after enhancement.

        Saves:
            JSON file with pre- and post-enhancement comparison.
        """
        report = {
            'breed_name': breed_name,
            'pre_enhancement': {
                'mean_confidence': pre_results['mean_confidence'],
                'correct_predictions': pre_results['correct_predictions']
            },
            'post_enhancement': {
                'mean_confidence': post_results['mean_confidence'],
                'correct_predictions': post_results['correct_predictions']
            }
        }

        # Save report
        report_path = os.path.join(
            self.save_dir,
            f'enhancement_report_{breed_name}_{self.timestamp}.json'
        )
        with open(report_path, 'w') as f:
            json.dump(report, f, indent=4)

    def plot_training_curves(self, losses, learning_rates, temperatures=None):
        """
        Plot and save training curves for loss, learning rate, and temperature.

        Args:
            losses: List of loss values recorded during training.
            learning_rates: List of learning rates used during training.
            temperatures: (Optional) List of temperature values used in training, if applicable.

        """
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

        # Determine the number of subplots based on whether temperature data is available
        num_plots = 3 if temperatures is not None else 2
        plt.figure(figsize=(15, 5))

        # Plot the training loss curve
        plt.subplot(1, num_plots, 1)
        plt.plot(losses)
        plt.title('Training Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')

        # Plot the learning rate curve
        plt.subplot(1, num_plots, 2)
        plt.plot(learning_rates)
        plt.title('Learning Rate')
        plt.xlabel('Epoch')
        plt.ylabel('Learning Rate')
        plt.yscale('log')  # Use log scale for better visualization

        # If temperature values are provided, plot the temperature curve
        if temperatures is not None:
            plt.subplot(1, num_plots, 3)
            plt.plot(temperatures)
            plt.title('Temperature')
            plt.xlabel('Epoch')
            plt.ylabel('Temperature')

        # Save the figure
        plt.tight_layout()
        save_path = os.path.join(
            self.visualization_dir,
            f'training_curves_{timestamp}.png'
        )
        plt.savefig(save_path)
        plt.close()
        print(f"Training curves saved to: {save_path}")

    def plot_lr_finder(self, learning_rates, losses):
        """
        Plot and save the learning rate search results.

        Args:
            learning_rates: List of tested learning rates.
            losses: Corresponding list of loss values for each learning rate.

        """
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

        plt.figure(figsize=(10, 6))

        # Plot loss vs. learning rate on a log scale
        plt.plot(learning_rates, losses)
        plt.xscale('log')  # Logarithmic scale for better visualization of LR changes
        plt.xlabel('Learning Rate')
        plt.ylabel('Loss')
        plt.title('Learning Rate Finder')

        # Save the plot
        save_path = os.path.join(
            self.visualization_dir,
            f'lr_finder_{timestamp}.png'
        )
        plt.savefig(save_path)
        plt.close()
        print(f"Learning rate finder plot saved to: {save_path}")

In [None]:
def main():
    """
    Main function to perform breed enhancement using incremental learning.
    """

    # Initialize performance evaluator
    evaluator = PerformanceEvaluator(
        save_dir='/content/drive/Othercomputers/My MacBook Pro/Learning/ICARL_PawMatchAI/enhanced_models'
    )

    # Initialize the incremental learner
    learner = ICARLIncrementalLearner(
        model_path='/content/drive/Othercomputers/My MacBook Pro/Learning/ICARL_PawMatchAI/enhanced_models/enhanced_toy_poodle_20250201_063556.pth',
        json_path='/content/drive/Othercomputers/My MacBook Pro/Learning/ICARL_PawMatchAI/data/dog_breeds.json',
        evaluator=evaluator,
        device='cuda' if torch.cuda.is_available() else 'cpu'
    )

    # Define image transformation
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    # Load enhancement images
    def load_enhancement_images(image_dir):
        """
        Load and preprocess images for breed enhancement.

        Args:
            image_dir (str): Directory containing images.

        Returns:
            torch.Tensor: A batch of processed images.
        """
        images = []
        for img_name in os.listdir(image_dir):
            if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                img_path = os.path.join(image_dir, img_name)
                img = Image.open(img_path).convert('RGB')
                img = transform(img)
                images.append(img)
        return torch.stack(images)

    # Load images from the specified directory
    image_dir = '/content/drive/Othercomputers/My MacBook Pro/Learning/ICARL_PawMatchAI/enhanced_images/bichon_frise'
    enhancement_images = load_enhancement_images(image_dir)

    # Perform enhancement learning

    # Method 1: Automatically determine the learning rate
    # learner.enhance_breed(
    #     images=enhancement_images,
    #     breed_name='havanese',
    #     epochs=100,  # Default number of training epochs
    #     batch_size=16  # Default batch size
    # )

    # Method 2: Manually set the learning rate
    learner.enhance_breed(
        images=enhancement_images,
        breed_name='bichon_frise',
        epochs=100,
        lr=7e-7,  # Manually set the learning rate
        batch_size=8
    )

    # Evaluate the enhancement results
    if hasattr(learner, 'test_loader'):
        results = learner.evaluator.evaluate(
            learner.base_model,
            learner.test_loader,
            learner.breeds,
            learner.device
        )

        # Retrieve evaluation summary
        summary = learner.evaluator.get_evaluation_summary()
        print("\nEvaluation results saved in:", summary['results_dir'])

if __name__ == '__main__':
    main()