<a href="https://colab.research.google.com/github/Ysydso/newtest/blob/main/Innoprenuer_final_code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Install neccessary libraries**

In [None]:

!pip install ultralytics opencv-python-headless pandas matplotlib torch torchvision pyyaml

**Import the required libraries**

In [None]:
import cv2
import numpy as np
from ultralytics import YOLO
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
import zipfile
import os
import torch
import logging
from sklearn.model_selection import KFold
import yaml
from dataclasses import dataclass
from typing import Dict, Any, List
from pathlib import Path

**Set up logging system to track model progress**

In [None]:
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('training.log'),
        logging.StreamHandler()
    ]
)

**Configuration class for organization**

In [None]:
@dataclass
class YOLOConfig:
    k_folds: int
    batch_size: int
    epochs: int
    patience: int
    image_size: int
    model_type: str
    checkpoint_dir: str
    num_workers: int
    pin_memory: bool
    device: str

    @classmethod
    def create_default(cls):
        return cls(
            k_folds=5,
            batch_size=16,
            epochs=50,
            patience=5,
            image_size=640,
            model_type='yolov8n-seg.yaml',
            checkpoint_dir='/content/checkpoints',
            num_workers=4,
            pin_memory=True,
            device='cuda' if torch.cuda.is_available() else 'cpu'
        )

**Metrics tracking class**

In [None]:
class MetricsTracker:
    def __init__(self):
        self.metrics_history = {
            'train_loss': [],
            'val_loss': [],
            'map50': [],
            'map95': []
        }
        self.best_metrics = {
            'best_map50': 0,
            'best_map95': 0,
            'lowest_val_loss': float('inf')
        }

    def update(self, metrics: Dict[str, float]):
        """Update metrics history and best metrics"""
        for key, value in metrics.items():
            if key in self.metrics_history:
                self.metrics_history[key].append(value)

        if metrics.get('map50', 0) > self.best_metrics['best_map50']:
            self.best_metrics['best_map50'] = metrics['map50']
        if metrics.get('map95', 0) > self.best_metrics['best_map95']:
            self.best_metrics['best_map95'] = metrics['map95']
        if metrics.get('val_loss', float('inf')) < self.best_metrics['lowest_val_loss']:
            self.best_metrics['lowest_val_loss'] = metrics['val_loss']


**Early stopping class**

In [None]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.should_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
            return False

        if val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

        return self.should_stop

**Dataset validation function**

In [None]:
def validate_dataset(data_yaml_path: str) -> bool:
    """Validate the YOLO dataset structure and yaml file."""
    try:
        with open(data_yaml_path, 'r') as f:
            data_config = yaml.safe_load(f)

        required_keys = ['path', 'train', 'val', 'names']
        if not all(key in data_config for key in required_keys):
            logging.error(f"Missing required keys in {data_yaml_path}")
            return False

        # Convert relative paths to absolute
        base_path = Path(data_config['path']).resolve()
        for split in ['train', 'val']:
            split_path = base_path / data_config[split]
            if not split_path.exists():
                logging.error(f"Missing {split} dataset path: {split_path}")
                return False

        return True
    except Exception as e:
        logging.error(f"Dataset validation failed: {str(e)}")
        return False

**Checkpoint save function**

In [None]:
def save_checkpoint(model, fold, epoch, train_loss, val_loss, metrics, config, is_best=False):
    """Save model checkpoint with additional metrics"""
    checkpoint = {
        'fold': fold,
        'epoch': epoch,
        'train_loss': train_loss,
        'val_loss': val_loss,
        'metrics': metrics,
        'model_state': model.state_dict() if hasattr(model, 'state_dict') else None,
    }

    # Create checkpoint directory
    os.makedirs(config.checkpoint_dir, exist_ok=True)

    # Save regular checkpoint
    checkpoint_path = f"{config.checkpoint_dir}/fold_{fold}_epoch_{epoch}.pt"
    torch.save(checkpoint, checkpoint_path)
    logging.info(f"Saved checkpoint: {checkpoint_path}")

    # Save best model separately
    if is_best:
        best_path = f"{config.checkpoint_dir}/fold_{fold}_best_model.pt"
        torch.save(checkpoint, best_path)
        logging.info(f"Saved best model: {best_path}")

# Function to load checkpoint
def load_checkpoint(model, fold, config):
    """Load the latest checkpoint for a specific fold"""
    try:
        checkpoints = [f for f in os.listdir(config.checkpoint_dir)
                      if f.startswith(f'fold_{fold}_epoch_')]

        if not checkpoints:
            return 0, float('inf'), {}

        latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('_epoch_')[1].split('.')[0]))
        checkpoint_path = os.path.join(config.checkpoint_dir, latest_checkpoint)

        checkpoint = torch.load(checkpoint_path)
        if hasattr(model, 'load_state_dict'):
            model.load_state_dict(checkpoint['model_state'])

        logging.info(f"Loaded checkpoint: {checkpoint_path}")
        return checkpoint['epoch'] + 1, checkpoint['val_loss'], checkpoint.get('metrics', {})
    except Exception as e:
        logging.error(f"Failed to load checkpoint: {str(e)}")
        return 0, float('inf'), {}

**Prediction visualization function**

In [None]:
def visualize_predictions(image: np.ndarray, predictions: list, output_path: str):
    """Visualize detection results on the image"""
    try:
        img_copy = image.copy()
        for pred in predictions:
            box = pred['box'].astype(int)
            conf = pred['confidence']
            cv2.rectangle(img_copy, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
            cv2.putText(img_copy, f"{conf:.2f}", (box[0], box[1]-10),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
        cv2.imwrite(output_path, img_copy)
        logging.info(f"Saved visualization to {output_path}")
    except Exception as e:
        logging.error(f"Visualization failed: {str(e)}")

def evaluate_model(model: YOLO, test_images_path: str) -> Dict[str, float]:
    """Evaluate model performance on a test set"""
    try:
        results = model.val(data=test_images_path)
        metrics = {
            'precision': results.results_dict.get('metrics/precision(B)', 0),
            'recall': results.results_dict.get('metrics/recall(B)', 0),
            'mAP50': results.results_dict.get('metrics/mAP50(B)', 0),
            'mAP95': results.results_dict.get('metrics/mAP50-95(B)', 0)
        }
        return metrics
    except Exception as e:
        logging.error(f"Evaluation failed: {str(e)}")
        return {}


**Main machine learning setup**

In [None]:
if __name__ == "__main__":
    try:
        # Initialize configuration
        config = YOLOConfig.create_default()

        # Extract dataset
        with zipfile.ZipFile("/content/banana01.v3i.yolov8.zip", 'r') as zip_ref:
            zip_ref.extractall("/content/yolo_dataset")
        data_yaml_path = "/content/yolo_dataset/data.yaml"
        logging.info("Dataset extracted successfully")

        # Validate dataset
        if not validate_dataset(data_yaml_path):
            raise ValueError("Dataset validation failed")

        # Initialize YOLO model
        yolo_model = YOLO(config.model_type)
        logging.info("YOLO model initialized successfully")

        # Initialize metrics tracker
        metrics_tracker = MetricsTracker()

        # Set up K-Fold cross-validation
        kfold = KFold(n_splits=config.k_folds, shuffle=True)

        # Training loop with K-Fold Cross-Validation
        best_models = []

        for fold, (train_idx, val_idx) in enumerate(kfold.split(range(config.epochs))):
            logging.info(f'Starting training fold {fold + 1}/{config.k_folds}')

            # Load checkpoint if exists
            start_epoch, best_val_loss, previous_metrics = load_checkpoint(yolo_model, fold, config)

            # Initialize early stopping
            early_stopping = EarlyStopping(patience=config.patience)

            # Training configuration for this fold
            training_config = {
                'data': data_yaml_path,
                'epochs': config.epochs,
                'imgsz': config.image_size,
                'batch': config.batch_size,
                'name': f'banana_detection_model_fold_{fold}',
                'patience': config.patience,
                'resume': start_epoch > 0,
                'device': config.device,
                'workers': config.num_workers,
                'pin_memory': config.pin_memory
            }

**Main machine learning execution**

In [None]:
            results = yolo_model.train(**training_config)

**Metrics update function**

In [None]:
            metrics = {
                'map50': results.results_dict.get('metrics/mAP50(B)', 0),
                'map95': results.results_dict.get('metrics/mAP50-95(B)', 0),
                'val_loss': results.results_dict.get('val/box_loss', 0)
            }
            metrics_tracker.update(metrics)

**Checkpoint save**

In [None]:
            save_checkpoint(
                yolo_model,
                fold,
                config.epochs,
                results.results_dict.get('train/box_loss', 0),
                metrics['val_loss'],
                metrics,
                config,
                is_best=metrics['map50'] > metrics_tracker.best_metrics['best_map50']
            )

            # Save model for this fold
            model_path = f'banana_detection_model_fold_{fold}.pt'
            yolo_model.save(model_path)

            best_models.append({
                'fold': fold,
                'model_path': model_path,
                'metrics': metrics
            })

            logging.info(f'Completed training fold {fold + 1}')
            logging.info(f'Best metrics for fold {fold + 1}: {metrics_tracker.best_metrics}')

**Early stopping with final evaluation**

In [None]:
        # Check early stopping
            if early_stopping(metrics['val_loss']):
                logging.info(f"Early stopping triggered in fold {fold + 1}")
                break

        # Final evaluation
        logging.info("Training completed. Final metrics:")
        for metric, value in metrics_tracker.best_metrics.items():
            logging.info(f"{metric}: {value}")

    except Exception as e:
        logging.error(f"Training failed: {str(e)}")
        raise

**Sessions statistics class**

In [None]:
class SessionStatistics:
    def __init__(self):
        self.total_bananas = 0
        self.counts = {"Unripe": 0, "Ripe": 0, "Overripe": 0, "Rotten": 0}
        self.timestamps = []
        self.confidence_scores = []

    def update(self, ripeness, confidence=None):
        if ripeness not in self.counts:
            logging.error(f"Invalid ripeness value: {ripeness}")
            return

        self.total_bananas += 1
        self.counts[ripeness] += 1
        self.timestamps.append(datetime.now())
        if confidence is not None:
            self.confidence_scores.append(confidence)

    def calculate_quality_score(self):
        if self.total_bananas == 0:
            return 0
        rotten_count = self.counts["Rotten"]
        overripe_count = self.counts["Overripe"]
        quality_score = 100 - ((rotten_count + overripe_count) / self.total_bananas) * 100
        return quality_score

    def get_summary(self):
        quality_score = self.calculate_quality_score()
        rotten_percentage = (self.counts["Rotten"] / max(self.total_bananas, 1)) * 100

        summary = {
            "Total Bananas": self.total_bananas,
            "Unripe": self.counts["Unripe"],
            "Ripe": self.counts["Ripe"],
            "Overripe": self.counts["Overripe"],
            "Rotten": self.counts["Rotten"],
            "Rotten Percentage": rotten_percentage,
            "Quality Score": quality_score,
            "Session Timestamp": self.timestamps[0] if self.timestamps else "N/A",
            "Average Confidence": np.mean(self.confidence_scores) if self.confidence_scores else 0
        }

        return summary

**Database class**

In [None]:
class SimpleDatabase:
    def __init__(self):
        self.sessions = []

    def log_session(self, session_stats):
        try:
            self.sessions.append(session_stats.get_summary())
            logging.info("Session logged successfully")
        except Exception as e:
            logging.error(f"Failed to log session: {str(e)}")

    def retrieve_sessions(self):
        return pd.DataFrame(self.sessions)

    def get_statistics(self):
        df = self.retrieve_sessions()
        return {
            'total_sessions': len(df),
            'total_bananas': df['Total Bananas'].sum(),
            'average_quality_score': df['Quality Score'].mean(),
            'average_confidence': df['Average Confidence'].mean()
        }