In [None]:
import os

os.environ["WANDB_NOTEBOOK_NAME"] = "model_training.ipynb"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

import yaml
import dvc.api
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping

from messis.messis import Messis, LogConfusionMatrix, LogMessisMetrics
from messis.dataloader import GeospatialDataModule

params = dvc.api.params_show()

In [None]:
# Create a W&B logger
wandb_logger = WandbLogger(entity='crop-classification', project='messis', log_model=False)

with open(params['paths']['chips_stats'], 'r') as file:
    chips_stats = yaml.safe_load(file)

# Hyperparameters
hparams = {
    'img_size': 224,
    'patch_size': 16,
    'num_frames': 3,
    'bands': [0, 1, 2, 3, 4, 5],
    'debug': False,
    'lr': 1e-3,
    'batch_size': 4,
    'accumulate_grad_batches': 1,
    'max_epochs': 5,
    # 'subsets': {
    #     'train': 4,
    #     'val': 2,
    # },
    'tiers': {
        'tier1': {
            'loss_weight': 1,
            'num_classes': chips_stats['num_classes_tier1'],
        },
        'tier2': {
            'loss_weight': 1,
            'num_classes': chips_stats['num_classes_tier2'],
        },
        'tier3': {
            'loss_weight': 1,
            'num_classes': chips_stats['num_classes_tier3'],
        },
        'tier3_refined': {
            'loss_weight': 1,
            'num_classes': chips_stats['num_classes_tier3'],
        }
    }
}
# TODO add these to the params.yaml

In [None]:
model = Messis(hparams)

data_module = GeospatialDataModule(
    data_dir='./data/', 
    test_fold=0, 
    batch_size=hparams.get('batch_size', 4), 
    num_workers=2, 
    debug=False, 
    subsets=hparams.get('subsets', None)
)

early_stopping = EarlyStopping(
    monitor='val_loss',  # Metric to monitor
    patience=3,          # Number of epochs to wait for improvement
    verbose=True,
    mode='min'
)

trainer = Trainer(
    logger=wandb_logger,
    log_every_n_steps=1,
    callbacks=[
        LogMessisMetrics  (hparams, params['paths']['dataset_info'], debug=False),
        LogConfusionMatrix(hparams, params['paths']['dataset_info'], debug=False),
        early_stopping
    ],
    accumulate_grad_batches=hparams['accumulate_grad_batches'],  # Gradient accumulation
    max_epochs=hparams['max_epochs'],
    accelerator="gpu",
    strategy="ddp",         # Use Distributed Data Parallel
    num_nodes=1,            # Number of nodes
    devices=2,              # Number of GPUs to use
    precision='16-mixed'    # Train with 16-bit precision (https://lightning.ai/docs/pytorch/stable/common/trainer.html#precision)
)

In [None]:
trainer.fit(model, datamodule=data_module)