In [None]:
import os
import lightning as L
from torch.utils.data import DataLoader
from lightning.pytorch.callbacks import RichProgressBar

import models
import config_tools
import dataset

In [None]:
config = config_tools.load_config('configs/config.toml')

In [None]:
if os.path.exists(config['metadata_parameters']['path_to_train_metadata']) is False:
    dataset.AtomicDataset.create_train_test_split(config)

In [None]:
model = models.AtomicModelWrapper(config)

In [None]:
feature_extractor = model.get_processor()

def collate_fn(batch):
    pixel_values = [item[0] for item in batch]
    encoding = feature_extractor.pad(pixel_values, return_tensors="pt")
    labels = [item[1] for item in batch]
    batch = {
        'pixel_values': encoding['pixel_values'], 
        'pixel_mask': encoding['pixel_mask'], 
        'labels': labels
    }
    return batch

In [None]:
train_dataset = dataset.AtomicDataset(config, feature_extractor, train=True)
test_dataset = dataset.AtomicDataset(config, feature_extractor)

train_dataloader = DataLoader(
    train_dataset, 
    shuffle=True, 
    batch_size=config['model_parameters']['batch_size'], 
    num_workers=3, 
    pin_memory=True,
    collate_fn=train_dataset.collate_fn
)

test_dataloader = DataLoader(
    test_dataset, 
    shuffle=False, 
    batch_size=config['model_parameters']['batch_size'], 
    num_workers=3, 
    pin_memory=True,
    collate_fn=test_dataset.collate_fn
)

a, b = test_dataset[0]

In [None]:
trainer = L.Trainer(
    max_epochs=config['model_parameters']['epochs'], 
    gradient_clip_val=config['model_parameters']['grad_clipping'],
    callbacks=[RichProgressBar()]
)

trainer.fit(model, train_dataloader, test_dataloader)