# Does not work for now

In [None]:
# syringe_keypoint_training_final.py
import torch
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

import numpy as np
import os
import json
import cv2
import random
import time
import gc
from datetime import datetime
from matplotlib import pyplot as plt
from typing import List, Dict
from collections import defaultdict
from fvcore.common.config import CfgNode as CN
from detectron2.engine.hooks import HookBase
from detectron2.evaluation import inference_context
from detectron2.utils.visualizer import Visualizer
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor, DefaultTrainer
from detectron2.data import (
    MetadataCatalog, 
    DatasetCatalog, 
    build_detection_test_loader,
    build_detection_train_loader,
    DatasetMapper
)
from detectron2.data.datasets import register_coco_instances
from detectron2.evaluation import COCOEvaluator, inference_on_dataset
from detectron2.modeling import build_model
from detectron2.utils.events import EventStorage
from detectron2.utils.visualizer import ColorMode
import detectron2.data.transforms as T
from tqdm import tqdm
import threading
from detectron2.checkpoint import DetectionCheckpointer

# --------------------------
# 1. Dataset Configuration
# --------------------------
def filter_and_register_dataset(json_path: str, img_dir: str, dataset_name: str) -> str:
    """Robust dataset registration with validation."""
    if not os.path.exists(json_path):
        raise FileNotFoundError(f"Annotation file missing: {json_path}")
    if not os.path.exists(img_dir):
        raise NotADirectoryError(f"Image directory missing: {img_dir}")

    with open(json_path) as f:
        coco_data = json.load(f)
    
    # Identify the syringe category
    syringe1_id = next(c["id"] for c in coco_data["categories"] if c["name"] == "syringe1")
    
    # Filter annotations
    valid_annos = []
    for a in coco_data["annotations"]:
        if a["category_id"] == syringe1_id:
            kps = np.array(a["keypoints"], dtype=np.float32).reshape(-1, 3)
            if kps.shape[1] != 3:
                raise ValueError(f"Invalid keypoints in image {a['image_id']}")
            a["keypoints"] = kps.reshape(-1).tolist()
            a["num_keypoints"] = kps.shape[0]
            valid_annos.append(a)
    
    valid_img_ids = {a["image_id"] for a in valid_annos}
    valid_imgs = [img for img in coco_data["images"] if img["id"] in valid_img_ids]
    
    # Make sure categories have "keypoints" and "skeleton"
    categories = []
    for c in coco_data["categories"]:
        if c["name"] == "syringe1":
            c["keypoints"] = [
                "plunger_top",
                "plunger_bottom",
                "syringe_top",
                "syringe_bottom"
            ]
            # Example skeleton (1-based indexing in COCO):
            # Connect (plunger_top <-> plunger_bottom) and (syringe_top <-> syringe_bottom)
            c["skeleton"] = [[1, 2], [3, 4]]
            categories.append(c)
    
    filtered_data = {
        "images": valid_imgs,
        "annotations": valid_annos,
        "categories": categories
    }
    
    filtered_json = os.path.join(os.path.dirname(json_path), f"filtered_{dataset_name}.json")
    with open(filtered_json, 'w') as f:
        json.dump(filtered_data, f)
    
    try:
        register_coco_instances(dataset_name, {}, filtered_json, img_dir)
        if len(DatasetCatalog.get(dataset_name)) == 0:
            raise ValueError(f"Empty dataset after registration: {dataset_name}")
    except Exception as e:
        print(f"❌ Dataset registration failed: {str(e)}")
        exit(1)
    
    return filtered_json

# Initialize datasets FIRST
dataset_base = "/Users/andreas/Desktop/repos/masterthesis/datasets/Syringe-volume-estimation-detectron2"
try:
    train_json = filter_and_register_dataset(
        os.path.join(dataset_base, "train/_annotations.coco.json"),
        os.path.join(dataset_base, "train"),
        "syringe_train"
    )
    val_json = filter_and_register_dataset(
        os.path.join(dataset_base, "valid/_annotations.coco.json"),
        os.path.join(dataset_base, "valid"),
        "syringe_val"
    )
except Exception as e:
    print(f"❌ Dataset initialization failed: {str(e)}")
    exit(1)

# --------------------------
# 2. Metadata Configuration
# --------------------------
syringe_metadata = MetadataCatalog.get("syringe_train")
syringe_metadata.thing_classes = ["syringe1"]
syringe_metadata.keypoint_names = ["plunger_top", "plunger_bottom", "syringe_top", "syringe_bottom"]
syringe_metadata.keypoint_flip_map = []
syringe_metadata.keypoint_connection_rules = [
    ("plunger_top", "plunger_bottom", (0, 255, 0)),
    ("syringe_top", "syringe_bottom", (255, 0, 0)),
]

# --------------------------
# 3. Model Configuration
# --------------------------
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml"))

# Clear default datasets FIRST
cfg.DATASETS.TRAIN = ()
cfg.DATASETS.TEST = ()
cfg.MODEL.WEIGHTS = ""

# Set custom datasets
cfg.DATASETS.TRAIN = ("syringe_train",)
cfg.DATASETS.TEST = ("syringe_val",)

# Verify dataset registration
assert len(DatasetCatalog.get(cfg.DATASETS.TRAIN[0])) > 0, "No training data registered!"
assert len(DatasetCatalog.get(cfg.DATASETS.TEST[0])) > 0, "No validation data registered!"

# Device setup
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Model architecture
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
cfg.MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS = 4
cfg.TEST.KEYPOINT_OKS_SIGMAS = [0.5] * 4

# Training parameters
cfg.SOLVER.BASE_LR = 0.00002
cfg.SOLVER.MAX_ITER = 100
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.WARMUP_ITERS = 100
cfg.SOLVER.WARMUP_FACTOR = 0.001
cfg.SOLVER.GAMMA = 0.1
cfg.SOLVER.WEIGHT_DECAY = 0.0005

# Gradient clipping
cfg.SOLVER.CLIP_GRADIENTS = CN()
cfg.SOLVER.CLIP_GRADIENTS.ENABLED = True
cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "norm"
cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE = 1.0
cfg.SOLVER.CLIP_GRADIENTS.NORM_TYPE = 2.0

# Validation
cfg.TEST.EVAL_PERIOD = 10

# Input config
cfg.INPUT.MIN_SIZE_TRAIN = (600,)
cfg.INPUT.MAX_SIZE_TRAIN = 1200
cfg.INPUT.MIN_SIZE_TEST = 800
cfg.INPUT.MAX_SIZE_TEST = 1200

# Output
cfg.OUTPUT_DIR = os.path.join(dataset_base, "output")
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

# --------------------------
# 4. Training Hooks
# --------------------------
class TrainingMonitorHook(HookBase):
    def __init__(self, max_iter: int):
        self.max_iter = max_iter
        self.pbar = None
        self.start_time = time.time()
        
    def before_train(self):
        self.pbar = tqdm(
            total=self.max_iter,
            desc="Training Progress",
            unit="iter",
            bar_format="{l_bar}{bar:30}{r_bar}",
            postfix={
                "loss": "N/A", 
                "iter/s": "N/A",
                "eta": "N/A"
            }
        )
        
    def after_step(self):
        metrics = self.trainer.storage.latest()
        elapsed = time.time() - self.start_time
        ips = self.trainer.iter / elapsed if elapsed > 0 else 0
        eta = (self.max_iter - self.trainer.iter) / ips if ips > 0 else 0
        
        self.pbar.set_postfix({
            "loss": f"{metrics.get('total_loss', (0,))[0]:.3f}",
            "iter/s": f"{ips:.2f}",
            "eta": f"{eta/3600:.1f}h" if eta > 3600 else f"{eta/60:.1f}m"
        })
        self.pbar.update(1)
        
        # Memory management
        if self.trainer.iter % 5 == 0:
            gc.collect()
        
    def after_train(self):
        self.pbar.close()
        print(f"\nTraining completed in {(time.time()-self.start_time)/3600:.2f} hours")

class ValidationHook(HookBase):
    def __init__(self, cfg):
        self.cfg = cfg.clone()
        self.evaluator = COCOEvaluator(
            "syringe_val",
            output_dir=self.cfg.OUTPUT_DIR,
            use_fast_impl=False,
            kpt_oks_sigmas=cfg.TEST.KEYPOINT_OKS_SIGMAS
        )
        
    def after_step(self):
        if self.trainer.iter % self.cfg.TEST.EVAL_PERIOD == 0:
            try:
                start_time = time.time()
                model = self.trainer.model
                original_mode = model.training
                
                with inference_context(model):
                    val_loader = build_detection_test_loader(self.cfg, "syringe_val")
                    results = inference_on_dataset(model, val_loader, self.evaluator)
                    
                    print(f"\nValidation @ Iter {self.trainer.iter}:")
                    print(json.dumps(results["keypoints"], indent=2))
                    print(f"Validation time: {time.time()-start_time:.1f}s")
                
                model.train(original_mode)
                
            except Exception as e:
                print(f"\n⚠️ Validation error: {str(e)}")
                model.train(True)

# --------------------------
# 5. Training Execution
# --------------------------
if __name__ == "__main__":
    # Verify dataset registration
    print("\nRegistered datasets:")
    print(f"Train: {len(DatasetCatalog.get('syringe_train'))} samples")
    print(f"Validation: {len(DatasetCatalog.get('syringe_val'))} samples")
    
    # Verify config values
    assert len(cfg.DATASETS.TRAIN) > 0, "No training datasets configured!"
    assert len(cfg.DATASETS.TEST) > 0, "No validation datasets configured!"
    
    # Initialize trainer
    trainer = DefaultTrainer(cfg)
    trainer.register_hooks([
        TrainingMonitorHook(cfg.SOLVER.MAX_ITER),
        ValidationHook(cfg)
    ])
    
    # Model verification
    DetectionCheckpointer(trainer.model).resume_or_load(cfg.MODEL.WEIGHTS)
    print("\nModel architecture verified:")
    print(trainer.model)
    
    # Hardware info
    print("\n" + "="*50)
    print(f"Training on {'GPU' if torch.cuda.is_available() else 'CPU'}")
    if torch.cuda.is_available():
        print(f"Device: {torch.cuda.get_device_name(0)}")
        print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.2f}GB")
    else:
        print("⚠️ WARNING: Training on CPU - expect slow performance")
        print(f"Estimated time: {cfg.SOLVER.MAX_ITER*5/60:.1f} minutes (5s/iter)")
    print("="*50 + "\n")
    
    try:
        trainer.train()
    except KeyboardInterrupt:
        print("\nTraining interrupted by user!")
    finally:
        # Final evaluation BEFORE cleanup
        final_model_path = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
        if os.path.exists(final_model_path):
            try:
                cfg.MODEL.WEIGHTS = final_model_path
                evaluator = COCOEvaluator("syringe_val", output_dir=cfg.OUTPUT_DIR)
                val_loader = build_detection_test_loader(cfg, "syringe_val")
                print("\nFinal evaluation results:")
                print(inference_on_dataset(trainer.model, val_loader, evaluator))
            except Exception as e:
                print(f"\n⚠️ Final evaluation failed: {str(e)}")
        
        # Cleanup temporary files AFTER evaluation
        try:
            if os.path.exists(train_json):
                os.remove(train_json)
            if os.path.exists(val_json):
                os.remove(val_json)
            print("Cleaned temporary files")
        except Exception as e:
            print(f"⚠️ Cleanup error: {str(e)}")