In [23]:
!pip install pycocotools

Collecting pycocotools
  Downloading pycocotools-2.0.8-cp311-cp311-win_amd64.whl.metadata (1.1 kB)
Downloading pycocotools-2.0.8-cp311-cp311-win_amd64.whl (85 kB)
   ---------------------------------------- 0.0/85.3 kB ? eta -:--:--
   -------------- ------------------------- 30.7/85.3 kB 1.3 MB/s eta 0:00:01
   -------------- ------------------------- 30.7/85.3 kB 1.3 MB/s eta 0:00:01
   ---------------------------------------- 85.3/85.3 kB 682.4 kB/s eta 0:00:00
Installing collected packages: pycocotools
Successfully installed pycocotools-2.0.8


In [5]:
import torch
import torchvision
from torchvision.models.detection import RetinaNet
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.models import mobilenet_v3_small
from torchvision.transforms import v2 as transforms
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torchvision.ops import box_iou
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import xml.etree.ElementTree as ET
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
from datetime import datetime
import os
import json

In [7]:

class CoralDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.class_map = {'Healthy Coral': 2, 'Dead Coral': 1, 'Bleached Coral': 0}
        self.image_files = list(self.root_dir.glob('*.jpg')) + list(self.root_dir.glob('*.png'))
        self.annotation_cache = {}

    def __len__(self): return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        xml_path = self.root_dir / f'{img_path.stem}.xml'
        
        # Load image
        image = Image.open(img_path).convert("RGB")
        
        # Load cached or parse XML
        if img_path in self.annotation_cache:
            boxes, labels = self.annotation_cache[img_path]
        else:
            tree = ET.parse(xml_path)
            root = tree.getroot()
            size = root.find('size')
            width = int(size.find('width').text)
            height = int(size.find('height').text)
            
            boxes, labels = [], []
            for obj in root.iter('object'):
                class_name = obj.find('name').text
                if class_name not in self.class_map: continue
                
                bndbox = obj.find('bndbox')
                xmin = float(bndbox.find('xmin').text) / width
                ymin = float(bndbox.find('ymin').text) / height
                xmax = float(bndbox.find('xmax').text) / width
                ymax = float(bndbox.find('ymax').text) / height
                
                boxes.append([xmin, ymin, xmax, ymax])
                labels.append(self.class_map[class_name])
            
            self.annotation_cache[img_path] = (boxes, labels)

        # Convert to tensors
        # print(f"Labels for image {img_path.name}: {labels}")
        boxes = torch.tensor(boxes, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.int64)
        # print(f"Shape of labels tensor for {img_path.name}: {labels.shape}") # Add this line
        
        target = {
            'boxes': boxes,
            'labels': labels,
            'image_id': torch.tensor([idx]),
            'area': (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]),
            'iscrowd': torch.zeros(len(boxes), dtype=torch.int64)
        }

        if self.transform:
            image = self.transform(image)
            
        return image, target

In [208]:
class CoralRetinaNetMobile:
    def __init__(self, num_classes=3, pretrained=True):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.class_names = ['Bleached', 'Dead', 'Healthy']
        self.checkpoint_dir = Path('checkpoints')
        self.checkpoint_dir.mkdir(exist_ok=True)
        self.best_score = 0.0
        self.best_combined = 0.0  # Add this
        self.best_dead_precision = 0.0  # Add this
        self.train_losses = []
        self.val_losses = [] 
        
        # Model setup
        backbone = mobilenet_v3_small(weights='DEFAULT' if pretrained else None).features
        backbone.out_channels = 576
        
        anchor_generator = AnchorGenerator(
            sizes=((16, 32, 64, 128),),
            aspect_ratios=((0.5, 1.0, 2.0),) * 4
        )
        
        self.model = RetinaNet(
            backbone,
            num_classes=num_classes,
            anchor_generator=anchor_generator,
            box_score_thresh=0.25
        ).to(self.device)
        
        # Checkpoint loading
        if (self.checkpoint_dir / 'last.pt').exists():
            self.load_checkpoint()
        else:
            self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5, weight_decay=1e-4)
            self.scaler = torch.amp.GradScaler()
            self.start_epoch = 0
            self.best_map = 0.0  # Track mAP instead of loss

    def load_checkpoint(self):
        checkpoint = torch.load(self.checkpoint_dir / 'last.pt', map_location=self.device)
        self.model.load_state_dict(checkpoint['model'])
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5, weight_decay=1e-4)  #lr=1e-5 wd=0
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.start_epoch = checkpoint['epoch'] + 1
        self.best_map = checkpoint.get('best_map', 0.0)
        self.scaler = torch.amp.GradScaler()
        print(f"Resuming training from epoch {self.start_epoch}")

    def get_transform(self, train=True):
        return transforms.Compose([
            transforms.ToImage(),
            transforms.ConvertImageDtype(torch.float32),
            transforms.Lambda(lambda img: img / 255.0),
            transforms.Resize((320, 320), antialias=True),
            transforms.RandomHorizontalFlip(p=0.3) if train else transforms.Identity(),
            # transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            # transforms.RandomRotation(degrees=(-10, 10)),
            # transforms.RandomVerticalFlip(p=0.1),
        ])

    def create_loaders(self, data_root):
        train_loader = DataLoader(
            CoralDataset(Path(data_root)/'train', self.get_transform(True)),
            batch_size=16,
            shuffle=True,
            num_workers=0,  # Required for Windows stability
            pin_memory=True,
            collate_fn=lambda x: tuple(zip(*x))
        )
        val_loader = DataLoader(
            CoralDataset(Path(data_root)/'valid', self.get_transform(False)),
            batch_size=32,
            num_workers=0,
            pin_memory=True,
            collate_fn=lambda x: tuple(zip(*x))
        )
        return train_loader, val_loader

    def train(self, train_loader, epochs=80, val_loader=None, early_stop_patience=2):
        self.model.train()
        torch.backends.cudnn.benchmark = True
        early_stop_counter = 0
    # Ensure self.best_score is defined in __init__ (e.g., self.best_score = 0.0)

        for epoch in range(self.start_epoch, epochs):
            print(f"\nEpoch {epoch+1}/{epochs}")
            epoch_loss = 0.0

        # Training phase
            for i, (images, targets) in enumerate(train_loader):
                losses = torch.tensor(0.0, device=self.device) # Initialize losses here
                try:
                    images = [img.to(self.device) for img in images]
                    targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
                    # print("First target in the batch:", targets[0])
                    # print("Image Shapes:", [img.shape for img in images])
                    # print("Target Shapes:", [t['boxes'].shape for t in targets], [t['labels'].shape for t in targets]) # Add this line
                    
                    with torch.amp.autocast(device_type='cuda', enabled=self.device.type == 'cuda'):
                        loss_dict = self.model(images, targets)
                        losses = loss_dict["classification"] + loss_dict["bbox_regression"]
                        #print("Loss Dictionary:", loss_dict)
                        # losses = sum(loss for loss in loss_dict.values())

                        # Handle loss calculation safely
                        if isinstance(loss_dict, dict):
                            losses = sum(torch.mean(loss) for loss in loss_dict.values())
                        elif isinstance(loss_dict, list):
                            losses = sum(
                            torch.mean(sum(item.values())) if isinstance(item, dict) else torch.mean(item)
                            for item in loss_dict
                            )
                        else:
                            raise ValueError(f"Unexpected loss type: {type(loss_dict)}")

                    self.optimizer.zero_grad()
                    self.scaler.scale(losses).backward()
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    epoch_loss += losses.item()


                # Log batch loss every 50 batches
                    if (i + 1) % 50 == 0:
                        print(f"Batch {i+1}/{len(train_loader)} - Loss: {losses.item():.4f}")
                except RuntimeError as e:
                    print(f"Error in batch {i+1}: {e}")
                    print(f"Image shapes: {[img.shape for img in images]}")
                    print(f"Target keys: {[list(t.keys()) for t in targets]}")
                    raise e
                except Exception as e:
                    print(f"Error in batch {i+1}/{len(train_loader)}: {str(e)}")
                    print(f"Image shapes: {[img.shape for img in images]}")
                    print(f"Target keys: {[list(t.keys()) for t in targets]}")
                    raise e

            avg_epoch_loss = epoch_loss / len(train_loader)
            self.train_losses.append(avg_epoch_loss)
            print(f"Epoch {epoch+1} - Train Loss: {avg_epoch_loss:.4f}")

        # Validation phase
            if val_loader:
                metrics = self.validate(val_loader)  # Assume it returns a dict with key 'f1'
                current_map = metrics['map']  # Will always exist
                current_f1 = metrics['f1']  # From cls_metrics
                dead_coral_recall = metrics['per_class_metrics']['Dead']['recall']
                
            # Use F1 score for early stopping
                # 1. Primary: Mean Average Precision (mAP)
                if metrics['map'] > self.best_map:
                    self.best_map = metrics['map']
                    self.save_checkpoint(epoch, 'best_map.pt')
    
    # 2. Secondary: Combined Detection+Classification Score
                combined_score = (0.7 * metrics['map'] + 
                        0.3 * metrics['f1'])
                if combined_score > self.best_combined:
                    self.best_combined = combined_score
                    self.save_checkpoint(epoch, 'best_combined.pt')
    
    # 3. Class-Specific: Critical Class Precision (e.g., "Dead" coral)
                # dead_precision = metrics['confusion_matrix'][1,1] / metrics['confusion_matrix'][:,1].sum()
                # if dead_precision > self.best_dead_precision:
                #     self.best_dead_precision = dead_precision
                #     self.save_checkpoint(epoch, 'best_dead_precision.pt')
    
    # Early stopping based on primary metric
                if metrics['map'] > self.best_map:
                    early_stop_counter = 0
                    self.best_map = metrics['map']
                else:
                    early_stop_counter += 1
    

                print(f"Epoch {epoch+1} - Train Loss: {epoch_loss/len(train_loader):.4f} | Cls.F1: {metrics['f1']:.2f}")
            else:
                print(f"Epoch {epoch+1} - Train Loss: {epoch_loss/len(train_loader):.4f}")
            self._plot_training_curves(self.train_losses, self.val_losses)
            self.save_checkpoint(epoch, 'last.pt')

        
    def validate(self, loader, test_mode=False):
        metric = MeanAveragePrecision()
        self.model.eval()
        all_preds = []
        all_targets = []
        detection_stats = {
        'total_true': 0,
        'total_pred': 0,
        'correct': 0
        }
        per_class_metrics = {class_name: {'tp': 0, 'fp': 0, 'fn': 0} 
                            for class_name in self.class_names}

        with torch.no_grad():
            for images, targets in loader:
                images = [img.to(self.device) for img in images]
                targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
            
                outputs = self.model(images)

                preds = []
                for out in outputs:
                    preds.append({
                        'boxes': out['boxes'],
                        'scores': out['scores'], 
                        'labels': out['labels']
                    })
                metric.update(preds, targets)
            
            # Process predictions and calculate matches
                for out, tgt in zip(outputs, targets):
                    if len(out['boxes']) == 0 or len(tgt['boxes']) == 0:
                        continue

                    # Inside the validation loop, for each image pair (out, tgt)
                    detection_stats['total_pred'] += len(out['boxes'])
                    detection_stats['total_true'] += len(tgt['boxes'])
                    
                    iou_matrix = box_iou(out['boxes'], tgt['boxes'])
                    matched_true = [False] * len(tgt['labels'])
                
                    for pred_idx in range(len(out['boxes'])):
                        best_iou = -1
                        best_true_idx = -1
                    
                        for true_idx in range(len(tgt['boxes'])):
                            if not matched_true[true_idx] and iou_matrix[pred_idx, true_idx] >= 0.5:
                                if iou_matrix[pred_idx, true_idx] > best_iou:
                                    best_iou = iou_matrix[pred_idx, true_idx]
                                    best_true_idx = true_idx
                    
                        if best_true_idx != -1:
                            matched_true[best_true_idx] = True
                            pred_label = out['labels'][pred_idx].item()
                            true_label = tgt['labels'][best_true_idx].item()
                        
                            if pred_label == true_label:
                                detection_stats['correct'] += 1
                                per_class_metrics[self.class_names[true_label]]['tp'] += 1
                            else:
                                per_class_metrics[self.class_names[pred_label]]['fp'] += 1
                                per_class_metrics[self.class_names[true_label]]['fn'] += 1
                        
                            all_preds.append(pred_label)
                            all_targets.append(true_label)
                        else:
                            pred_label = out['labels'][pred_idx].item()
                            per_class_metrics[self.class_names[pred_label]]['fp'] += 1
                
                # Count unmatched true boxes as false negatives
                    for true_idx, matched in enumerate(matched_true):
                        if not matched:
                            true_label = tgt['labels'][true_idx].item()
                            per_class_metrics[self.class_names[true_label]]['fn'] += 1

    # Calculate metrics
        precision = detection_stats['correct'] / detection_stats['total_pred'] if detection_stats['total_pred'] > 0 else 0
        recall = detection_stats['correct'] / detection_stats['total_true'] if detection_stats['total_true'] > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    # Only calculate confusion matrix in test mode
        cm = None
        if len(all_preds) > 0:
            cls_metrics = {
            'precision': precision_score(all_targets, all_preds, average='weighted', zero_division=0),
            'recall': recall_score(all_targets, all_preds, average='weighted', zero_division=0),
            'f1': f1_score(all_targets, all_preds, average='weighted', zero_division=0)
            }
            if test_mode:  # Only calculate CM during testing
                cm = confusion_matrix(all_targets, all_preds, labels=range(len(self.class_names)))
                self._plot_confusion_matrix(cm)
                cls_metrics['confusion_matrix'] = cm.tolist()
        else:
            cls_metrics = {
                'precision': 0,
                'recall': 0,
                'f1': 0,
                'confusion_matrix': []
            }
    
    # Calculate per-class metrics
        class_wise_results = {}
        for class_name in self.class_names:
            tp = per_class_metrics[class_name]['tp']
            fp = per_class_metrics[class_name]['fp']
            fn = per_class_metrics[class_name]['fn']
        
            class_wise_results[class_name] = {
                'precision': tp / (tp + fp) if (tp + fp) > 0 else 0,
            'recall': tp / (tp + fn) if (tp + fn) > 0 else 0,
                'f1': 2*tp/(2*tp + fp + fn) if (tp + fp + fn) > 0 else 0,
            'support': tp + fn
            }

        combined_metrics = {
            'map': metric.compute()['map'].item(),
            'detection_precision': precision,
            'detection_recall': recall,
            'detection_f1': f1,
            'per_class_metrics': class_wise_results,  # Keeps your exact key name
            **cls_metrics,
            'timestamp': datetime.now().isoformat()
        }
    
        self.model.train()
        return combined_metrics

    def analyze_test_set(self, data_root):
        test_dir = Path(data_root) / 'test'
        test_set = CoralDataset(
        root_dir=test_dir,
        transform=self.get_transform(train=False))
        test_loader = DataLoader(
            test_set, 
            batch_size=8,
            collate_fn=lambda x: tuple(zip(*x)))
    
        test_metrics = self.validate(test_loader, test_mode=True)

        print("\nPer-Class Test Metrics:")
        for class_name, metrics in test_metrics['per_class_metrics'].items():
            print(f"Class: {class_name}")
            print(f"  Precision: {metrics['precision']:.2%}")
            print(f"  Recall: {metrics['recall']:.2%}")
            print(f"  F1 Score: {metrics['f1']:.2%}")
            print(f"  Support: {metrics['support']}")
        
        print("\nFinal Detection Test Metrics:")
        print(f"Precision: {test_metrics['detection_precision']:.2%}")
        print(f"Recall: {test_metrics['detection_recall']:.2%}")
        print(f"F1 Score: {test_metrics['detection_f1']:.2%}")
    
        return test_metrics
        
    def _calculate_metrics(self, true_labels, pred_labels):
        # Classification metrics
        precision = precision_score(true_labels, pred_labels, average='weighted', zero_division=0)
        recall = recall_score(true_labels, pred_labels, average='weighted', zero_division=0)
        f1 = f1_score(true_labels, pred_labels, average='weighted', zero_division=0)
        cm = confusion_matrix(true_labels, pred_labels)
        
        self._plot_confusion_matrix(cm)
        return {
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'confusion_matrix': cm.tolist()
        }

    def _plot_confusion_matrix(self, cm):
        plt.figure(figsize=(10, 8))
        try:
            sns.heatmap(cm, annot=True, fmt='d',
                  xticklabels=self.class_names,
                  yticklabels=self.class_names,
                  cmap='Blues')
            plt.title('Confusion Matrix')
            plt.tight_layout()
        
        # Save to checkpoint directory with high resolution
            save_path = self.checkpoint_dir / 'confusion_matrix.png'
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"Saved confusion matrix to: {save_path}")
        
        except Exception as e:
            print(f"Error plotting confusion matrix: {str(e)}")
        finally:
            plt.close()
    

    def _plot_training_curves(self, train_losses, val_losses):
        plt.figure(figsize=(10, 6))
        try:
            plt.plot(train_losses, label='Train Loss', linewidth=2)
            if val_losses:
                plt.plot(val_losses, label='Validation Loss', linewidth=2)
            
            plt.xlabel('Epochs', fontsize=12)
            plt.ylabel('Loss', fontsize=12)
            plt.legend(fontsize=12)
            plt.grid(True, alpha=0.3)
            plt.title('Training Curves', fontsize=14)
        
        # Save to checkpoint directory
            save_path = self.checkpoint_dir / 'training_curves.png'
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"Saved training curves to: {save_path}")
        
        except Exception as e:
            print(f"Error plotting training curves: {str(e)}")
        finally:
            plt.close()

    # def analyze_test_set(self, data_root):
    #     test_dir = Path(data_root) / 'test'
    #     test_set = CoralDataset(
    #         root_dir=test_dir,
    #         transform=self.get_transform(train=False))
    #     test_loader = DataLoader(
    #         test_set, 
    #         batch_size=8,
    #         collate_fn=lambda x: tuple(zip(*x)))
    
    # # Single return value capture
    #     test_metrics = self.validate(test_loader, test_mode=True)

    #     print("\nPer-Class Test Metrics:")
    #     for class_name, metrics in test_metrics['per_class_metrics'].items():
    #         print(f"Class: {class_name}")
    #         print(f"  Precision: {metrics['precision']:.2%}")
    #         print(f"  Recall: {metrics['recall']:.2%}")
    #         print(f"  F1 Score: {metrics['f1']:.2%}")

    #     print("\nFinal Detection Test Metrics:")
    #     print(f"Precision: {test_metrics['detection_precision']:.2%}")
    #     print(f"Recall: {test_metrics['detection_recall']:.2%}")
    #     print(f"F1 Score: {test_metrics['detection_f1']:.2%}")
    
    # # Optional: Save test metrics separately
    #     with open(self.checkpoint_dir / 'test_metrics1.json', 'w') as f:
    #         json.dump(test_metrics, f, indent=4)
        
    #     return test_metrics

    def save_checkpoint(self, epoch, filename):
        torch.save({
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'epoch': epoch,
            'best_map': self.best_map
        }, self.checkpoint_dir / filename)


In [210]:
analyzer = CoralRetinaNetMobile(num_classes=3)

  checkpoint = torch.load(self.checkpoint_dir / 'last.pt', map_location=self.device)


Resuming training from epoch 22


In [212]:
data_root = r"C:\Users\Rugved\Downloads\marjan-segmentaion.v15i.voc"
    
train_loader, val_loader = analyzer.create_loaders(data_root)
print(f"Training batches: {len(train_loader)}, Validation batches: {len(val_loader)}")

Training batches: 310, Validation batches: 12


In [214]:
analyzer.train(train_loader, epochs=24, val_loader=val_loader)


Epoch 23/24
Batch 50/310 - Loss: 0.5097
Batch 100/310 - Loss: 0.7898
Batch 150/310 - Loss: 0.5264
Batch 200/310 - Loss: 0.4766
Batch 250/310 - Loss: 0.5872
Batch 300/310 - Loss: 0.4943
Epoch 23 - Train Loss: 0.6034




Epoch 23 - Train Loss: 0.6034 | Cls.F1: 0.49
Saved training curves to: checkpoints\training_curves.png

Epoch 24/24
Batch 50/310 - Loss: 0.5762
Batch 100/310 - Loss: 0.3928
Batch 150/310 - Loss: 0.6090
Batch 200/310 - Loss: 0.6087
Batch 250/310 - Loss: 0.5474
Batch 300/310 - Loss: 0.4911
Epoch 24 - Train Loss: 0.5443




Epoch 24 - Train Loss: 0.5443 | Cls.F1: 0.54
Saved training curves to: checkpoints\training_curves.png


In [216]:
test_metrics = analyzer.analyze_test_set(data_root)
print("\nFinal Test Metrics:")
print(f"Precision: {test_metrics['precision']:.2%}")
print(f"Recall: {test_metrics['recall']:.2%}")
print(f"F1 Score: {test_metrics['f1']:.2%}")

Saved confusion matrix to: checkpoints\confusion_matrix.png

Per-Class Test Metrics:
Class: Bleached
  Precision: 3.60%
  Recall: 3.44%
  F1 Score: 3.52%
  Support: 349
Class: Dead
  Precision: 5.26%
  Recall: 5.75%
  F1 Score: 5.50%
  Support: 313
Class: Healthy
  Precision: 6.03%
  Recall: 6.91%
  F1 Score: 6.44%
  Support: 304

Final Detection Test Metrics:
Precision: 4.99%
Recall: 5.28%
F1 Score: 5.13%

Final Test Metrics:
Precision: 67.88%
Recall: 56.04%
F1 Score: 56.99%




In [218]:
print(test_metrics)

{'map': -1.0, 'detection_precision': 0.04985337243401759, 'detection_recall': 0.052795031055900624, 'detection_f1': 0.05128205128205129, 'per_class_metrics': {'Bleached': {'precision': 0.036036036036036036, 'recall': 0.034383954154727794, 'f1': 0.03519061583577713, 'support': 349}, 'Dead': {'precision': 0.05263157894736842, 'recall': 0.05750798722044728, 'f1': 0.0549618320610687, 'support': 313}, 'Healthy': {'precision': 0.0603448275862069, 'recall': 0.06907894736842106, 'f1': 0.06441717791411043, 'support': 304}}, 'precision': 0.6788323457104176, 'recall': 0.5604395604395604, 'f1': 0.5698979591836734, 'confusion_matrix': [[12, 0, 13], [2, 18, 19], [3, 3, 21]], 'timestamp': '2025-03-30T19:44:09.985825'}


In [23]:
# class CoralRetinaNetMobile:   #vlast
#     def __init__(self, num_classes=3, pretrained=True):
#         self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#         self.class_names = ['Bleached', 'Dead', 'Healthy']
#         self.checkpoint_dir = Path('checkpoints')
#         self.checkpoint_dir.mkdir(exist_ok=True)
        
#         # Model setup
#         backbone = mobilenet_v3_small(weights='DEFAULT' if pretrained else None).features
#         backbone.out_channels = 576
        
#         anchor_generator = AnchorGenerator(
#             sizes=((16, 32, 64, 128),),
#             aspect_ratios=((0.5, 1.0, 2.0),) * 4
#         )
        
#         self.model = RetinaNet(
#             backbone,
#             num_classes=num_classes,
#             anchor_generator=anchor_generator,
#             box_score_thresh=0.25
#         ).to(self.device)
        
#         # Checkpoint loading
#         if (self.checkpoint_dir / 'last.pt').exists():
#             self.load_checkpoint()
#         else:
#             self.optimizer = torch.optim.RAdam(self.model.parameters(), lr=3e-4)
#             self.scaler = torch.amp.GradScaler()
#             self.start_epoch = 0
#             self.best_map = 0.0  # Track mAP instead of loss

#     def load_checkpoint(self):
#         checkpoint = torch.load(self.checkpoint_dir / 'last.pt', map_location=self.device)
#         self.model.load_state_dict(checkpoint['model'])
#         self.optimizer = torch.optim.RAdam(self.model.parameters())
#         self.optimizer.load_state_dict(checkpoint['optimizer'])
#         self.start_epoch = checkpoint['epoch'] + 1
#         self.best_map = checkpoint.get('best_map', 0.0)
#         self.scaler = torch.amp.GradScaler()
#         print(f"Resuming training from epoch {self.start_epoch}")

#     def get_transform(self, train=True):
#         return transforms.Compose([
#             transforms.ToImage(),
#             transforms.ToDtype(torch.float32, scale=True),
#             transforms.Resize((320, 320), antialias=True),
#             transforms.RandomHorizontalFlip(p=0.3) if train else transforms.Identity(),
#         ])

#     def create_loaders(self, data_root):
#         train_loader = DataLoader(
#             CoralDataset(Path(data_root)/'train', self.get_transform(True)),
#             batch_size=16,
#             shuffle=True,
#             num_workers=0,  # Required for Windows stability
#             pin_memory=True,
#             collate_fn=lambda x: tuple(zip(*x))
#         )
#         val_loader = DataLoader(
#             CoralDataset(Path(data_root)/'valid', self.get_transform(False)),
#             batch_size=32,
#             num_workers=0,
#             pin_memory=True,
#             collate_fn=lambda x: tuple(zip(*x))
#         )
#         return train_loader, val_loader

#     def train(self, train_loader, epochs=80, val_loader=None, early_stop_patience=12):
#         self.model.train()
#         torch.backends.cudnn.benchmark = True
#         early_stop_counter = 0

#         for epoch in range(self.start_epoch, epochs):
#             print(f"\nEpoch {epoch+1}/{epochs}")
#             epoch_loss = 0.0

#         # Training phase
#             for i, (images, targets) in enumerate(train_loader):
#                 images = [img.to(self.device) for img in images]
#                 targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]

#                 with torch.amp.autocast(device_type='cuda', enabled=self.device.type == 'cuda'):
#                     loss_dict = self.model(images, targets)
#                     losses = sum(loss for loss in loss_dict.values())

#                 self.optimizer.zero_grad()
#                 self.scaler.scale(losses).backward()
#                 self.scaler.step(self.optimizer)
#                 self.scaler.update()
#                 epoch_loss += losses.item()

#             # Log batch loss every 30 batches
#                 if (i + 1) % 30 == 0:
#                     print(f"Batch {i+1}/{len(train_loader)} - Loss: {losses.item():.4f}")

#         # Validation phase
#             if val_loader:
#                 map_score, metrics = self.validate(val_loader)
            
#                 if map_score > self.best_map:
#                     self.best_map = map_score
#                     self.save_checkpoint(epoch, 'best.pt')
#                     early_stop_counter = 0
#                 else:
#                     early_stop_counter += 1

#                 if early_stop_counter >= early_stop_patience:
#                     print(f"Early stopping at epoch {epoch+1}")
#                     break

#             self.save_checkpoint(epoch, 'last.pt')
#             print(f"Epoch {epoch+1} - Train Loss: {epoch_loss/len(train_loader):.4f} | mAP: {map_score:.4f}")
        
#     def validate(self, loader):
#         metric = MeanAveragePrecision()
#         self.model.eval()
#         all_preds = []
#         all_targets = []
        
#         with torch.no_grad():
#             for images, targets in loader:
#                 images = [img.to(self.device) for img in images]
#                 targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
                
#                 outputs = self.model(images)
                
#                 # For detection metrics
#                 metric.update(
#                     [{'boxes': o['boxes'], 'scores': o['scores'], 'labels': o['labels']} for o in outputs],
#                     [{'boxes': t['boxes'], 'labels': t['labels']} for t in targets]
#                 )
                
#                 # For classification-style metrics
#                 for out, tgt in zip(outputs, targets):
#                     if len(out['labels']) > 0:
#                         all_preds.extend(out['labels'].cpu().numpy())
#                         all_targets.extend(tgt['labels'].cpu().numpy())

#         # Detection metrics
#         detection_metrics = metric.compute()
        
#         # Classification-style metrics
#         if len(all_preds) > 0 and len(all_targets) > 0:
#             cls_metrics = self._calculate_metrics(np.array(all_targets), np.array(all_preds))
#         else:
#             cls_metrics = {}
        
#         # Combine metrics
#         combined_metrics = {
#             **detection_metrics,
#             **cls_metrics,
#             'timestamp': datetime.now().isoformat()
#         }
        
#         with open('metrics.json', 'w') as f:
#             json.dump(combined_metrics, f)
            
#         return detection_metrics['map'], combined_metrics

#     def _calculate_metrics(self, true_labels, pred_labels):
#         # Classification metrics
#         precision = precision_score(true_labels, pred_labels, 
#                                   average='weighted', zero_division=0)
#         recall = recall_score(true_labels, pred_labels,
#                             average='weighted', zero_division=0)
#         f1 = f1_score(true_labels, pred_labels,
#                     average='weighted', zero_division=0)
#         cm = confusion_matrix(true_labels, pred_labels)
        
#         # Plot confusion matrix
#         self._plot_confusion_matrix(cm)
        
#         return {
#             'precision': precision,
#             'recall': recall,
#             'f1': f1,
#             'confusion_matrix': cm.tolist()
#         }

#     def _plot_confusion_matrix(self, cm):
#         plt.figure(figsize=(10,8))
#         sns.heatmap(cm, annot=True, fmt='d', 
#                   xticklabels=self.class_names,
#                   yticklabels=self.class_names)
#         plt.title('Confusion Matrix')
#         plt.savefig('confusion_matrix.png')
#         plt.close()

#     def _plot_training_curves(self, train_losses, val_losses):
#         plt.figure()
#         plt.plot(train_losses, label='Train Loss')
#         if val_losses:
#             plt.plot(val_losses, label='Validation Loss')
#         plt.xlabel('Epochs')
#         plt.ylabel('Loss')
#         plt.legend()
#         plt.title('Training Curves')
#         plt.savefig('training_curves.png')
#         plt.close()

#     def analyze_test_set(self, data_root):
#         test_dir = Path(data_root) / 'test'
#         test_set = CoralDataset(
#             root_dir=test_dir,
#             transform=self.get_transform(train=False)
#         )
#         test_loader = DataLoader(
#             test_set, batch_size=8,
#             collate_fn=lambda x: tuple(zip(*x))
#         )
#         _, metrics = self.validate(test_loader)
#         return metrics

#     def save_checkpoint(self, epoch, filename):
#         torch.save({
#             'model': self.model.state_dict(),
#             'optimizer': self.optimizer.state_dict(),
#             'epoch': epoch,
#             'best_map': self.best_map
#         }, self.checkpoint_dir / filename)

In [15]:



# class CoralRetinaNetMobile:
#     def __init__(self, num_classes=3, pretrained=True):
#         self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#         self.class_names = ['Bleached', 'Dead', 'Healthy']
#         self.checkpoint_dir = Path('checkpoints')
#         self.checkpoint_dir.mkdir(exist_ok=True)
        
#         # Load MobileNetV3-small backbone
#         backbone = mobilenet_v3_small(weights='DEFAULT' if pretrained else None).features
#         backbone.out_channels = 576  # Critical for MobileNetV3-small
        
#         # Custom anchor generator for coral sizes
#         anchor_generator = AnchorGenerator(
#             sizes=((16, 32, 64, 128),),
#             aspect_ratios=((0.5, 1.0, 2.0),) * 4
#         )
        
#         # Initialize model
#         self.model = RetinaNet(
#             backbone,
#             num_classes=num_classes,
#             anchor_generator=anchor_generator,
#             box_score_thresh=0.25
#         ).to(self.device)
        
#         # Load checkpoint if exists
#         if (self.checkpoint_dir / 'last.pt').exists():
#             self.load_checkpoint()
#         else:
#             self.optimizer = torch.optim.RAdam(self.model.parameters(), lr=3e-4)
#             self.scaler = self.create_scaler()
#             self.start_epoch = 0
#             self.best_map = 0.0  # Track mAP instead of loss

#     def create_scaler(self):
#         return torch.amp.GradScaler('cuda', enabled=self.device.type == 'cuda')

#     def load_checkpoint(self):
#         checkpoint = torch.load(self.checkpoint_dir / 'last.pt', map_location=self.device)
#         self.model.load_state_dict(checkpoint['model'])
#         self.optimizer = torch.optim.RAdam(self.model.parameters())
#         self.optimizer.load_state_dict(checkpoint['optimizer'])
#         self.start_epoch = checkpoint['epoch'] + 1
#         self.best_loss = checkpoint['best_loss']
#         self.scaler = self.create_scaler()
#         print(f"Resuming training from epoch {self.start_epoch}")

#     def get_transform(self, train=True):
#         return transforms.Compose([
#             transforms.ToImage(),
#             transforms.ToDtype(torch.float32, scale=True),
#             transforms.Resize((320, 320), antialias=True),
#             transforms.RandomHorizontalFlip(p=0.3) if train else transforms.Identity(),
#             transforms.ColorJitter(brightness=0.15, contrast=0.15) if train else transforms.Identity(),
#         ])

#     def create_loaders(self, data_root):
#         train_loader = DataLoader(
#             CoralDataset(Path(data_root)/'train', self.get_transform(True)),
#             batch_size=16,
#             shuffle=True,
#             num_workers=0,
#             pin_memory=True,
#             collate_fn=lambda x: tuple(zip(*x))
#         )
#         val_loader = DataLoader(
#             CoralDataset(Path(data_root)/'valid', self.get_transform(False)),
#             batch_size=32,
#             num_workers=0,
#             pin_memory=True,
#             collate_fn=lambda x: tuple(zip(*x))
#         )
#         return train_loader, val_loader

#     def train(self, train_loader, epochs=80, val_loader=None, early_stop_patience=12):
#         self.model.train()
#         torch.backends.cudnn.benchmark = True
#         train_losses, val_losses = [], []
#         early_stop_counter = 0

#         for epoch in range(self.start_epoch, epochs):
#             print(f"\nEpoch {epoch+1}/{epochs}")
#             epoch_loss = 0.0

#             for i, (images, targets) in enumerate(train_loader):
#                 print(f"Processing first batch (size: {len(images)})") if i == 0 else None

#                 try:
#                 # Move data to the device
#                     images = [img.to(self.device) for img in images]
#                     targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
#                     if i == 0:
#                         print(f"  Batch {i+1}: Images moved to device")
                
#                 # Forward pass with automatic mixed precision
#                     with torch.cuda.amp.autocast(enabled=self.device.type == 'cuda'):
#                         if i == 0:
#                             print(f"  Batch {i+1}: Starting forward pass")
#                         loss_dict = self.model(images, targets)
#                         if i == 0:
#                             print(f"  Batch {i+1}: Forward pass complete, computing loss")
#                         losses = sum(loss for loss in loss_dict.values())
#                         if i == 0:
#                             print(f"  Batch {i+1}: Loss computed: {losses.item():.4f}")
                
#                 # Zero gradients, backward pass, and step
#                     self.optimizer.zero_grad()
#                     if i == 0:
#                         print(f"  Batch {i+1}: Starting backward pass")
#                     self.scaler.scale(losses).backward()
#                     if i == 0:
#                         print(f"  Batch {i+1}: Backward pass complete")
#                     self.scaler.step(self.optimizer)
#                     self.scaler.update()

#                     epoch_loss += losses.item()

#                     if i % 20 == 0:
#                         print(f"  Batch {i+1}/{len(train_loader)} processed, current loss: {losses.item():.4f}")

#                 except Exception as e:
#                     print(f"Error in batch {i+1}: {str(e)}")
#                     print(f"Image shapes: {[img.shape for img in images]}")
#                     print(f"Target keys: {[list(t.keys()) for t in targets]}")
#                     raise e

#             avg_train_loss = epoch_loss / len(train_loader)
#             train_losses.append(avg_train_loss)
#             print(f"Train Loss: {avg_train_loss:.4f}")

#         # Validation
#             if val_loader:
#                 val_loss, metrics = self.validate(val_loader)
#                 val_losses.append(val_loss)

#                 if val_loss < self.best_loss:
#                     self.best_loss = val_loss
#                     self.save_checkpoint(epoch, 'best.pt')
#                     early_stop_counter = 0
#                 else:
#                     early_stop_counter += 1

#                 if early_stop_counter >= early_stop_patience:
#                     print(f"Early stopping at epoch {epoch+1}")
#                     break

#             self.save_checkpoint(epoch, 'last.pt')

#         self._plot_training_curves(train_losses, val_losses)
#         return train_losses, val_losses


#     def save_checkpoint(self, epoch, filename):
#         torch.save({
#             'model': self.model.state_dict(),
#             'optimizer': self.optimizer.state_dict(),
#             'epoch': epoch,
#             'best_loss': self.best_loss
#         }, self.checkpoint_dir / filename)

#     def validate(self, loader):
#         self.model.eval()
#         total_loss = 0.0
#         all_preds, all_targets = [], []
        
#         with torch.no_grad():
#             for images, targets in loader:
#                 images = [img.to(self.device) for img in images]
#                 targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
            
#             # Case 1: Get losses (requires gradient)
#                 with torch.amp.autocast(device_type='cuda', enabled=self.device.type == 'cuda'):
#                     loss_dict = self.model(images, targets)  # Only works in train mode
                
#                 # Handle both dict and list outputs
#                     if isinstance(loss_dict, dict):
#                         total_loss += sum(loss for loss in loss_dict.values()).item()
#                     else:
#                     # If no losses returned (eval mode), skip loss calculation
#                         total_loss = float('nan')
            
#             # Case 2: Get predictions
#                 outputs = self.model(images)  # Forward pass without targets
#                 for out, tgt in zip(outputs, targets):
#                     all_preds.extend(out['labels'].cpu().numpy())
#                     all_targets.extend(tgt['labels'].cpu().numpy())

#         avg_loss = total_loss if not isinstance(total_loss, float) else 0.0
#         metrics = self._calculate_metrics(np.array(all_targets), np.array(all_preds))
#         return avg_loss, metrics

#     # Keep _calculate_metrics, _plot_training_curves, and analyze_test_set 
#     # implementations from previous code (they remain the same)

#     def _calculate_metrics(self, true_labels, pred_labels):
#         # Classification metrics
#         precision = precision_score(true_labels, pred_labels, average='weighted', zero_division=0)
#         recall = recall_score(true_labels, pred_labels, average='weighted', zero_division=0)
#         f1 = f1_score(true_labels, pred_labels, average='weighted', zero_division=0)
#         cm = confusion_matrix(true_labels, pred_labels)
        
#         # Plot confusion matrix
#         plt.figure(figsize=(10,8))
#         sns.heatmap(cm, annot=True, fmt='d', xticklabels=self.class_names, yticklabels=self.class_names)
#         plt.title('Confusion Matrix')
#         plt.savefig('confusion_matrix.png')
#         plt.close()
        
#         # Save metrics
#         metrics = {
#             'precision': precision,
#             'recall': recall,
#             'f1': f1,
#             'confusion_matrix': cm.tolist(),
#             'timestamp': datetime.now().isoformat()
#         }
        
#         with open('metrics.json', 'w') as f:
#             json.dump(metrics, f)
        
#         print(f"Validation Metrics - Precision: {precision:.2%}, Recall: {recall:.2%}, F1: {f1:.2%}")
#         return metrics

#     def _plot_training_curves(self, train_losses, val_losses):
#         plt.figure()
#         plt.plot(train_losses, label='Train Loss')
#         if val_losses:
#             plt.plot(val_losses, label='Validation Loss')
#         plt.xlabel('Epochs')
#         plt.ylabel('Loss')
#         plt.legend()
#         plt.title('Training Curves')
#         plt.savefig('training_curves.png')
#         plt.close()

#     def analyze_test_set(self, data_root):
#         test_dir = Path(data_root) / 'test'
#         test_set = CoralDataset(
#             root_dir=test_dir,
#             transform=self.get_transform(train=False)
#         )
#         test_loader = DataLoader(
#             test_set, batch_size=8,
#             collate_fn=lambda x: tuple(zip(*x))
#         )
#         _, metrics = self.validate(test_loader)
#         return metrics


# # # Usage example
# # if __name__ == "__main__":
# #     analyzer = CoralRetinaNetMobile(num_classes=3)
# #     data_root = r"C:\Users\Rugved\Downloads\marjan-segmentaion.v15i.voc"
    
# #     train_loader, val_loader = analyzer.create_loaders(data_root)
# #     print(f"Training batches: {len(train_loader)}, Validation batches: {len(val_loader)}")
    
# #     analyzer.train(train_loader, epochs=80, val_loader=val_loader)
    
# #     test_metrics = analyzer.analyze_test_set(Path(data_root)/'test')
# #     print("\nFinal Test Metrics:")
# #     print(f"Precision: {test_metrics['precision']:.2%}")
# #     print(f"Recall: {test_metrics['recall']:.2%}")
# #     print(f"F1 Score: {test_metrics['f1']:.2%}")

In [None]:
def validate(self, loader, test_mode=False):
    metric = MeanAveragePrecision()
    self.model.eval()
    all_preds = []
    all_targets = []
    detection_stats = {
        'total_true': 0,
        'total_pred': 0,
        'correct': 0
    }
    per_class_metrics = {class_name: {'tp': 0, 'fp': 0, 'fn': 0} 
                        for class_name in self.class_names}

    with torch.no_grad():
        for images, targets in loader:
            images = [img.to(self.device) for img in images]
            targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
            
            outputs = self.model(images)
            
            # Process predictions and calculate matches
            for out, tgt in zip(outputs, targets):
                if len(out['boxes']) == 0 or len(tgt['boxes']) == 0:
                    continue
                    
                iou_matrix = box_iou(out['boxes'], tgt['boxes'])
                matched_true = [False] * len(tgt['labels'])
                
                for pred_idx in range(len(out['boxes'])):
                    best_iou = -1
                    best_true_idx = -1
                    
                    for true_idx in range(len(tgt['boxes'])):
                        if not matched_true[true_idx] and iou_matrix[pred_idx, true_idx] >= 0.5:
                            if iou_matrix[pred_idx, true_idx] > best_iou:
                                best_iou = iou_matrix[pred_idx, true_idx]
                                best_true_idx = true_idx
                    
                    if best_true_idx != -1:
                        matched_true[best_true_idx] = True
                        pred_label = out['labels'][pred_idx].item()
                        true_label = tgt['labels'][best_true_idx].item()
                        
                        if pred_label == true_label:
                            detection_stats['correct'] += 1
                            per_class_metrics[self.class_names[true_label]]['tp'] += 1
                        else:
                            per_class_metrics[self.class_names[pred_label]]['fp'] += 1
                            per_class_metrics[self.class_names[true_label]]['fn'] += 1
                        
                        all_preds.append(pred_label)
                        all_targets.append(true_label)
                    else:
                        pred_label = out['labels'][pred_idx].item()
                        per_class_metrics[self.class_names[pred_label]]['fp'] += 1
                
                # Count unmatched true boxes as false negatives
                for true_idx, matched in enumerate(matched_true):
                    if not matched:
                        true_label = tgt['labels'][true_idx].item()
                        per_class_metrics[self.class_names[true_label]]['fn'] += 1

    # Calculate metrics
    precision = detection_stats['correct'] / detection_stats['total_pred'] if detection_stats['total_pred'] > 0 else 0
    recall = detection_stats['correct'] / detection_stats['total_true'] if detection_stats['total_true'] > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    # Only calculate confusion matrix in test mode
    cm = None
    if test_mode and len(all_preds) > 0:
        cm = confusion_matrix(all_targets, all_preds, labels=range(len(self.class_names)))
        self._plot_confusion_matrix(cm)
    
    # Calculate per-class metrics
    class_wise_results = {}
    for class_name in self.class_names:
        tp = per_class_metrics[class_name]['tp']
        fp = per_class_metrics[class_name]['fp']
        fn = per_class_metrics[class_name]['fn']
        
        class_wise_results[class_name] = {
            'precision': tp / (tp + fp) if (tp + fp) > 0 else 0,
            'recall': tp / (tp + fn) if (tp + fn) > 0 else 0,
            'f1': 2*tp/(2*tp + fp + fn) if (tp + fp + fn) > 0 else 0,
            'support': tp + fn
        }

    combined_metrics = {
        'map': metric.compute()['map'].item(),
        'detection_precision': precision,
        'detection_recall': recall,
        'detection_f1': f1,
        'per_class_metrics': class_wise_results,  # Keeps your exact key name
        'confusion_matrix': cm.tolist() if cm is not None else [],
        'timestamp': datetime.now().isoformat()
    }
    
    self.model.train()
    return combined_metrics

def analyze_test_set(self, data_root):
    test_dir = Path(data_root) / 'test'
    test_set = CoralDataset(
        root_dir=test_dir,
        transform=self.get_transform(train=False))
    test_loader = DataLoader(
        test_set, 
        batch_size=8,
        collate_fn=lambda x: tuple(zip(*x)))
    
    test_metrics = self.validate(test_loader, test_mode=True)

    print("\nPer-Class Test Metrics:")
    for class_name, metrics in test_metrics['per_class_metrics'].items():
        print(f"Class: {class_name}")
        print(f"  Precision: {metrics['precision']:.2%}")
        print(f"  Recall: {metrics['recall']:.2%}")
        print(f"  F1 Score: {metrics['f1']:.2%}")
        print(f"  Support: {metrics['support']}")

    print("\nFinal Detection Test Metrics:")
    print(f"Precision: {test_metrics['detection_precision']:.2%}")
    print(f"Recall: {test_metrics['detection_recall']:.2%}")
    print(f"F1 Score: {test_metrics['detection_f1']:.2%}")
    
    return test_metrics