In [None]:
import os
import lightning as L
from torch.utils.data import DataLoader
from lightning.pytorch.callbacks import RichProgressBar

import config_tools
from models import AtomicFasterRCNN
from dataset import AtomicDataset

In [2]:
config = config_tools.load_config('configs/faster_rcnn_config.toml')

In [3]:
if os.path.exists(config['metadata_parameters']['path_to_train_metadata']) is False:
    AtomicDataset.create_train_test_split(config)

In [4]:
train_dataset = AtomicDataset(config, is_train=True)
test_dataset = AtomicDataset(config)

train_dataloader = DataLoader(
    train_dataset, 
    shuffle=True, 
    batch_size=config['model_parameters']['batch_size'], 
    num_workers=3, 
    pin_memory=True,
    collate_fn=AtomicDataset.collate_fn
)

test_dataloader = DataLoader(
    test_dataset, 
    shuffle=False, 
    batch_size=config['model_parameters']['batch_size'], 
    num_workers=3, 
    pin_memory=True,
    collate_fn=AtomicDataset.collate_fn
)

In [None]:
model = AtomicFasterRCNN(config)

In [None]:
trainer = L.Trainer(
    max_epochs=config['model_parameters']['epochs'], 
    gradient_clip_val=config['model_parameters']['grad_clipping'],
    callbacks=[RichProgressBar()],
    accumulate_grad_batches=64
)

trainer.fit(model, train_dataloader)