In [1]:
import yaml
import dvc.api
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

from messis.messis import Messis, LogConfusionMatrix, LogMessisMetrics
from messis.dataloader import GeospatialDataModule

params = dvc.api.params_show()


# set env "WANDB_NOTEBOOK_NAME" to "model_training.ipynb"
import os

os.environ["WANDB_NOTEBOOK_NAME"] = "model_training.ipynb"

In [2]:
# Create a W&B logger
wandb_logger = WandbLogger(entity='crop-classification', project='messis', log_model='all')

with open(params['paths']['chips_fold_stats'], 'r') as file:
    chips_fold_stats = yaml.safe_load(file)

# Make sure to pass your Hyperparameters as a dictionary
hparams = {
    'img_size': 224,
    'patch_size': 16,
    'num_frames': 3,
    'bands': [0, 1, 2, 3, 4, 5],
    'debug': False,
    'lr': 1e-3,
    'subsets': {
        'train': 8,
        'val': 4,
    },
    'tiers': {
        'tier1': {
            'loss_weight': 1,
            'num_classes': chips_fold_stats['num_classes_tier1'],
        },
        'tier2': {
            'loss_weight': 1,
            'num_classes': chips_fold_stats['num_classes_tier2'],
        },
        'tier3': {
            'loss_weight': 1,
            'num_classes': chips_fold_stats['num_classes_tier3'],
        },
        'tier3_refined': {
            'loss_weight': 1,
            'num_classes': chips_fold_stats['num_classes_tier3'],
        }
    }
}
# TODO add these to the params.yaml

In [3]:
model = Messis(hparams)

data_module = GeospatialDataModule(data_dir='./data/', test_fold=0, batch_size=4, num_workers=1, crop_to=224, debug=True, subsets=hparams.get('subsets', None))

trainer = Trainer(
    callbacks=[
        LogMessisMetrics(hparams, debug=False),
        LogConfusionMatrix(hparams, params['paths']['tier_names'], debug=True)
        ],
    logger=wandb_logger,    # Attach the logger
    max_epochs=5,           # Set the number of epochs
    log_every_n_steps=1
)

Loaded pretrained weights from './prithvi/models/Prithvi_100M.pt' with partial matching.


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [4]:
trainer.fit(model, datamodule=data_module)

wandb: Currently logged in as: crop-classification. Use `wandb login --relogin` to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011288888888884685, max=1.0…

Loading mean/std stats from ./data/chips_fold_stats.yaml
Stats with selected test fold 0: {'mean': [545.200927734375, 645.622802734375, 392.03851318359375, 2928.294189453125, 1824.5634765625, 1107.721923828125], 'n_bands': 6, 'n_chips': 290, 'n_timesteps': 3, 'std': [454.2867126464844, 391.0761413574219, 348.5946044921875, 1274.0751953125, 649.343994140625, 541.7734985351562]} over 3 timesteps.
Randomly selecting 8 samples from 290 samples.
Loading mean/std stats from ./data/chips_fold_stats.yaml
Stats with selected test fold 0: {'mean': [545.200927734375, 645.622802734375, 392.03851318359375, 2928.294189453125, 1824.5634765625, 1107.721923828125], 'n_bands': 6, 'n_chips': 290, 'n_timesteps': 3, 'std': [454.2867126464844, 391.0761413574219, 348.5946044921875, 1274.0751953125, 649.343994140625, 541.7734985351562]} over 3 timesteps.
Randomly selecting 4 samples from 56 samples.


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                   | Params
-------------------------------------------------
0 | model | HierarchicalClassifier | 178 M 
-------------------------------------------------
91.8 M    Trainable params
86.7 M    Non-trainable params
178 M     Total params
713.849   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\flori\AppData\Local\pypoetry\Cache\virtualenvs\messis-NTiJd-Nx-py3.11\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:436: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.
  return F.conv3d(
  x = F.scaled_dot_product_attention(
c:\Users\flori\AppData\Local\pypoetry\Cache\virtualenvs\messis-NTiJd-Nx-py3.11\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:436: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Training: |          | 0/? [00:00<?, ?it/s]

Updating confusion matrix for train tier1
Updating confusion matrix for train tier2
Updating confusion matrix for train tier3
Updating confusion matrix for train tier3_refined
Updating confusion matrix for train tier1
Updating confusion matrix for train tier2
Updating confusion matrix for train tier3
Updating confusion matrix for train tier3_refined


Validation: |          | 0/? [00:00<?, ?it/s]

Updating confusion matrix for val tier1
Updating confusion matrix for val tier2
Updating confusion matrix for val tier3
Updating confusion matrix for val tier3_refined
Logging and resetting confusion matrix for val tier1 Update count: 1
Logging and resetting confusion matrix for val tier2 Update count: 1
Logging and resetting confusion matrix for val tier3 Update count: 1
Logging and resetting confusion matrix for val tier3_refined Update count: 1
Logging and resetting confusion matrix for train tier1 Update count: 2
Logging and resetting confusion matrix for train tier2 Update count: 2
Logging and resetting confusion matrix for train tier3 Update count: 2
Logging and resetting confusion matrix for train tier3_refined Update count: 2
Updating confusion matrix for train tier1
Updating confusion matrix for train tier2
Updating confusion matrix for train tier3
Updating confusion matrix for train tier3_refined
Updating confusion matrix for train tier1
Updating confusion matrix for train ti

Validation: |          | 0/? [00:00<?, ?it/s]

Updating confusion matrix for val tier1
Updating confusion matrix for val tier2
Updating confusion matrix for val tier3
Updating confusion matrix for val tier3_refined
Logging and resetting confusion matrix for val tier1 Update count: 1
Logging and resetting confusion matrix for val tier2 Update count: 1
Logging and resetting confusion matrix for val tier3 Update count: 1
Logging and resetting confusion matrix for val tier3_refined Update count: 1
Logging and resetting confusion matrix for train tier1 Update count: 2
Logging and resetting confusion matrix for train tier2 Update count: 2
Logging and resetting confusion matrix for train tier3 Update count: 2
Logging and resetting confusion matrix for train tier3_refined Update count: 2
Updating confusion matrix for train tier1
Updating confusion matrix for train tier2
Updating confusion matrix for train tier3
Updating confusion matrix for train tier3_refined
Updating confusion matrix for train tier1
Updating confusion matrix for train ti

Validation: |          | 0/? [00:00<?, ?it/s]

Updating confusion matrix for val tier1
Updating confusion matrix for val tier2
Updating confusion matrix for val tier3
Updating confusion matrix for val tier3_refined
Logging and resetting confusion matrix for val tier1 Update count: 1
Logging and resetting confusion matrix for val tier2 Update count: 1
Logging and resetting confusion matrix for val tier3 Update count: 1
Logging and resetting confusion matrix for val tier3_refined Update count: 1
Logging and resetting confusion matrix for train tier1 Update count: 2
Logging and resetting confusion matrix for train tier2 Update count: 2
Logging and resetting confusion matrix for train tier3 Update count: 2
Logging and resetting confusion matrix for train tier3_refined Update count: 2
Updating confusion matrix for train tier1
Updating confusion matrix for train tier2
Updating confusion matrix for train tier3
Updating confusion matrix for train tier3_refined
Updating confusion matrix for train tier1
Updating confusion matrix for train ti

Validation: |          | 0/? [00:00<?, ?it/s]

Updating confusion matrix for val tier1
Updating confusion matrix for val tier2
Updating confusion matrix for val tier3
Updating confusion matrix for val tier3_refined
Logging and resetting confusion matrix for val tier1 Update count: 1
Logging and resetting confusion matrix for val tier2 Update count: 1
Logging and resetting confusion matrix for val tier3 Update count: 1
Logging and resetting confusion matrix for val tier3_refined Update count: 1
Logging and resetting confusion matrix for train tier1 Update count: 2
Logging and resetting confusion matrix for train tier2 Update count: 2
Logging and resetting confusion matrix for train tier3 Update count: 2
Logging and resetting confusion matrix for train tier3_refined Update count: 2
Updating confusion matrix for train tier1
Updating confusion matrix for train tier2
Updating confusion matrix for train tier3
Updating confusion matrix for train tier3_refined
Updating confusion matrix for train tier1
Updating confusion matrix for train ti

Validation: |          | 0/? [00:00<?, ?it/s]

Updating confusion matrix for val tier1
Updating confusion matrix for val tier2
Updating confusion matrix for val tier3
Updating confusion matrix for val tier3_refined
Logging and resetting confusion matrix for val tier1 Update count: 1
Logging and resetting confusion matrix for val tier2 Update count: 1
Logging and resetting confusion matrix for val tier3 Update count: 1
Logging and resetting confusion matrix for val tier3_refined Update count: 1
Logging and resetting confusion matrix for train tier1 Update count: 2
Logging and resetting confusion matrix for train tier2 Update count: 2
Logging and resetting confusion matrix for train tier3 Update count: 2
Logging and resetting confusion matrix for train tier3_refined Update count: 2


OSError: [Errno 28] No space left on device