In [5]:
import yaml
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

import segmentation_models_pytorch as smp

from src.xarray_module import XarrayDataModule
from src.litsegmodel import LitSegModel

#params are stored in config.yaml
with open('config.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
    
seed_everything(1234)

ENCODER_NAME=config['ENCODER_NAME'] 
N_BAND=config['N_BAND']                  
N_CLASS=config['N_CLASS']

Global seed set to 1234


## fit model

In [3]:
#base model
umodel = smp.Unet(
    encoder_name=ENCODER_NAME, 
    in_channels=N_BAND,                  
    classes=N_CLASS,                      
)
#datamodule for xarrays
xmod = XarrayDataModule()
#lightning model
model = LitSegModel(umodel)

checkpoint_callback = ModelCheckpoint(dirpath="checkpoints/",  monitor="valid_loss",save_top_k=2)
wandb_logger = WandbLogger(project="overstory",name="run_0",log_model = False)

trainer = Trainer(max_time="00:02:00:00", 
                  max_epochs=500,
                  accelerator="gpu",
                  logger=wandb_logger,
                  callbacks=[checkpoint_callback])

trainer.fit(model, datamodule=xmod)

check_best = checkpoint_callback.best_model_path

  "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse"
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Cannot find the ecCodes library
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type                   | Params
---------------------------------------------------
0 | model   | Unet                   | 24.5 M
1 | jaccard | MulticlassJaccardIndex | 0     
2 | confmat | BinaryConfusionMatrix  | 0     
3 | loss_fn | SoftCrossEntropyLoss   | 0     
---------------------------------------------------
24.5 M    Trainable params
0         Non-trainable params
24.5 M    Total params
97.834    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

## load best model and test

In [12]:

umodel = smp.Unet(
    encoder_name=ENCODER_NAME, 
    in_channels=N_BAND,                  
    classes=N_CLASS,                      
)

model = LitSegModel(umodel)
xmod = XarrayDataModule()
trainer = Trainer()

checkpoint = '/home/glennmoncrieff/overstory/checkpoints/epoch=4-step=2155.ckpt'
model = LitSegModel.load_from_checkpoint(checkpoint,model=umodel)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [13]:
trainer.test(model, datamodule=xmod)

Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_jaccard          0.5796831846237183
        test_loss           0.45853060483932495
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.45853060483932495, 'test_jaccard': 0.5796831846237183}]