# Setting up trainer for combined dataset


In [1]:
# import pandas as pd
from pathlib import Path
from rasterio.plot import show
from ftw_ma import *
import lightning.pytorch as pl

## Model

In [2]:
model = CustomSemanticSegmentationTask(
    model="unet", 
    backbone="efficientnet-b3", 
    weights=True, 
    in_channels=4,
    num_classes=3, 
    loss="localtversky",
    # loss="ce",
    ignore_index=3
)

Using custom trainer
"backbone":        efficientnet-b3
"class_weights":   None
"freeze_backbone": False
"freeze_decoder":  False
"ignore_index":    3
"in_channels":     4
"loss":            localtversky
"lr":              0.001
"model":           unet
"model_kwargs":    {}
"num_classes":     3
"num_filters":     3
"patch_weights":   False
"patience":        10


## Data module

In [3]:
# from ftw_ma.datamodule import FTWMapAfricaDataModule
dm = FTWMapAfricaDataModule(
    batch_size=32,
    num_workers=4,
    data_dir="/Users/LEstes/data/labels/cropland/",
    catalog="../data/toycat2.csv",
    normalization_strategy="min_max",
    normalization_stat_procedure="lab",
    aug_list= ["rotation", "hflip", "vflip", "sharpness", "gamma", 
               "brightness", "contrast", "rescale", "satslidemix"]
)

## Trainer

In [4]:
# choose accelerator
if torch.backends.mps.is_available():
    accel = "mps"
    # acceal = "cpu"  # mps has issues with some operations
elif torch.cuda.is_available():
    accel = "gpu"
else:
    accel = "cpu"

trainer = pl.Trainer(
    max_epochs=3,
    devices=1,
    accelerator=accel,
    precision=32,
)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


### Fit

In [5]:
trainer.fit(model, datamodule=dm)


  | Name          | Type             | Params | Mode 
-----------------------------------------------------------
0 | model         | Unet             | 13.2 M | train
1 | criterion     | SafeLossWrapper  | 0      | train
2 | train_metrics | MetricCollection | 0      | train
3 | val_metrics   | MetricCollection | 0      | train
4 | test_metrics  | MetricCollection | 0      | train
-----------------------------------------------------------
13.2 M    Trainable params
0         Non-trainable params
13.2 M    Total params
52.639    Total estimated model params size (MB)
484       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/LEstes/.pyenv/versions/ftw-mapafrica/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (3) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.


In [12]:
dm.num_workers = 0
dm.setup("test")
# dm.setup("test")

In [14]:
from tqdm import tqdm
for batch in tqdm(dm.test_dataloader()):
    print(batch["image"])

# run in notebook
# dl = dm.test_dataloader()
# print("num_workers:", dl.num_workers, "pin_memory:", dl.pin_memory)
# b = next(iter(dl))  # try with current settings (may raise the error)
# # if above errors, try:
# dm.num_workers = 0
# dl = dm.test_dataloader()
# b = next(iter(dl))
# for k,v in b.items():
#     print(k, type(v), getattr(v, "device", None))

100%|██████████| 1/1 [00:00<00:00, 11.23it/s]

tensor([[[[2.0223e-03, 2.0181e-03, 2.0011e-03,  ..., 1.9014e-03,
           1.9077e-03, 1.9183e-03],
          [2.0351e-03, 2.0287e-03, 2.0117e-03,  ..., 1.9141e-03,
           1.9247e-03, 1.9353e-03],
          [2.0499e-03, 2.0414e-03, 2.0223e-03,  ..., 1.9290e-03,
           1.9438e-03, 1.9565e-03],
          ...,
          [2.1475e-03, 2.1772e-03, 2.2112e-03,  ..., 2.2282e-03,
           2.2494e-03, 2.2600e-03],
          [2.1666e-03, 2.1921e-03, 2.2260e-03,  ..., 2.3109e-03,
           2.3237e-03, 2.3279e-03],
          [2.1942e-03, 2.2176e-03, 2.2451e-03,  ..., 2.3958e-03,
           2.4043e-03, 2.4022e-03]],

         [[1.2923e-03, 1.2881e-03, 1.2817e-03,  ..., 1.1841e-03,
           1.1905e-03, 1.1968e-03],
          [1.2966e-03, 1.2945e-03, 1.2881e-03,  ..., 1.1820e-03,
           1.1905e-03, 1.2011e-03],
          [1.2966e-03, 1.2945e-03, 1.2902e-03,  ..., 1.1820e-03,
           1.1947e-03, 1.2075e-03],
          ...,
          [1.4069e-03, 1.4218e-03, 1.4388e-03,  ..., 1.4515




In [None]:
{k: (v.item() if isinstance(v, torch.Tensor) else v)
     for k, v in trainer.callback_metrics.items()}

In [None]:
# %load_ext tensorboard
# %tensorboard --logdir {trainer.logger.log_dir}

'/Users/LEstes/Dropbox/projects/ftw-mappingafrica-integration/notebooks/lightning_logs/version_16'