In [1]:
from typing import List, Optional
from torch.optim.optimizer import Optimizer

import os
import hydra
from torch import nn
from omegaconf import DictConfig
import pytorch_lightning as pl
from pytorch_lightning import (
    Callback,
    LightningDataModule,
    LightningModule,
    Trainer,
    seed_everything,
)
from pytorch_lightning.loggers import LightningLoggerBase

from src.utils import utils
from hydra import compose, initialize
from omegaconf import OmegaConf

# GlobalHydra.instance().clear()
initialize(config_path="../configs/", job_name="test_app")

hydra.initialize()

In [2]:
from torch.utils.data import Dataset
from torchvision import transforms

In [3]:
config = compose(config_name="experiments/base_object_detection.yaml", overrides=[])

In [4]:
default_transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(512)
        ])

In [5]:
train_dataset: Dataset = hydra.utils.instantiate(config.datamodule.dataset,
                                         dataset_config = config.datamodule.dataset,
                                         SPLIT_FILE = os.path.join(config.datamodule.dataset.SPLIT_FILE, 'train.txt'),
                                         transforms = default_transforms,
                                          _recursive_=False)
val_dataset: Dataset = hydra.utils.instantiate(config.datamodule.dataset,
                                         dataset_config = config.datamodule.dataset,
                                         SPLIT_FILE = os.path.join(config.datamodule.dataset.SPLIT_FILE, 'val.txt'),
                                         transforms = default_transforms,
                                          _recursive_=False)    

In [6]:
len(val_dataset)

2022

In [7]:
datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule, data_config = config.datamodule, _recursive_=False)

In [8]:
datamodule.setup()

In [9]:
train_dataloader = datamodule.train_dataloader()
val_dataloader = datamodule.val_dataloader()
test_dataloader = datamodule.test_dataloader()

In [10]:
len(train_dataloader), len(val_dataloader), len(test_dataloader)

(8994, 2022, 1451)

In [11]:
for batch_id, batch in enumerate(test_dataloader):
    break

In [14]:
batch['img'].shape

torch.Size([1, 3, 512, 512])

In [15]:
batch

{'img_id': ['011b8618-9b79-40fe-958b-3e35c55930f6_trap3_4744115559943958786'],
 'img': tensor([[[[0.7994, 0.7840, 0.7828,  ..., 0.7521, 0.7385, 0.7484],
           [0.8086, 0.7649, 0.8142,  ..., 0.7389, 0.7495, 0.7462],
           [0.7880, 0.7946, 0.8037,  ..., 0.7433, 0.7426, 0.7377],
           ...,
           [0.7320, 0.7395, 0.7451,  ..., 0.7145, 0.7119, 0.7213],
           [0.7415, 0.7333, 0.7202,  ..., 0.7145, 0.7147, 0.6972],
           [0.7392, 0.7199, 0.7197,  ..., 0.7031, 0.7201, 0.7171]],
 
          [[0.7563, 0.7408, 0.7431,  ..., 0.7111, 0.6993, 0.7092],
           [0.7655, 0.7218, 0.7710,  ..., 0.6997, 0.7103, 0.7070],
           [0.7449, 0.7515, 0.7606,  ..., 0.7041, 0.7034, 0.6984],
           ...,
           [0.6928, 0.7003, 0.7059,  ..., 0.6949, 0.6884, 0.6978],
           [0.7023, 0.6941, 0.6810,  ..., 0.6949, 0.6912, 0.6737],
           [0.6999, 0.6807, 0.6805,  ..., 0.6835, 0.6970, 0.6870]],
 
          [[0.1916, 0.1761, 0.1681,  ..., 0.1008, 0.0758, 0.0857],
     