In [1]:
import yaml
import dvc.api
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

from messis.messis import Messis
from messis.dataloader import GeospatialDataModule

params = dvc.api.params_show()


# set env "WANDB_NOTEBOOK_NAME" to "model_training.ipynb"
import os

os.environ["WANDB_NOTEBOOK_NAME"] = "model_training.ipynb"

In [2]:
# Create a W&B logger
wandb_logger = WandbLogger(entity='crop-classification', project='messis', log_model='all')

with open(params['paths']['chips_fold_stats'], 'r') as file:
    chips_fold_stats = yaml.safe_load(file)

# Make sure to pass your Hyperparameters as a dictionary
hparams = {
    'num_classes_tier1': chips_fold_stats['num_classes_tier1'],
    'num_classes_tier2': chips_fold_stats['num_classes_tier2'],
    'num_classes_tier3': chips_fold_stats['num_classes_tier3'],
    'img_size': 224, # new image size
    'patch_size': 16,
    'num_frames': 3,
    'bands': [0, 1, 2, 3, 4, 5],
    'weight_tier1': 1,
    'weight_tier2': 1,
    'weight_tier3': 1,
    'weight_tier3_refined': 1,
    'debug': False,
    'lr': 1e-3
}
# TODO add these to the params.yaml

In [3]:
model = Messis(hparams)

data_module = GeospatialDataModule(data_dir='./data/', test_fold=0, batch_size=4, num_workers=1, crop_to=224, debug=True)

trainer = Trainer(
    logger=wandb_logger,    # Attach the logger
    max_epochs=1            # Set the number of epochs
)

Loaded pretrained weights from './prithvi/models/Prithvi_100M.pt' with partial matching.


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [5]:
trainer.fit(model, datamodule=data_module)

Loading mean/std stats from c:\dev\messis\data\chips_fold_stats.yaml
Stats with selected test fold 0: {'mean': [545.200927734375, 645.622802734375, 392.03851318359375, 2928.294189453125, 1824.5634765625, 1107.721923828125], 'n_bands': 6, 'n_chips': 290, 'n_timesteps': 3, 'std': [454.2867126464844, 391.0761413574219, 348.5946044921875, 1274.0751953125, 649.343994140625, 541.7734985351562]} over 3 timesteps.
Loading mean/std stats from c:\dev\messis\data\chips_fold_stats.yaml
Stats with selected test fold 0: {'mean': [545.200927734375, 645.622802734375, 392.03851318359375, 2928.294189453125, 1824.5634765625, 1107.721923828125], 'n_bands': 6, 'n_chips': 290, 'n_timesteps': 3, 'std': [454.2867126464844, 391.0761413574219, 348.5946044921875, 1274.0751953125, 649.343994140625, 541.7734985351562]} over 3 timesteps.



  | Name  | Type                   | Params
-------------------------------------------------
0 | model | HierarchicalClassifier | 178 M 
-------------------------------------------------
91.8 M    Trainable params
86.7 M    Non-trainable params
178 M     Total params
713.849   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\flori\AppData\Local\pypoetry\Cache\virtualenvs\messis-NTiJd-Nx-py3.12\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:436: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.
c:\Users\flori\AppData\Local\pypoetry\Cache\virtualenvs\messis-NTiJd-Nx-py3.12\Lib\site-packages\pytorch_lightning\trainer\call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
