In [2]:
import os
import sys
import multiprocessing as mp
import pickle
from pathlib import Path
import wandb
import torch
import geopandas as gpd
import pandas as pd
import matplotlib.pyplot as plt
from segmentation_models_pytorch import Unet
from collections import OrderedDict
from torch.utils.data import DataLoader
from lightning.pytorch import Trainer
from sklearn.metrics import roc_auc_score, roc_curve, auc

sys.path.append("../scripts/")
from asm_datamodules import *
from asm_models import *

In [3]:
%load_ext autoreload
%autoreload 2

# Set up inference with lightning trainer functionalities

In [4]:
#artifact_dir = "/n/home07/kayan/asm/notebooks/artifacts/model-z1woyme2:v19"
artifact_dir = "/n/home07/kayan/asm/notebooks/artifacts/model-ztyg139f:v19"
state_dict = torch.load(f"{artifact_dir}/model.ckpt")["state_dict"]

In [5]:
root = "/n/holyscratch01/tambe_lab/kayan/karena/"

In [6]:
# model parameters
lr = 1e-5
n_epoch = 5
batch_size = 64
loss = "ce"
class_weights = [0.2,0.8]
num_workers = 8
mines_only = False
split = False
split_n = None
split_path = "/n/home07/kayan/asm/data/splits/9_all_data_lowlr_save-split"
freeze_backbone = False
save_split = False

task = CustomSemanticSegmentationTask(
    model="unet",
    backbone="resnet18",
    weights=True,
    loss=loss,
    class_weights = torch.Tensor(class_weights),
    in_channels=4,
    num_classes=2,
    lr=lr,
    patience=5,
    freeze_backbone=freeze_backbone,
    freeze_decoder=False
)

task.load_state_dict(state_dict)

<All keys matched successfully>

In [7]:
# device configuration
device, num_devices = ("cuda", torch.cuda.device_count()) if torch.cuda.is_available() else ("cpu", mp.cpu_count())
workers = mp.cpu_count()
torch.set_num_threads(32)

In [8]:
datamodule = ASMDataModule(batch_size=batch_size, num_workers=num_workers, split=split, split_n=split_n, 
                           root=root, transforms=min_max_transform, mines_only=mines_only, split_path=split_path)

In [9]:
trainer = Trainer(
        accelerator=device,
        devices=num_devices,
        max_epochs=n_epoch,
        logger=False,
        enable_checkpointing=False
    )

/n/home07/kayan/miniconda3/envs/geo-ml/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /n/home07/kayan/miniconda3/envs/geo-ml/lib/python3.1 ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


## Test by feeding in a datamodule, and print example input

In [43]:
trainer.test(model=task, datamodule=datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [MIG-50048574-5e97-5636-860e-ec25631abc1b]


Testing: |          | 0/? [00:00<?, ?it/s]

tensor([[[0.1240, 0.1024, 0.0862,  ..., 0.4923, 0.5039, 0.5165],
         [0.1132, 0.1496, 0.1420,  ..., 0.5943, 0.5915, 0.6119],
         [0.1018, 0.1497, 0.1489,  ..., 0.6251, 0.6491, 0.6763],
         ...,
         [0.3329, 0.3062, 0.2777,  ..., 0.1909, 0.1932, 0.2264],
         [0.3169, 0.3133, 0.2952,  ..., 0.2062, 0.2225, 0.2194],
         [0.2521, 0.3097, 0.3133,  ..., 0.1845, 0.1942, 0.1942]],

        [[0.1848, 0.1422, 0.1231,  ..., 0.5770, 0.5415, 0.5462],
         [0.1800, 0.1972, 0.1851,  ..., 0.6386, 0.6058, 0.6290],
         [0.1815, 0.2205, 0.2130,  ..., 0.6526, 0.6512, 0.6605],
         ...,
         [0.4474, 0.4141, 0.3756,  ..., 0.2359, 0.2133, 0.2211],
         [0.4152, 0.4068, 0.3838,  ..., 0.2437, 0.2541, 0.2512],
         [0.3560, 0.4033, 0.3997,  ..., 0.2258, 0.2702, 0.2962]],

        [[0.1937, 0.1615, 0.1634,  ..., 0.5584, 0.5230, 0.4957],
         [0.1788, 0.2158, 0.2240,  ..., 0.6307, 0.5971, 0.6347],
         [0.1945, 0.2300, 0.2357,  ..., 0.6295, 0.6541, 0.

[{'test_loss': 0.10330305248498917,
  'test_MulticlassAccuracy': 0.990871012210846,
  'test_MulticlassJaccardIndex': 0.9819071888923645}]

## Test by feeding in a DATALOADER, and print example input

In [35]:
test_dataloader = datamodule._dataloader_factory("test")

In [36]:
trainer.test(model=task, dataloaders=test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [MIG-50048574-5e97-5636-860e-ec25631abc1b]


Testing: |          | 0/? [00:00<?, ?it/s]

Input data type: torch.float32
Example input: tensor([[[0.1240, 0.1024, 0.0862,  ..., 0.4923, 0.5039, 0.5165],
         [0.1132, 0.1496, 0.1420,  ..., 0.5943, 0.5915, 0.6119],
         [0.1018, 0.1497, 0.1489,  ..., 0.6251, 0.6491, 0.6763],
         ...,
         [0.3329, 0.3062, 0.2777,  ..., 0.1909, 0.1932, 0.2264],
         [0.3169, 0.3133, 0.2952,  ..., 0.2062, 0.2225, 0.2194],
         [0.2521, 0.3097, 0.3133,  ..., 0.1845, 0.1942, 0.1942]],

        [[0.1848, 0.1422, 0.1231,  ..., 0.5770, 0.5415, 0.5462],
         [0.1800, 0.1972, 0.1851,  ..., 0.6386, 0.6058, 0.6290],
         [0.1815, 0.2205, 0.2130,  ..., 0.6526, 0.6512, 0.6605],
         ...,
         [0.4474, 0.4141, 0.3756,  ..., 0.2359, 0.2133, 0.2211],
         [0.4152, 0.4068, 0.3838,  ..., 0.2437, 0.2541, 0.2512],
         [0.3560, 0.4033, 0.3997,  ..., 0.2258, 0.2702, 0.2962]],

        [[0.1937, 0.1615, 0.1634,  ..., 0.5584, 0.5230, 0.4957],
         [0.1788, 0.2158, 0.2240,  ..., 0.6307, 0.5971, 0.6347],
         [0.

[{'test_loss': 22.92302894592285,
  'test_MulticlassAccuracy': 0.16873769462108612,
  'test_MulticlassJaccardIndex': 0.09214282780885696}]

Somewhere in the weeds of how the Trainer processes a datamodule vs dataloader, input data from the datamodule gets divided by 255 while input data from the dataloader doesn't.

This model was trained by feeding in a datamodule -- so batch norm stats correspond to input data that has been divided by 255 (on top of min-max scaling). When we test on data from a dataloader, which only has min-max scaling, the stats are off which explains the terrible results. It also explains why running a few forward passes with model.train() before testing fixes the issue, since the batch norm stats adjust to the new data.

# Confirm this is the root of the issue

Let's try dividing data from the dataloader by 255, doing nothing else, and feeding it into test()

In [14]:
def custom_transform(sample):
    ''''Does the same thing as min_max_transform, but performs additional division by 255'''
    sample = min_max_transform(sample)
    sample["image"] = sample["image"]/255
    return sample

In [15]:
test_dataset = ASMDataset(
        root = "/n/holyscratch01/tambe_lab/kayan/karena/",
        transforms = custom_transform,
        split = "test",
        bands = ["R", "G", "B", "NIR"],
        split_path = split_path)
test_dataloader = DataLoader(test_dataset, batch_size=64, num_workers=8, shuffle=False)

In [16]:
trainer.test(model=task, dataloaders=test_dataloader)

You are using a CUDA device ('NVIDIA A100-SXM4-40GB MIG 3g.20gb') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [MIG-50048574-5e97-5636-860e-ec25631abc1b]


Testing: |          | 0/? [00:00<?, ?it/s]

Input data type: torch.float32
Example input: tensor([[[0.0005, 0.0004, 0.0003,  ..., 0.0019, 0.0020, 0.0020],
         [0.0004, 0.0006, 0.0006,  ..., 0.0023, 0.0023, 0.0024],
         [0.0004, 0.0006, 0.0006,  ..., 0.0025, 0.0025, 0.0027],
         ...,
         [0.0013, 0.0012, 0.0011,  ..., 0.0007, 0.0008, 0.0009],
         [0.0012, 0.0012, 0.0012,  ..., 0.0008, 0.0009, 0.0009],
         [0.0010, 0.0012, 0.0012,  ..., 0.0007, 0.0008, 0.0008]],

        [[0.0007, 0.0006, 0.0005,  ..., 0.0023, 0.0021, 0.0021],
         [0.0007, 0.0008, 0.0007,  ..., 0.0025, 0.0024, 0.0025],
         [0.0007, 0.0009, 0.0008,  ..., 0.0026, 0.0026, 0.0026],
         ...,
         [0.0018, 0.0016, 0.0015,  ..., 0.0009, 0.0008, 0.0009],
         [0.0016, 0.0016, 0.0015,  ..., 0.0010, 0.0010, 0.0010],
         [0.0014, 0.0016, 0.0016,  ..., 0.0009, 0.0011, 0.0012]],

        [[0.0008, 0.0006, 0.0006,  ..., 0.0022, 0.0021, 0.0019],
         [0.0007, 0.0008, 0.0009,  ..., 0.0025, 0.0023, 0.0025],
         [0.

/n/home07/kayan/miniconda3/envs/geo-ml/lib/python3.11/site-packages/lightning/pytorch/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 64. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/n/home07/kayan/miniconda3/envs/geo-ml/lib/python3.11/site-packages/lightning/pytorch/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 39. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


[{'test_loss': 0.10330305248498917,
  'test_MulticlassAccuracy': 0.990871012210846,
  'test_MulticlassJaccardIndex': 0.9819071888923645}]