In [2]:
import torch
import torchvision
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torch.utils.data import DataLoader, Dataset
import cv2
import numpy as np
import os
import json
import shutil
import matplotlib.pyplot as plt
import sys

# Try importing kagglehub for auto-download
try:
    import kagglehub
except ImportError:
    kagglehub = None
    print("Warning: 'kagglehub' not installed. Auto-download might fail.")
    print("Run: pip install kagglehub")

# ---------------------------------------------------------
# HELPER: DEBUG FILE STRUCTURE
# ---------------------------------------------------------
def print_directory_structure(startpath):
    print(f"\n--- Debugging Structure for: {startpath} ---")
    for root, dirs, files in os.walk(startpath):
        level = root.replace(startpath, '').count(os.sep)
        indent = ' ' * 4 * (level)
        print(f"{indent}{os.path.basename(root)}/")
        subindent = ' ' * 4 * (level + 1)
        # Print first 3 files only to keep log short
        for f in files[:3]:
            print(f"{subindent}{f}")
        if len(files) > 3:
            print(f"{subindent}... ({len(files)-3} more files)")

# ---------------------------------------------------------
# 1. DATASET SETUP & DOWNLOADER
# ---------------------------------------------------------
def download_and_prep_datasets():
    """
    Downloads SeaClear and TrashCan datasets using kagglehub.
    Ensures they are saved LOCALLY in the script's directory.
    """
    print("\n[1/4] Checking Datasets...")
    
    # --- A. SeaClear Download ---
    local_seaclear_path = "./seaclear_data"
    if not os.path.exists(local_seaclear_path) or not os.listdir(local_seaclear_path):
        print("   -> Fetching SeaClear (via kagglehub)...")
        try:
            cache_path = kagglehub.dataset_download("jocelyndumlao/seaclear-marine-debris-detection-and-segmentation")
            print(f"      Cached at: {cache_path}")
            print(f"      Moving to local dir: {local_seaclear_path}...")
            if os.path.exists(local_seaclear_path):
                shutil.rmtree(local_seaclear_path)
            shutil.copytree(cache_path, local_seaclear_path, dirs_exist_ok=True)
        except Exception as e:
            print(f"      Error downloading SeaClear: {e}")
            os.makedirs(local_seaclear_path, exist_ok=True)
    else:
        print(f"   -> SeaClear already exists at {local_seaclear_path}")

    # --- B. TrashCan Download ---
    local_trashcan_path = "./trashcan_data"
    if not os.path.exists(local_trashcan_path) or not os.listdir(local_trashcan_path):
        print("   -> Fetching TrashCan (via kagglehub)...")
        try:
            cache_path = kagglehub.dataset_download("yasht123/trashcan")
            print(f"      Cached at: {cache_path}")
            print(f"      Moving to local dir: {local_trashcan_path}...")
            if os.path.exists(local_trashcan_path):
                shutil.rmtree(local_trashcan_path)
            shutil.copytree(cache_path, local_trashcan_path, dirs_exist_ok=True)
            print(f"      TrashCan ready at: {local_trashcan_path}")
        except Exception as e:
            print(f"      Error downloading TrashCan: {e}")
            print("      If this persists, please download manually.")
            os.makedirs(local_trashcan_path, exist_ok=True)
    else:
        print(f"   -> TrashCan already exists at {local_trashcan_path}")

    return local_seaclear_path, local_trashcan_path

def find_dataset_components(root_dir, dataset_name):
    """
    Recursively searches for the 'images' folder and a valid 'json' annotation file.
    """
    img_dir = None
    ann_file = None
    
    print(f"   Searching {dataset_name} in: {root_dir}")

    # 1. Find ALL JSON files
    json_candidates = []
    for root, dirs, files in os.walk(root_dir):
        for file in files:
            if file.endswith(".json"):
                json_candidates.append(os.path.join(root, file))
    
    # 2. Heuristic for Annotation File
    for f in json_candidates:
        if "train" in f.lower() and "instance" in f.lower():
            ann_file = f
            break
    if not ann_file and json_candidates:
        for f in json_candidates:
            if "train" in f.lower():
                ann_file = f
                break
    if not ann_file and json_candidates:
        ann_file = json_candidates[0] 

    # 3. Find Image Directory
    candidate_img_dirs = []
    for root, dirs, files in os.walk(root_dir):
        image_count = sum(1 for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg')))
        if image_count > 5: 
            candidate_img_dirs.append((root, image_count))
    
    candidate_img_dirs.sort(key=lambda x: x[1], reverse=True)
    if candidate_img_dirs:
        img_dir = candidate_img_dirs[0][0]

    if img_dir and ann_file:
        print(f"      [OK] Found images in: .../{os.path.basename(img_dir)} ({candidate_img_dirs[0][1]} images)")
        print(f"      [OK] Found annotation: .../{os.path.basename(ann_file)}")
    else:
        print(f"      [FAIL] Could not locate components.")
        print_directory_structure(root_dir)

    return img_dir, ann_file

# ---------------------------------------------------------
# 2. COLOR SPACE TRANSFORMATIONS
# ---------------------------------------------------------
class ColorSpaceTransform:
    def __init__(self, color_space='RGB'):
        self.color_space = color_space.upper()

    def __call__(self, image_rgb):
        if self.color_space == 'HSV':
            img = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2HSV)
        elif self.color_space == 'LAB':
            img = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2LAB)
        elif self.color_space == 'HSL':
            img = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2HLS)
        elif self.color_space == 'GRAY':
            img = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2GRAY)
            img = cv2.merge([img, img, img]) 
        else:
            img = image_rgb 

        img = img.astype(np.float32) / 255.0
        img_tensor = torch.from_numpy(img).permute(2, 0, 1)
        return img_tensor

# ---------------------------------------------------------
# 3. DATASET LOADER (COCO Format)
# ---------------------------------------------------------
class MarineDebrisDataset(Dataset):
    def __init__(self, img_dir, annotation_file, color_space='RGB'):
        self.img_dir = img_dir
        self.color_transformer = ColorSpaceTransform(color_space)
        self.valid_dataset = False
        
        if img_dir and annotation_file and os.path.exists(annotation_file):
            try:
                print(f"   Loading JSON: {os.path.basename(annotation_file)}...")
                with open(annotation_file, 'r') as f:
                    self.coco_data = json.load(f)
                self.valid_dataset = True
            except Exception as e:
                print(f"   Error parsing JSON: {e}")
                self.valid_dataset = False
        
        if self.valid_dataset:
            self.images = {}
            self.ids = []
            self.ann_map = {}
            
            for ann in self.coco_data.get('annotations', []):
                img_id = ann['image_id']
                if img_id not in self.ann_map: self.ann_map[img_id] = []
                self.ann_map[img_id].append(ann)

            print(f"   Mapping images from {os.path.basename(self.img_dir)}...")
            count_found = 0
            for img in self.coco_data['images']:
                path_opts = [
                    os.path.join(self.img_dir, img['file_name']),
                    os.path.join(self.img_dir, os.path.basename(img['file_name']))
                ]
                final_path = None
                for p in path_opts:
                    if os.path.exists(p):
                        final_path = p
                        break
                
                if final_path:
                    self.images[img['id']] = {'path': final_path, 'info': img}
                    self.ids.append(img['id'])
                    count_found += 1
            
            print(f"   {count_found} valid images matched.")
            if count_found == 0:
                print("   WARNING: JSON loaded but no images matched filenames.")
                self.valid_dataset = False
        
        if not self.valid_dataset:
            print("   Using Mock Data (Dataset load failed)")
            self.ids = [1, 2] 

    def __getitem__(self, index):
        if self.valid_dataset:
            img_id = self.ids[index]
            data = self.images[img_id]
            path = data['path']
            img = cv2.imread(path)
            if img is None:
                img = np.zeros((480, 640, 3), dtype=np.uint8)
            else:
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            anns = self.ann_map.get(img_id, [])
        else:
            img = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
            anns = []

        img_tensor = self.color_transformer(img)
        
        boxes, labels, masks_list = [], [], []
        h, w = img.shape[:2]

        if self.valid_dataset and anns:
            for ann in anns:
                if 'bbox' not in ann: continue
                x, y, bw, bh = ann['bbox']
                if bw > 1 and bh > 1: 
                    boxes.append([x, y, x + bw, y + bh])
                    labels.append(1)
                    mask = np.zeros((h, w), dtype=np.uint8)
                    if 'segmentation' in ann:
                        seg = ann['segmentation']
                        if isinstance(seg, list) and len(seg) > 0 and isinstance(seg[0], list):
                             for s in seg:
                                poly = np.array(s).reshape((-1, 2)).astype(np.int32)
                                cv2.fillPoly(mask, [poly], 1)
                    masks_list.append(mask)
        elif not self.valid_dataset:
            boxes.append([100, 100, 200, 200])
            labels.append(1)
            mask = np.zeros((h, w), dtype=np.uint8)
            mask[100:200, 100:200] = 1
            masks_list.append(mask)

        if len(boxes) > 0:
            boxes_t = torch.as_tensor(boxes, dtype=torch.float32)
            labels_t = torch.as_tensor(labels, dtype=torch.int64)
            masks_t = torch.as_tensor(np.array(masks_list), dtype=torch.uint8)
        else:
            boxes_t = torch.zeros((0, 4), dtype=torch.float32)
            labels_t = torch.zeros((0,), dtype=torch.int64)
            masks_t = torch.zeros((0, h, w), dtype=torch.uint8)

        target = {
            "boxes": boxes_t, "labels": labels_t, "masks": masks_t,
            "image_id": torch.tensor([index])
        }
        return img_tensor, target

    def __len__(self):
        return len(self.ids)

# ---------------------------------------------------------
# 4. MODEL & EVALUATION LOGIC
# ---------------------------------------------------------
def get_model(num_classes):
    """
    Configures Mask R-CNN with specific settings for Underwater Waste.
    """
    # 1. Load Pretrained
    model = maskrcnn_resnet50_fpn(weights="DEFAULT")
    
    # 2. Replace Box Predictor (Classes: Background + Waste)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    
    # 3. Replace Mask Predictor
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, 256, num_classes)
    
    # 4. HANDLE CLASS IMBALANCE & SMALL OBJECTS
    # Increase positive fraction: forces model to sample more foreground (trash) 
    # regions during loss calculation. Default is 0.25.
    model.roi_heads.fg_bg_sampler.positive_fraction = 0.50
    
    # Increase detections per image to catch dense trash clusters
    model.roi_heads.detections_per_img = 100
    
    return model

def calculate_iou(pred, gt):
    inter = np.logical_and(pred, gt).sum()
    union = np.logical_or(pred, gt).sum()
    return inter / union if union > 0 else 0.0

def train_and_evaluate(dataset_name, color_space, img_dir, ann_file, device, epochs=50):
    print(f"\n--- Running: {dataset_name} | {color_space} ---")
    
    # 1. Prepare Data (FULL DATASET USE)
    dataset = MarineDebrisDataset(img_dir, ann_file, color_space)
    if len(dataset) < 5:
        print("Dataset too small. Skipping.")
        return 0.0, 0.0

    # Random Split - Using 80/20 split on full data
    indices = torch.randperm(len(dataset)).tolist()
    split = int(0.8 * len(dataset))
    
    train_ds = torch.utils.data.Subset(dataset, indices[:split]) 
    test_ds = torch.utils.data.Subset(dataset, indices[split:])

    print(f"   Training on {len(train_ds)} images | Validation on {len(test_ds)} images")
    print(f"   Epochs: {epochs}")

    # Increase batch size if GPU allows to stabilize gradients
    batch_size = 4 
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))

    # 2. Model & Optimizer
    model = get_model(2).to(device)
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
    
    # Learning Rate Scheduler (Essential for 50+ epochs)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

    # 3. Train Loop
    for ep in range(epochs):
        model.train()
        losses = []
        for imgs, tgts in train_loader:
            imgs = list(img.to(device) for img in imgs)
            tgts = [{k: v.to(device) for k, v in t.items()} for t in tgts]
            
            if not imgs: continue
            
            loss_dict = model(imgs, tgts)
            loss = sum(l for l in loss_dict.values())
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
        
        lr_scheduler.step()
        
        if (ep+1) % 5 == 0:
            avg_loss = np.mean(losses) if losses else 0
            print(f"   Epoch {ep+1}/{epochs} | Loss: {avg_loss:.4f} | LR: {optimizer.param_groups[0]['lr']}")

    # 4. Evaluation Loop
    model.eval()
    ious = []
    tp, fp, fn = 0, 0, 0
    with torch.no_grad():
        for imgs, tgts in test_loader:
            imgs = list(img.to(device) for img in imgs)
            output = model(imgs)[0]
            target = tgts[0]
            
            gt_masks = target['masks'].cpu().numpy()
            if len(gt_masks) == 0: continue

            if len(output['masks']) > 0:
                pred_masks = (output['masks'][:, 0].cpu().numpy() > 0.5).astype(np.uint8)
                
                matched = set()
                for p_mask in pred_masks:
                    best_iou = 0
                    best_idx = -1
                    for idx, g_mask in enumerate(gt_masks):
                        iou = calculate_iou(p_mask, g_mask)
                        if iou > best_iou:
                            best_iou = iou
                            best_idx = idx
                    
                    if best_iou > 0.5:
                        if best_idx not in matched:
                            tp += 1
                            matched.add(best_idx)
                            ious.append(best_iou)
                        else:
                            fp += 1
                    else:
                        fp += 1
                fn += len(gt_masks) - len(matched)
            else:
                fn += len(gt_masks)

    mean_iou = np.mean(ious) if ious else 0
    precision = tp / (tp + fp + 1e-6)
    recall = tp / (tp + fn + 1e-6)
    f1 = 2 * (precision * recall) / (precision + recall + 1e-6)
    
    print(f"   Result -> F1: {f1:.4f} | mIoU: {mean_iou:.4f}")
    return f1, mean_iou

def plot_results(results):
    print("\n[4/4] Plotting Results...")
    datasets = list(results.keys())
    if not datasets:
        print("No datasets to plot.")
        return

    color_spaces = list(results[datasets[0]].keys())
    x = np.arange(len(datasets))
    width = 0.2
    
    fig, ax = plt.subplots(figsize=(10, 6))
    for i, color in enumerate(color_spaces):
        scores = [results[ds].get(color, {'f1':0})['f1'] for ds in datasets]
        offset = width * i
        rects = ax.bar(x + offset, scores, width, label=color)

    ax.set_ylabel('F1 Score')
    ax.set_title('Impact of Color Space (50 Epochs, Full Data)')
    ax.set_xticks(x + width)
    ax.set_xticklabels(datasets)
    ax.legend(title="Color Space")
    ax.set_ylim(0, 1.0)
    ax.grid(axis='y', linestyle='--', alpha=0.7)
    
    plt.tight_layout()
    plt.savefig('experiment_results_full.png')
    print("Graph saved to 'experiment_results_full.png'")

# ---------------------------------------------------------
# 5. MAIN EXECUTION
# ---------------------------------------------------------
if __name__ == "__main__":
    # --- CROSS-PLATFORM DIAGNOSTICS ---
    print("\n--- System Diagnostics ---")
    print(f"   OS: {sys.platform}")
    print(f"   Python: {sys.version.split()[0]}")
    print(f"   PyTorch: {torch.__version__}")
    
    if torch.cuda.is_available():
        print(f"   GPU Available: YES - {torch.cuda.get_device_name(0)}")
        DEVICE = torch.device('cuda')
    else:
        print("   GPU Available: NO")
        print("   [WARNING] Running on CPU. Training will be slow.")
        
        # Advice based on Platform
        if sys.platform == 'linux':
            print("   To fix on Linux/WSL:")
            print("   1. Stop this script.")
            print("   2. Run: pip install torch torchvision --upgrade")
            print("      (PyTorch on Linux usually finds CUDA automatically)")
        elif sys.platform == 'win32':
             print("   To fix on Windows:")
             print("   1. Stop this script.")
             print("   2. Run: pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121")
        
        DEVICE = torch.device('cpu')
    print("--------------------------\n")

    # UPDATED CONFIGURATION
    EPOCHS_PER_RUN = 50 
    
    # 1. Download
    p_seaclear, p_trashcan = download_and_prep_datasets()
    
    # 2. Locate components
    ds_paths = {
        'SeaClear': find_dataset_components(p_seaclear, 'SeaClear'),
        'TrashCan': find_dataset_components(p_trashcan, 'TrashCan')
    }
    
    # 3. Experiment Loop
    color_spaces_to_test = ['RGB', 'HSV', 'LAB']
    results_data = {}

    print(f"\n[2/4] Starting Full Experiments (Epochs={EPOCHS_PER_RUN}) on Device: {DEVICE}")
    
    for ds_name, (img_dir, ann_file) in ds_paths.items():
        if not img_dir or not ann_file:
            print(f"Skipping {ds_name} (Missing files)")
            continue
            
        results_data[ds_name] = {}
        
        for cs in color_spaces_to_test:
            f1, iou = train_and_evaluate(ds_name, cs, img_dir, ann_file, DEVICE, EPOCHS_PER_RUN)
            results_data[ds_name][cs] = {'f1': f1, 'iou': iou}

    # 4. Save & Plot
    print("\n[3/4] Experiment Complete. Saving Summary...")
    with open("results_summary_full.txt", "w") as f:
        f.write(json.dumps(results_data, indent=4))
        
    if results_data:
        plot_results(results_data)
    else:
        print("No results to plot.")

ModuleNotFoundError: No module named 'torch'