In [None]:
# set env "WANDB_NOTEBOOK_NAME" to "model_training.ipynb"
import os

os.environ["WANDB_NOTEBOOK_NAME"] = "model_training.ipynb"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

In [None]:
import yaml
import dvc.api
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

from messis.messis import Messis, LogConfusionMatrix, LogMessisMetrics
from messis.dataloader import GeospatialDataModule

params = dvc.api.params_show()

In [None]:
# Create a W&B logger
wandb_logger = WandbLogger(entity='crop-classification', project='messis', log_model=False)

with open(params['paths']['chips_fold_stats'], 'r') as file:
    chips_fold_stats = yaml.safe_load(file)

# Make sure to pass your Hyperparameters as a dictionary
hparams = {
    'img_size': 224,
    'patch_size': 16,
    'num_frames': 3,
    'bands': [0, 1, 2, 3, 4, 5],
    'debug': False,
    'lr': 1e-3,
    'subsets': {
        'train': 4,
        'val': 2,
    },
    'tiers': {
        'tier1': {
            'loss_weight': 1,
            'num_classes': chips_fold_stats['num_classes_tier1'],
        },
        'tier2': {
            'loss_weight': 1,
            'num_classes': chips_fold_stats['num_classes_tier2'],
        },
        'tier3': {
            'loss_weight': 1,
            'num_classes': chips_fold_stats['num_classes_tier3'],
        },
        'tier3_refined': {
            'loss_weight': 1,
            'num_classes': chips_fold_stats['num_classes_tier3'],
        }
    }
}
# TODO add these to the params.yaml

In [None]:
model = Messis(hparams)

data_module = GeospatialDataModule(data_dir='./data/', test_fold=0, batch_size=2, num_workers=1, debug=True, subsets=hparams.get('subsets', None))

trainer = Trainer(
    callbacks=[
        LogMessisMetrics  (hparams, params['paths']['tier_names'], debug=False),
        LogConfusionMatrix(hparams, params['paths']['tier_names'], debug=True)
        ],
    logger=wandb_logger,    # Attach the logger
    max_epochs=1,           # Set the number of epochs
    log_every_n_steps=1
)

In [None]:
trainer.fit(model, datamodule=data_module)