# Question 6

In [None]:
import os
import numpy as np
import torch
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.data.datasets import register_coco_instances
from detectron2.config import get_cfg
from detectron2.engine import DefaultTrainer
from detectron2.utils.visualizer import ColorMode, Visualizer
from detectron2.checkpoint import DetectionCheckpointer


# 1. Prepare the Dataset
dataset_name = "custom_dataset"  # Change this to your dataset name
dataset_dir = "/path/to/dataset"  # Change this to your dataset directory

# Register the COCO format dataset
register_coco_instances(dataset_name, {}, os.path.join(dataset_dir, "train.json"), os.path.join(dataset_dir, "images"))

# 2. Set Up Detectron2
cfg = get_cfg()
cfg.merge_from_file("path/to/config.yaml")  # Change this to your model configuration file
cfg.DATASETS.TRAIN = (dataset_name,)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2  # Adjust the number of workers based on your system
cfg.MODEL.WEIGHTS = "path/to/pretrained/model.pth"  # Change this to your pretrained model weights
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025
cfg.SOLVER.MAX_ITER = 5000
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128
cfg.MODEL.ROI_HEADS.NUM_CLASSES = num_classes  # Set the number of classes in your dataset

# 3. Train the Model
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()

# 4. Evaluate the Model
model = trainer.build_model(cfg)
DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS)
model.eval()

# 5. Segment Images based on User Input
image_path = "path/to/input/image.jpg"  # Change this to your input image path
image = np.array(Image.open(image_path).convert("RGB"))
outputs = model(image)

# Visualize the segmentation results
v = Visualizer(image[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2, instance_mode=ColorMode.IMAGE_BW)
v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
output_image = v.get_image()[:, :, ::-1]

# Display or save the output image
cv2.imshow("Segmentation Output", output_image)
cv2.waitKey(0)
cv2.destroyAllWindows()
