<a href="https://colab.research.google.com/github/corypham/CourtCheck/blob/main/save_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
from detectron2.engine import DefaultTrainer
from detectron2.config import get_cfg
from detectron2.data.datasets import register_coco_instances
from detectron2 import model_zoo
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.evaluation import COCOEvaluator

# Unregister the dataset if it already exists
def unregister_dataset(dataset_name):
    if dataset_name in DatasetCatalog.list():
        DatasetCatalog.pop(dataset_name)
        MetadataCatalog.pop(dataset_name)

# Define keypoint metadata
keypoint_names = [
    "BTL", "BTLI", "BTRI", "BTR", "BBR", "BBRI", "IBR", "NR", "NM", "ITL",
    "ITM", "ITR", "NL", "BBL", "IBL", "IBM", "BBLI"
]

keypoint_flip_map = [
    ("BTL", "BTR"), ("BTLI", "BTRI"), ("BBL", "BBR"), ("BBLI", "BBRI"), ("ITL", "ITR"),
    ("ITM", "ITM"), ("NL", "NR"), ("IBL", "IBR"), ("IBM", "IBM"), ("NM", "NM")
]

skeleton = []

# Custom trainer with evaluator
class TrainerWithEval(DefaultTrainer):
    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        if output_folder is None:
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
        return COCOEvaluator(dataset_name, cfg, True, output_folder)

# Function to set up and train the model with mixed datasets incrementally
def train_model_incrementally(game_numbers, max_iter, resume=False):
    for game_number in game_numbers:
        unregister_dataset(f"tennis_game{game_number}_train")
        unregister_dataset(f"tennis_game{game_number}_val")

        json_train_file = f"/content/drive/MyDrive/ASA Tennis Bounds Project/models/court_detection_model/annotations/model_annotations/games/game{game_number}/game{game_number}_train.json"
        json_val_file = f"/content/drive/MyDrive/ASA Tennis Bounds Project/models/court_detection_model/annotations/model_annotations/games/game{game_number}/game{game_number}_val.json"
        image_root_train = f"/content/drive/MyDrive/ASA Tennis Bounds Project/models/court_detection_model/dataset/game{game_number}/game{game_number}_train"
        image_root_val = f"/content/drive/MyDrive/ASA Tennis Bounds Project/models/court_detection_model/dataset/game{game_number}/game{game_number}_val"
        register_coco_instances(f"tennis_game{game_number}_train", {}, json_train_file, image_root_train)
        register_coco_instances(f"tennis_game{game_number}_val", {}, json_val_file, image_root_val)

        MetadataCatalog.get(f"tennis_game{game_number}_train").keypoint_names = keypoint_names
        MetadataCatalog.get(f"tennis_game{game_number}_train").keypoint_flip_map = keypoint_flip_map
        MetadataCatalog.get(f"tennis_game{game_number}_train").keypoint_connection_rules = skeleton

        MetadataCatalog.get(f"tennis_game{game_number}_val").keypoint_names = keypoint_names
        MetadataCatalog.get(f"tennis_game{game_number}_val").keypoint_flip_map = keypoint_flip_map
        MetadataCatalog.get(f"tennis_game{game_number}_val").keypoint_connection_rules = skeleton

    cfg = get_cfg()
    cfg.merge_from_file("/content/drive/MyDrive/ASA Tennis Bounds Project/models/court_detection_model/detectron2/configs/COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml")
    cfg.DATASETS.TRAIN = tuple([f"tennis_game{game_number}_train" for game_number in game_numbers])
    cfg.DATASETS.TEST = tuple([f"tennis_game{game_number}_val" for game_number in game_numbers])
    cfg.DATALOADER.NUM_WORKERS = 4
    cfg.SOLVER.IMS_PER_BATCH = 4  # increase if you have more GPU memory
    cfg.SOLVER.BASE_LR = 0.0001  # lower learning rate for more careful training
    cfg.SOLVER.MAX_ITER = max_iter  # total number of iterations
    cfg.SOLVER.STEPS = [int(max_iter*0.75), int(max_iter*0.875)]  # decay learning rate
    cfg.SOLVER.GAMMA = 0.1  # decay factor
    cfg.SOLVER.CHECKPOINT_PERIOD = 500  # save a checkpoint every 500 iterations
    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 256  # increase for more stable gradients
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 11  # your dataset has 11 classes

    output_dir = "/content/drive/MyDrive/ASA Tennis Bounds Project/models/court_detection_model/detectron2/game_model"
    cfg.OUTPUT_DIR = output_dir
    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

    trainer = TrainerWithEval(cfg)
    trainer.resume_or_load(resume=resume)
    trainer.train()

# Train with mixed datasets incrementally
# train_model_incrementally([1], 8000, resume=False)        # Train with game 1 to 8000 iterations (initial training)
# train_model_incrementally([1, 2], 16000, resume=True)     # Continue training with games 1 and 2 to 16000 iterations
# train_model_incrementally([1, 2, 3], 24000, resume=True)  # Continue training with games 1, 2, and 3 to 24000 iterations
train_model_incrementally([1, 2, 3, 4], 44000, resume=True)
# train_model_incrementally([1, 2, 3, 4, 5], 40000, resume=True)
# train_model_incrementally([1, 2, 3, 4, 5, 6], 48000, resume=True)
# train_model_incrementally([1, 2, 3, 4, 5, 6, 7], 56000, resume=True)
# train_model_incrementally([1, 2, 3, 4, 5, 6, 7, 8], 64000, resume=True)
# train_model_incrementally([1, 2, 3, 4, 5, 6, 7, 8, 9], 72000, resume=True)
