This notebook is for training a model to segment a body part

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Install Dependencies

In [None]:
!pip install -U torch==2.2.1 torchvision
!pip install git+https://github.com/facebookresearch/fvcore.git
import torch, torchvision
torch.__version__

In [None]:
!git clone https://github.com/facebookresearch/detectron2 detectron2_repo
!pip install -e detectron2_repo

In [None]:
# You may need to restart your runtime prior to this, to let your installation take effect
# Some basic setup
# Setup detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

# import some common libraries
import matplotlib.pyplot as plt
import numpy as np
import cv2
from google.colab.patches import cv2_imshow

# import some common detectron2 utilities
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog

# Import Images

In [None]:
from zipfile import ZipFile

# Extract train1.zip
with ZipFile('/content/drive/MyDrive/ML_Stuff/Train_1.zip', 'r') as zip_ref:
    zip_ref.extractall('Train_1')

# Extract train2.zip
with ZipFile('/content/drive/MyDrive/ML_Stuff/Train_2.zip', 'r') as zip_ref:
    zip_ref.extractall('Train_2')

with ZipFile('/content/drive/MyDrive/ML_Stuff/Test.zip', 'r') as zip_ref:
    zip_ref.extractall('Test')

with ZipFile('/content/drive/MyDrive/ML_Stuff/Train_3.zip', 'r') as zip_ref:
    zip_ref.extractall('Train_3')

In [None]:
from detectron2.data.datasets import register_coco_instances

register_coco_instances("Train_1_b_dataset", {}, "/content/drive/MyDrive/ML_Stuff/Train_1.json", "Train_1")
register_coco_instances("Train_2_b_dataset", {}, "/content/drive/MyDrive/ML_Stuff/Train_2.json", "Train_2")
register_coco_instances("Train_3_b_dataset", {}, "/content/drive/MyDrive/ML_Stuff/Train_3.json", "Train_3")

In [None]:
train_1_metadata = MetadataCatalog.get("Train_1_b_dataset")
train_1_dicts = DatasetCatalog.get("Train_1_b_dataset")

train_2_metadata = MetadataCatalog.get("Train_2_b_dataset")
train_2_dicts = DatasetCatalog.get("Train_2_b_dataset")

train_3_metadata = MetadataCatalog.get("Train_3_b_dataset")
train_3_dicts = DatasetCatalog.get("Train_3_b_dataset")

Show sample images

In [None]:
import random

for d in random.sample(train_1_dicts, 3):
    img = cv2.imread(d["file_name"])
    visualizer = Visualizer(img[:, :, ::-1], metadata=train_1_metadata, scale=0.5)
    vis = visualizer.draw_dataset_dict(d)
    cv2_imshow(vis.get_image()[:, :, ::-1])

for d in random.sample(train_2_dicts, 3):
    img = cv2.imread(d["file_name"])
    visualizer = Visualizer(img[:, :, ::-1], metadata=train_2_metadata, scale=0.5)
    vis = visualizer.draw_dataset_dict(d)
    cv2_imshow(vis.get_image()[:, :, ::-1])


# Training

In [None]:
from detectron2.engine import DefaultTrainer
from detectron2.config import get_cfg
from detectron2.checkpoint import Checkpointer
import os

In [None]:
cfg = get_cfg()
cfg.OUTPUT_DIR = "./model_9"
cfg.merge_from_file("./detectron2_repo/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg.DATASETS.TRAIN = ("Train_1_b_dataset","Train_2_b_dataset", "Train_3_b_dataset")
cfg.DATASETS.TEST = ()   # no metrics implemented for this dataset
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = "detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl"  # initialize from model zoo
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.001 # Learning rate
cfg.SOLVER.MAX_ITER = 2250
cfg.SOLVER.OPTIMIZER = "ADAM"
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 256   # faster, and good enough for this toy dataset
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1  
cfg.MODEL.RPN.NMS_THRESH = 0.315 # Higher = increased overlap between bounding boxes
cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.315

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

In [None]:
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()

model = trainer.model

# Create a checkpointer object that will be used for saving
checkpointer = Checkpointer(model, save_dir=cfg.OUTPUT_DIR)

# Save the model manually
checkpointer.save("model_trained")

Load Test Dataset

In [None]:
from detectron2.data.datasets import register_coco_instances
from zipfile import ZipFile

with ZipFile('/content/drive/MyDrive/ML_Stuff/Validation.zip', 'r') as zip_ref:
    zip_ref.extractall('Validation')

register_coco_instances("Validation_b_dataset", {}, "/content/drive/MyDrive/ML_Stuff/Validation.json", "Validation")
validation_butt_metadata = MetadataCatalog.get("Validation_b_dataset")
validation_butt_dicts = DatasetCatalog.get("Validation_b_dataset")

Perform Inference with Trained model on test dataset, and create a predictor using the model we just trained (this is also the code we use to actually use the model later on):

In [None]:
# cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
cfg.MODEL.WEIGHTS = os.path.join("/content/drive/MyDrive/ML_Stuff/model_9", "model_final.pth")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.4
cfg.DATASETS.TEST = ("Test_butt_dataset", )
predictor = DefaultPredictor(cfg)

[04/28 01:43:09 d2.checkpoint.detection_checkpoint]: [DetectionCheckpointer] Loading from /content/drive/MyDrive/ML_Stuff/model_9/model_final.pth ...


Randomly select from test samples to visualize results, uses model on sample images, selects highest confidence ROI for highlight:

In [None]:
from detectron2.utils.visualizer import ColorMode

for d in random.sample(validation_butt_dicts, 3):
    im = cv2.imread(d["file_name"])
    outputs = predictor(im)
    v = Visualizer(im[:, :, ::-1],
                   metadata=validation_butt_metadata,
                   scale=0.8,
                   instance_mode=ColorMode.IMAGE_BW   # remove the colors of unsegmented pixels
    )
    instances = outputs["instances"].to("cpu")

    if len(instances) > 0:
      highest_score_index = instances.scores.argmax()
      highest_score_instance = instances[highest_score_index:highest_score_index + 1]
    else:
      highest_score_instance = instances

    v = v.draw_instance_predictions(highest_score_instance)
    cv2_imshow(v.get_image()[:, :, ::-1])