In [1]:
import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
# os.environ['TORCHDYNAMO_VERBOSE'] = "1"

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.models.detection import fasterrcnn_mobilenet_v3_large_fpn, FasterRCNN_MobileNet_V3_Large_FPN_Weights, fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_Weights
from torchvision.models import MobileNet_V3_Large_Weights, ResNet50_Weights
from torchvision.models.detection.rpn import concat_box_prediction_layers
from torchvision.models.detection.roi_heads import fastrcnn_loss
import transformers
from transformers import DetaImageProcessor, DetaForObjectDetection, DetaConfig, DetrImageProcessor, DetrForObjectDetection, DetrConfig, DeformableDetrForObjectDetection, DeformableDetrConfig, DeformableDetrImageProcessor
from transformers.models.deta.image_processing_deta import AnnotionFormat

import numpy as np
import random
import math
from pathlib import Path
from fastai.vision import *
import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import RichProgressBar, ModelCheckpoint, EarlyStopping, LearningRateMonitor, RichModelSummary
import cv2
from pycocotools import coco, cocoeval, _mask
from pycocotools import mask as maskUtils
from PIL import  Image
from matplotlib import pyplot as plt
import logging
from logging.config import fileConfig
import sys
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

# logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
fileConfig("logging.ini")
logger = logging.getLogger("trainer")
torch.set_float32_matmul_precision('medium')
torch.backends.cudnn.benchmark = True
# logger.setLevel(logging.DEBUG)

Could not load the custom kernel for multi-scale deformable attention: CUDA_HOME environment variable is not set. Please set it to your CUDA install root.


In [2]:
ROOT_DIR = Path("/home/vamsik1211/Data/git-repos/ClearquoteProject/exercise-2/dataset/train")
coco_annotations = {}
for image_dir in ROOT_DIR.iterdir():
    if image_dir.is_dir():
        coco_data = coco.COCO(image_dir / "coco_data.json")
        coco_annotations[image_dir.name] = coco_data

# Test Data
ROOT_DIR_TEST = Path("/home/vamsik1211/Data/git-repos/ClearquoteProject/exercise-2/dataset/test")
coco_annotations_test = {}
for image_dir in ROOT_DIR_TEST.iterdir():
    if image_dir.is_dir():
        coco_data = coco.COCO(image_dir / "coco_data.json")
        coco_annotations_test[image_dir.name] = coco_data

loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
load

In [3]:
class LoadData(Dataset):
    def __init__(self, root_dir: Path, coco_annotations: dict[str, coco.COCO], transform=None, device="cpu", processor: DetrImageProcessor = None):
        self.root_dir = root_dir
        self.coco_annotations = coco_annotations
        self.transform = transform
        self.device = device

        self.processor = processor

    def load_images_paths(self) -> list[Path]:
        images = []
        for image_dir in self.root_dir.iterdir():
            if image_dir.is_dir():
                images.extend(list(image_dir.glob("**.jpg")))
        return images
    
    def __len__(self):
        # return len(self.images)
        total_len = 0
        for anns in self.coco_annotations.values():
            total_len += len(anns.getImgIds())

        return total_len
    
    def get_idx_coco_section(self, idx: int) -> tuple[str, coco.COCO, int]:
        start_len = 0
        for folder_name, anns in self.coco_annotations.items():
            end_len = start_len + len(anns.getImgIds())
            if idx < end_len:
                return folder_name, anns, idx - start_len + 1
            else:
                start_len = end_len

    def collate_fn(self, batch):
        
        pixel_values = [item[0] for item in batch]
        # pixel_masks = [item[1] for item in batch]
        encoding = self.processor.pad(pixel_values, return_tensors="pt")
        labels = [item[1] for item in batch]

        return {
            "pixel_values": encoding["pixel_values"],
            "pixel_mask": encoding["pixel_mask"],
            "labels": labels
        }
        
            
    
    def __getitem__(self, idx):
        # print("INDEX", idx)
        folder_name, annotations, image_idx = self.get_idx_coco_section(idx)
        image_attributes = annotations.loadImgs([image_idx])[0]
        image_path = self.root_dir / folder_name / image_attributes["file_name"]
        image = Image.open(image_path).convert("RGB")
        image = np.array(image)

        # print(folder_name, image_idx, image_attributes)


        # annotations = self.coco_annotations[image_path.parent.name].loadAnns(self.coco_annotations[image_path.parent.name].getAnnIds(image_idx))
        annotations = annotations.imgToAnns[image_idx]

        # filter annotations to have only LCD
        # annotations = [ann for ann in annotations if ann["category_id"] == 1]

        boxes = []
        labels = []
        for annotation in annotations:
            # bbox = [annotation['bbox'][0], annotation['bbox'][1], annotation['bbox'][0] + annotation['bbox'][2], annotation['bbox'][1] + annotation['bbox'][3]]
            bbox = [annotation["bbox"][0], annotation["bbox"][1], annotation["bbox"][2], annotation["bbox"][3]]
            boxes.append(bbox)
            labels.append(annotation["category_id"])
        
        if self.transform:
            transformed_data = self.transform(image=image, bboxes=boxes, labels=labels)
            image = transformed_data["image"]
            boxes = transformed_data["bboxes"]
            # logger.debug("BOXES: %s", boxes)
            # labels = transformed_data["labels"]

        # for ann_idx, annotation in enumerate(annotations):
        #     annotations[ann_idx]["bbox"] = torch.as_tensor(boxes[ann_idx])

        annotations = {
            "image_id": image_idx,
            "annotations": annotations
        }

        # image = image/255.
        encoding = self.processor(images=image, annotations=annotations, return_tensors="pt")
        pixel_value = encoding["pixel_values"].squeeze()
        target = encoding["labels"][0]
        
        return pixel_value, target

In [4]:
transform = A.Compose([
    A.Resize(720, 1280),
    A.ChannelShuffle(p=0.5),
    A.RandomBrightnessContrast(p=0.2, ),
    A.GaussNoise(p=0.5, var_limit=(10.0, 100.0)),
    A.Equalize(p=0.5),
    # A.Normalize(),
    # ToTensorV2()
# ], bbox_params=A.BboxParams(format="pascal_voc", label_fields=['LCD', 'M', 'not_touching', 'odometer', 'screen']))
], bbox_params=A.BboxParams(format="coco", label_fields=['labels']))

transform_test = A.Compose([
    A.Resize(720, 1280),
    # A.Normalize(),
    # ToTensorV2()
], bbox_params=A.BboxParams(format="coco", label_fields=['labels']))

In [5]:
import timm

def create_timm_backbone(name: str, pretrained: bool = True):
    # Load the timm model
    timm_model = timm.create_model(name, pretrained=pretrained)

    # Remove the classification head to get the backbone
    backbone = torch.nn.Sequential(*list(timm_model.children())[:-2])

    # Calculate the number of output channels
    num_channels = list(backbone.children())[-1].out_channels

    return backbone, num_channels

In [6]:
MAX_EPOCHS = 300
MIN_EPOCHS = 10
BATCH_SIZE = 16
NUM_WORKERS = 4

HF_MODEL_NAME = "facebook/detr-resnet-50"
# HF_MODEL_NAME = "facebook/detr-resnet-101"


CATEGORIES = coco_annotations_test[list(coco_annotations_test.keys())[0]].cats
ID2LABEL = {k: v["name"] for k, v in CATEGORIES.items()}
NUM_LABELS = len(CATEGORIES)

# LEARNING_RATES = {
#     "backbone": 1e-4,
#     "others": 1e-3
# }

LEARNING_RATES = {
    "backbone": 1e-4,
    "input_proj": 1e-4,
    "query_position_embeddings": 1e-4,
    "encoder": 1e-3,
    "decoder": 1e-3,
    "class_labels_classifier": 1e-3,
    "bbox_predictor": 1e-3,
    
}

WEIGHT_DECAY = 1e-5

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device {DEVICE}")

processor = DetrImageProcessor.from_pretrained(HF_MODEL_NAME, device=DEVICE, format=AnnotionFormat.COCO_DETECTION, do_resize=False, do_pad=False, do_rescale=True, revision="no_timm", num_labels=3, ignore_mismatched_sizes=True)

odometer_dataset = LoadData(ROOT_DIR, coco_annotations, transform=transform, processor=processor, device=DEVICE)
odometer_dataset_test = LoadData(ROOT_DIR_TEST, coco_annotations_test, transform=transform_test, processor=processor, device=DEVICE)

odometer_dataloader = DataLoader(odometer_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=odometer_dataset.collate_fn, num_workers=NUM_WORKERS)
odometer_dataloader_test = DataLoader(odometer_dataset_test, batch_size=BATCH_SIZE, shuffle=False, collate_fn=odometer_dataset_test.collate_fn, num_workers=NUM_WORKERS)




STEPS_PER_EPOCH = len(odometer_dataset)/BATCH_SIZE
if type(STEPS_PER_EPOCH) == float:
    STEPS_PER_EPOCH = math.ceil(STEPS_PER_EPOCH)

LOG_DIR = "tb_logs"
CHECKPOINT_DIR = "checkpoints"
CHECKPOINT_FILE_NAME = "detr_model-{epoch:02d}-{val_loss:.2f}"


[INFO]: Using device cuda [/tmp/ipykernel_242410/1391204146.py:33]


In [7]:
ID2LABEL

{1: 'LCD', 2: 'odometer'}

In [8]:
print(len(odometer_dataset))
ret_data = odometer_dataset[78]

# Image.fromarray((ret_data[0].permute(1, 2, 0).numpy()*255).astype(np.uint8))
ret_data[0].shape

3400


torch.Size([3, 720, 1280])

In [9]:
detr_config = DetrConfig.from_pretrained(HF_MODEL_NAME, num_labels=NUM_LABELS, id2label=ID2LABEL, label2id={v: k for k, v in ID2LABEL.items()}, revision="no_timm", ignore_mismatched_sizes=True)

In [10]:
# detr_model = DetrForObjectDetection.from_pretrained(HF_MODEL_NAME, revision="no_timm", num_labels=2, ignore_mismatched_sizes=True).to(DEVICE)
detr_model = DetrForObjectDetection(detr_config).to(DEVICE)

for param in detr_model.model.backbone.parameters():
    param.requires_grad = True
# detr_model = torch.compile(detr_model, backend="cudagraphs")

In [11]:
batch = next(iter(odometer_dataloader))
detr_model.eval()
with torch.no_grad():
    outputs = detr_model(batch["pixel_values"].to(DEVICE), batch["pixel_mask"].to(DEVICE), [{k: v.to(DEVICE) for k, v in t.items()} for t in batch["labels"]])

# processor.post_process_object_detection(outputs, threshold=0.0)


In [12]:
# processor.post_process_object_detection(outputs, threshold=0.0)
# batch["pixel_mask"].min()
# outputs.logits

In [13]:
from typing import Any


from lightning.pytorch.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS, OptimizerLRScheduler


class DetrModelPL(L.LightningModule):
    def __init__(self, model: DetrForObjectDetection, 
                 processor: DetrImageProcessor,
                 device: str = "cpu",
                 learning_rates=LEARNING_RATES,
                    weight_decay=WEIGHT_DECAY,
                    steps_per_epoch=STEPS_PER_EPOCH,
                    max_epochs=MAX_EPOCHS,
                    min_epochs=MIN_EPOCHS,
                    batch_size=BATCH_SIZE,
                    num_workers=NUM_WORKERS,
                    model_name=HF_MODEL_NAME,
                 ):
        super().__init__()

        self.model = model
        self.processor = processor

        self.automatic_optimization = False

        self.save_hyperparameters(
            "learning_rates",
            "weight_decay",
            "steps_per_epoch",
            "max_epochs",
            "min_epochs",
            "batch_size",
            "model_name"
        )
        
    def training_step(self, batch, batch_idx) -> STEP_OUTPUT:

        # optimizer_backbone, optimizer_others = self.optimizers()
        (optimizer_model_backbone, optimizer_model_input_projection, \
            optimizer_model_query_position_embeddings, optimizer_model_encoder, \
                optimizer_model_decoder, optimizer_model_class_labels_classifier, \
                    optimizer_model_bbox_predictor) = self.optimizers()
        
        pixel_values = batch["pixel_values"].to(self.device)
        pixel_mask = batch["pixel_mask"].to(self.device)
        labels = [{k: v.to(self.device) for k, v in t.items()} for t in batch["labels"]]

        outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
        loss = outputs["loss"]
        loss_dict = outputs["loss_dict"]
        logits = outputs["logits"]
        pred_boxes = outputs["pred_boxes"]

        if not math.isfinite(loss):
            logger.debug(f"Loss is {loss}, stopping training")
            sys.exit(1)
        
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=pixel_values.shape[0])
        for k, v in loss_dict.items():
            self.log(f"train_{k}", v.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=pixel_values.shape[0])

        # optimizer_backbone.zero_grad()
        # optimizer_others.zero_grad()
        optimizer_model_backbone.zero_grad()
        optimizer_model_input_projection.zero_grad()
        optimizer_model_query_position_embeddings.zero_grad()
        optimizer_model_encoder.zero_grad()
        optimizer_model_decoder.zero_grad()
        optimizer_model_class_labels_classifier.zero_grad()
        optimizer_model_bbox_predictor.zero_grad()

        self.manual_backward(loss)

        # optimizer_backbone.step()
        # optimizer_others.step()
        optimizer_model_backbone.step()
        optimizer_model_input_projection.step()
        optimizer_model_query_position_embeddings.step()
        optimizer_model_encoder.step()
        optimizer_model_decoder.step()
        optimizer_model_class_labels_classifier.step()
        optimizer_model_bbox_predictor.step()

        (lr_scheduler_model_backbone, lr_scheduler_model_input_projection, \
            lr_scheduler_model_query_position_embeddings, lr_scheduler_model_encoder, \
                lr_scheduler_model_decoder, lr_scheduler_model_class_labels_classifier, \
                    lr_scheduler_model_bbox_predictor) = self.lr_schedulers()

        # lr_scheduler_backbone.step()
        # lr_scheduler_others.step()
        
        lr_scheduler_model_backbone.step()
        lr_scheduler_model_input_projection.step()
        lr_scheduler_model_query_position_embeddings.step()
        lr_scheduler_model_encoder.step()
        lr_scheduler_model_decoder.step()
        lr_scheduler_model_class_labels_classifier.step()
        lr_scheduler_model_bbox_predictor.step()
        
        return loss
    
    def validation_step(self, batch, batch_idx) -> STEP_OUTPUT:

        pixel_values = batch["pixel_values"]
        pixel_mask = batch["pixel_mask"]
        labels = [{k: v.to(self.device) for k, v in t.items()} for t in batch["labels"]]

        outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)

        loss = outputs["loss"]
        loss_dict = outputs["loss_dict"]
        logits = outputs["logits"]
        pred_boxes = outputs["pred_boxes"]

        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=pixel_values.shape[0])
        for k, v in loss_dict.items():
            self.log(f"val_{k}", v.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=pixel_values.shape[0])

        return loss

    
    # def configure_optimizers(self) -> OptimizerLRScheduler:
        

        backbone_params = [p for n, p in self.model.named_parameters() if "backbone" in n]
        other_params = [p for n, p in self.model.named_parameters() if "backbone" not in n]

        optimizer_backbone = torch.optim.SGD(backbone_params, lr=LEARNING_RATES["backbone"], momentum=0.9, weight_decay=WEIGHT_DECAY)
        optimizer_others = torch.optim.SGD(other_params, lr=LEARNING_RATES["others"], momentum=0.9, weight_decay=WEIGHT_DECAY)

        # lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer_backbone, max_lr=LEARNING_RATES["backbone"], steps_per_epoch=STEPS_PER_EPOCH, epochs=MAX_EPOCHS)
        lr_scheduler_backbone = optim.lr_scheduler.OneCycleLR(optimizer_backbone, max_lr=LEARNING_RATES["backbone"], pct_start=0.15, div_factor=1.5, final_div_factor=100000, base_momentum=0.75, max_momentum=0.95, epochs=MAX_EPOCHS, steps_per_epoch=STEPS_PER_EPOCH)
        lr_scheduler_others = optim.lr_scheduler.OneCycleLR(optimizer_others, max_lr=LEARNING_RATES["others"], pct_start=0.15, div_factor=1.5, final_div_factor=100000, base_momentum=0.75, max_momentum=0.95, epochs=MAX_EPOCHS, steps_per_epoch=STEPS_PER_EPOCH)

        lr_scheduler_backbone = {
            "scheduler": lr_scheduler_backbone,
            "interval": "step",
            "name": "backbone_lr",
        }
        lr_scheduler_others = {
            "scheduler": lr_scheduler_others,
            "interval": "step",
            "name": "others_lr",
        }

        return ([optimizer_backbone, optimizer_others], [lr_scheduler_backbone, lr_scheduler_others])
    
    def configure_optimizers(self) -> OptimizerLRScheduler:
        
        optimizer_model_backbone = torch.optim.SGD(self.model.model.backbone.parameters(), lr=LEARNING_RATES["backbone"], weight_decay=WEIGHT_DECAY)
        optimizer_model_input_projection = torch.optim.SGD(self.model.model.input_projection.parameters(), lr=LEARNING_RATES["input_proj"], weight_decay=WEIGHT_DECAY)
        optimizer_model_query_position_embeddings = torch.optim.SGD(self.model.model.query_position_embeddings.parameters(), lr=LEARNING_RATES["query_position_embeddings"], weight_decay=WEIGHT_DECAY)
        optimizer_model_encoder = torch.optim.SGD(self.model.model.encoder.parameters(), lr=LEARNING_RATES["encoder"], weight_decay=WEIGHT_DECAY)
        optimizer_model_decoder = torch.optim.SGD(self.model.model.decoder.parameters(), lr=LEARNING_RATES["decoder"], weight_decay=WEIGHT_DECAY)
        optimizer_model_class_labels_classifier = torch.optim.SGD(self.model.class_labels_classifier.parameters(), lr=LEARNING_RATES["class_labels_classifier"], weight_decay=WEIGHT_DECAY)
        optimizer_model_bbox_predictor = torch.optim.SGD(self.model.bbox_predictor.parameters(), lr=LEARNING_RATES["bbox_predictor"], weight_decay=WEIGHT_DECAY)

        lr_scheduler_model_backbone = optim.lr_scheduler.OneCycleLR(optimizer_model_backbone, max_lr=LEARNING_RATES["backbone"], pct_start=0.15, div_factor=1.5, final_div_factor=100000, base_momentum=0.75, max_momentum=0.95, epochs=MAX_EPOCHS, steps_per_epoch=STEPS_PER_EPOCH)
        lr_scheduler_model_input_projection = optim.lr_scheduler.OneCycleLR(optimizer_model_input_projection, max_lr=LEARNING_RATES["input_proj"], pct_start=0.15, div_factor=1.5, final_div_factor=100000, base_momentum=0.75, max_momentum=0.95, epochs=MAX_EPOCHS, steps_per_epoch=STEPS_PER_EPOCH)
        lr_scheduler_model_query_position_embeddings = optim.lr_scheduler.OneCycleLR(optimizer_model_query_position_embeddings, max_lr=LEARNING_RATES["query_position_embeddings"], pct_start=0.15, div_factor=1.5, final_div_factor=100000, base_momentum=0.75, max_momentum=0.95, epochs=MAX_EPOCHS, steps_per_epoch=STEPS_PER_EPOCH)
        lr_scheduler_model_encoder = optim.lr_scheduler.OneCycleLR(optimizer_model_encoder, max_lr=LEARNING_RATES["encoder"], pct_start=0.15, div_factor=1.5, final_div_factor=100000, base_momentum=0.75, max_momentum=0.95, epochs=MAX_EPOCHS, steps_per_epoch=STEPS_PER_EPOCH)
        lr_scheduler_model_decoder = optim.lr_scheduler.OneCycleLR(optimizer_model_decoder, max_lr=LEARNING_RATES["decoder"], pct_start=0.15, div_factor=1.5, final_div_factor=100000, base_momentum=0.75, max_momentum=0.95, epochs=MAX_EPOCHS, steps_per_epoch=STEPS_PER_EPOCH)
        lr_scheduler_model_class_labels_classifier = optim.lr_scheduler.OneCycleLR(optimizer_model_class_labels_classifier, max_lr=LEARNING_RATES["class_labels_classifier"], pct_start=0.15, div_factor=1.5, final_div_factor=100000, base_momentum=0.75, max_momentum=0.95, epochs=MAX_EPOCHS, steps_per_epoch=STEPS_PER_EPOCH)
        lr_scheduler_model_bbox_predictor = optim.lr_scheduler.OneCycleLR(optimizer_model_bbox_predictor, max_lr=LEARNING_RATES["bbox_predictor"], pct_start=0.15, div_factor=1.5, final_div_factor=100000, base_momentum=0.75, max_momentum=0.95, epochs=MAX_EPOCHS, steps_per_epoch=STEPS_PER_EPOCH)

        lr_scheduler_model_backbone = {
            "scheduler": lr_scheduler_model_backbone,
            "interval": "step",
            "name": "model_backbone_lr",
        }
        lr_scheduler_model_input_projection = {
            "scheduler": lr_scheduler_model_input_projection,
            "interval": "step",
            "name": "model_input_projection_lr",
        }
        lr_scheduler_model_query_position_embeddings = {
            "scheduler": lr_scheduler_model_query_position_embeddings,
            "interval": "step",
            "name": "model_query_position_embeddings_lr",
        }
        lr_scheduler_model_encoder = {
            "scheduler": lr_scheduler_model_encoder,
            "interval": "step",
            "name": "model_encoder_lr",
        }
        lr_scheduler_model_decoder = {
            "scheduler": lr_scheduler_model_decoder,
            "interval": "step",
            "name": "model_decoder_lr",
        }
        lr_scheduler_model_class_labels_classifier = {
            "scheduler": lr_scheduler_model_class_labels_classifier,
            "interval": "step",
            "name": "model_class_labels_classifier_lr",
        }
        lr_scheduler_model_bbox_predictor = {
            "scheduler": lr_scheduler_model_bbox_predictor,
            "interval": "step",
            "name": "model_bbox_predictor_lr",
        }

        return ([optimizer_model_backbone, optimizer_model_input_projection, \
                 optimizer_model_query_position_embeddings, optimizer_model_encoder, \
                    optimizer_model_decoder, optimizer_model_class_labels_classifier, \
                        optimizer_model_bbox_predictor], 
                [lr_scheduler_model_backbone, lr_scheduler_model_input_projection, \
                 lr_scheduler_model_query_position_embeddings, lr_scheduler_model_encoder, \
                    lr_scheduler_model_decoder, lr_scheduler_model_class_labels_classifier, \
                        lr_scheduler_model_bbox_predictor])



    def train_dataloader(self) -> TRAIN_DATALOADERS:
        return odometer_dataloader

    def val_dataloader(self) -> EVAL_DATALOADERS:
        return odometer_dataloader_test

        

In [14]:
model = DetrModelPL(detr_model, processor)

callbacks = [
    RichProgressBar(), 
    LearningRateMonitor(
        logging_interval="step",
    ),
    RichModelSummary(max_depth=4),
    ModelCheckpoint(
        monitor="val_loss",
        mode="min",
        filename=CHECKPOINT_FILE_NAME,
        save_top_k=3,
        # dirpath="checkpoints",
    ),
]

tb_logger = TensorBoardLogger(
    save_dir=LOG_DIR,
    name="detr_model",
)

trainer = L.Trainer(
    max_epochs=MAX_EPOCHS,
    min_epochs=MIN_EPOCHS,
    accelerator="gpu",
    callbacks=callbacks,
    logger=tb_logger,
    precision=32,
    enable_checkpointing=True,
)

trainer.fit(model)

Output()