In [1]:
import os
import re
import time
import subprocess
from dotenv import load_dotenv

load_dotenv()

os.environ["WANDB_NOTEBOOK_NAME"] = "model_training.ipynb"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["MPLBACKEND"] = "agg"

import yaml
import dvc.api
import wandb
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping

from messis.messis import Messis, LogConfusionMatrix, LogMessisMetrics
from messis.dataloader import GeospatialDataModule

params = dvc.api.params_show()

HF_MODEL_URL = "https://huggingface.co/crop-classification/messis"

def is_interactive():
    import __main__ as main
    return not hasattr(main, '__file__')


The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

In [4]:
if __name__ == '__main__':
    # from multiprocessing import freeze_support
    # freeze_support() # for Windows support

    with open(params['paths']['chips_stats'], 'r') as file:
        chips_stats = yaml.safe_load(file)

    experiment_group = 'exp-7-label-hierarchy'
    experiment_name_short = 'test'
    # Hyperparameters
    hparams = {
        'experiment_name': f'{experiment_group}-{experiment_name_short}',
        'name': experiment_name_short,
        'experiment_group': experiment_group,
        'img_size': params['chip_size'],
        'patch_size': 16,
        'num_frames': 3,
        'bands': [0, 1, 2, 3, 4, 5],
        'debug': False,
        'lr': 1e-3,
        'optimizer': 'Adam',
        'optimizer_weight_decay': None,
        'optimizer_momentum': None,
        'batch_size': 16, # up to 16 on server0111
        'accumulate_grad_batches': 1, # when no accumulation, choose 1
        'early_stopping_metric': 'val_loss',
        'early_stopping_mode': 'min', # 'min' or 'max
        'early_stopping_patience': 10,
        'max_epochs': 400,
        'dropout_p': 0.1,
        'backbone_weights_path': './prithvi/models/Prithvi_100M.pt',
        'freeze_backbone': True,
        'subsets': {
            'train': 4,
            'val': 2,
        },
        'test_folds': [5],
        'heads_spec': {
            'tier1': {
                'type': 'HierarchicalFCNHead',
                'loss_weight': 1,
                'num_classes_to_predict': chips_stats['num_classes_tier1'],
                'target_idx': 0
            },
            'tier2': {
                'type': 'HierarchicalFCNHead',
                'loss_weight': 1,
                'num_classes_to_predict': chips_stats['num_classes_tier2'],
                'target_idx': 1
            },
            'tier3': {
                'type': 'HierarchicalFCNHead',
                'loss_weight': 1,
                'num_classes_to_predict': chips_stats['num_classes_tier3'],
                'target_idx': 2
            },
            'tier3_refinement_head': {
                'type': 'LabelRefinementHead',
                'loss_weight': 1,
                'num_classes_to_predict': chips_stats['num_classes_tier3'],
                'target_idx': 2
            }
        },
    }

    remaining_folds = list(set(range(params['number_of_folds'])) - set(hparams['test_folds']))

    for fold in remaining_folds:
        hparams['val_folds'] = [fold]
        hparams['train_folds'] = list(set(remaining_folds) - set(hparams['val_folds']))

        # Create a W&B logger
        wandb_logger = WandbLogger(
            name=f"{hparams['experiment_name']}-{fold}",
            entity='crop-classification',
            project='messis',
            log_model=False)
        wandb_logger.experiment.config['dvc'] = params

        model = Messis(hparams)

        data_module = GeospatialDataModule(
            data_dir='./data/', 
            train_folds=hparams['train_folds'],
            val_folds=hparams['val_folds'],
            test_folds=hparams['test_folds'],
            batch_size=hparams.get('batch_size', 4), 
            num_workers=1,      # 1 worker is enough for this dataset
            debug=False, 
            subsets=hparams.get('subsets', None)
        )

        early_stopping = EarlyStopping(
            monitor=hparams.get('early_stopping_metric', 'val_loss'), # Metric to monitor
            mode=hparams.get('early_stopping_mode', 'min'), # 'min' or 'max'
            patience=hparams['early_stopping_patience'], # Number of epochs to wait for improvement
            verbose=True
        )

        trainer = Trainer(
            logger=wandb_logger,
            log_every_n_steps=16,
            profiler="simple",
            callbacks=[
                LogMessisMetrics  (hparams, params['paths']['dataset_info'], debug=False),
                LogConfusionMatrix(hparams, params['paths']['dataset_info'], debug=False),
                early_stopping
            ],
            accumulate_grad_batches=hparams['accumulate_grad_batches'],  # Gradient accumulation
            max_epochs=hparams['max_epochs'],
            accelerator="gpu",
            # Use Distributed Data Parallel (activate on server0111)
            strategy="ddp_notebook" if is_interactive() else "ddp",
            num_nodes=1,            # Number of nodes
            devices=1,              # Number of GPUs to use
            precision='16-mixed'    # Train with 16-bit precision (https://lightning.ai/docs/pytorch/stable/common/trainer.html#precision)
        )

        trainer.fit(model, datamodule=data_module)

        model.push_to_hub("crop-classification/messis", commit_message=f"Messis | W&B Run {wandb_logger.experiment.name} (https://wandb.ai/crop-classification/messis/runs/{wandb_logger.experiment.id})")

        time.sleep(2)

        def get_latest_commit_hash(repo_url):
            process = subprocess.Popen(["git", "ls-remote", repo_url], stdout=subprocess.PIPE)
            stdout, stderr = process.communicate()
            if stderr:
                print(f"Error occurred fetching latest commit hash from 🤗: {stderr}")
                return "unknwown"
            return re.split(r'\t+', stdout.decode('ascii'))[0]

        hf_repo_url = f"{HF_MODEL_URL}.git"
        wandb_logger.experiment.config['huggingface_commit'] = f"{HF_MODEL_URL}/commit/{get_latest_commit_hash(hf_repo_url)}"
        wandb.finish()

[0] [1, 2, 3, 4]
[1] [0, 2, 3, 4]
[2] [0, 1, 3, 4]


[3] [0, 1, 2, 4]
[4] [0, 1, 2, 3]
