# Traing Detectron Model on Custom Medical Dataset

### Import

In [6]:
import os
import sys
import cv2
from generate_dataset import my_dataset_function


# Aggiungi il percorso della cartella detectron2
sys.path.append(os.path.abspath("/detectron2_repo"))

from detectron2_repo.detectron2.engine import DefaultTrainer, DefaultPredictor
from detectron2_repo.detectron2.config import get_cfg
from detectron2_repo.detectron2.utils.visualizer import Visualizer
from detectron2.data import DatasetCatalog, MetadataCatalog

script_dir = os.getcwd()

#### 1. Registrazione del dataset in formato COCO

In [7]:
DatasetCatalog.register("brain_mri_train", lambda: my_dataset_function())
DatasetCatalog.register("brain_mri_val", lambda: my_dataset_function())

# Imposta le classi (se non l'hai già fatto)
MetadataCatalog.get("brain_mri_train").set(thing_classes=["background", "edema", "non-enhancing tumor", "enhancing tumour"])

namespace(name='brain_mri_train', thing_classes=['brain'])

### 2. Configurazione del modello

In [8]:
cfg = get_cfg()
cfg.DATASETS.TRAIN = ("brain_mri_train",)
cfg.DATASETS.TEST = ("brain_mri_val",)
cfg.MODEL.WEIGHTS = "detectron2://ImageNetPretrained/MSRA/R-50.pkl"  # Modello pre-addestrato
cfg.DATALOADER.NUM_WORKERS = 0
cfg.SOLVER.IMS_PER_BATCH = 2 #immagini per batch
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 80 # Numero di ROI per immagine
cfg.SOLVER.BASE_LR = 0.00025 #learning rate
cfg.SOLVER.MAX_ITER = 500  # iterazioni massime
cfg.SOLVER.WARMUP_ITERS = 50  # aumenta gradualmente il LR nelle prime x iterazioni fino a valore
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 4  # Cambia in base al numero delle classi
cfg.MODEL.DEVICE = "cpu"  # Imposta l'uso della CPU
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.3 #Soglia di confidenza

### 3. Addestramento

In [9]:
# Verifica se il modello addestrato esiste già
model_path = os.path.join(script_dir, "output", "model_final.pth")

if os.path.exists(model_path):
    print("Modello già addestrato trovato, caricando il modello...")
    cfg.MODEL.WEIGHTS = model_path
    # Carica il modello addestrato
    predictor = DefaultPredictor(cfg)
else:
    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
    trainer = DefaultTrainer(cfg)
    trainer.resume_or_load(resume=False)
    trainer.train()

    # Imposta il modello addestrato
    cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")  # Salva il modello addestrato


[32m[12/12 21:32:46 d2.data.build]: [0mRemoved 0 images with no usable annotations. 6 images left.
[32m[12/12 21:32:46 d2.data.build]: [0mDistribution of instances among all 1 categories:
[36m|  category  | #instances   |
|:----------:|:-------------|
|   brain    | 11           |
|            |              |[0m
[32m[12/12 21:32:46 d2.data.dataset_mapper]: [0m[DatasetMapper] Augmentations used in training: [ResizeShortestEdge(short_edge_length=(800,), max_size=1333, sample_style='choice'), RandomFlip()]
[32m[12/12 21:32:46 d2.data.build]: [0mUsing training sampler TrainingSampler
[32m[12/12 21:32:46 d2.data.common]: [0mSerializing the dataset using: <class 'detectron2.data.common._TorchSerializedList'>
[32m[12/12 21:32:46 d2.data.common]: [0mSerializing 6 elements to byte tensors and concatenating them all ...
[32m[12/12 21:32:46 d2.data.common]: [0mSerialized dataset takes 0.01 MiB
[32m[12/12 21:32:46 d2.data.build]: [0mMaking batched data loader with batch_size=2


Some model parameters or buffers are not found in the checkpoint:
[34mproposal_generator.rpn_head.anchor_deltas.{bias, weight}[0m
[34mproposal_generator.rpn_head.conv.{bias, weight}[0m
[34mproposal_generator.rpn_head.objectness_logits.{bias, weight}[0m
[34mroi_heads.box_predictor.bbox_pred.{bias, weight}[0m
[34mroi_heads.box_predictor.cls_score.{bias, weight}[0m
The checkpoint state_dict contains keys that are not used by the model:
  [35mfc1000.{bias, weight}[0m
  [35mstem.conv1.bias[0m
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


[32m[12/12 21:36:33 d2.utils.events]: [0m eta: 1:30:23  iter: 19  total_loss: 1.942  loss_cls: 0.5938  loss_box_reg: 0.4447  loss_rpn_cls: 0.6894  loss_rpn_loc: 0.2263    time: 11.3297  last_time: 11.2664  data_time: 0.0210  last_data_time: 0.0202   lr: 9.5155e-05  
[32m[12/12 21:40:20 d2.utils.events]: [0m eta: 1:28:10  iter: 39  total_loss: 1.724  loss_cls: 0.3399  loss_box_reg: 0.5787  loss_rpn_cls: 0.6473  loss_rpn_loc: 0.1448    time: 11.3234  last_time: 11.7918  data_time: 0.0237  last_data_time: 0.0200   lr: 0.00019505  
[32m[12/12 21:44:16 d2.utils.events]: [0m eta: 1:25:35  iter: 59  total_loss: 1.745  loss_cls: 0.2685  loss_box_reg: 0.792  loss_rpn_cls: 0.5811  loss_rpn_loc: 0.1022    time: 11.4998  last_time: 11.6430  data_time: 0.0195  last_data_time: 0.0182   lr: 0.00025  
[32m[12/12 21:48:17 d2.utils.events]: [0m eta: 1:22:46  iter: 79  total_loss: 1.395  loss_cls: 0.2218  loss_box_reg: 0.588  loss_rpn_cls: 0.4946  loss_rpn_loc: 0.09124    time: 11.6391  last_time

No evaluator found. Use `DefaultTrainer.test(evaluators=)`, or implement its `build_evaluator` method.


### 4.Inferenza

In [10]:
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")  # Modello addestrato
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # Soglia per l'inferenza
predictor = DefaultPredictor(cfg)

# Carica un'immagine MRI per l'inferenza
image_path = os.path.join(script_dir, "test", "test3.jpg")
image = cv2.imread(image_path)
outputs = predictor(image)
print(outputs)

[32m[12/12 23:15:11 d2.checkpoint.detection_checkpoint]: [0m[DetectionCheckpointer] Loading from ./output/model_final.pth ...


  return torch.load(f, map_location=torch.device("cpu"))


{'instances': Instances(num_instances=0, image_height=1534, image_width=1433, fields=[pred_boxes: Boxes(tensor([], size=(0, 4))), scores: tensor([]), pred_classes: tensor([], dtype=torch.int64)])}


### 5.Test sull'immagine

In [11]:
# Visualizza i risultati con bordi rossi per le istanze rilevate
v = Visualizer(image[:,:,::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
v = v.draw_instance_predictions(outputs['instances'].to('cpu'))
img = v.get_image()[:, :, ::-1]
cv2.imwrite('output.jpg', img)

True