In [1]:
import sys
import shutil
import multiprocessing as mp
import matplotlib.pyplot as plt
import torch
from torchgeo.models import RCF
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
import geopandas as gpd
import rasterio
import numpy as np
import wandb
from sklearn.metrics import roc_auc_score, roc_curve, auc

sys.path.append("../scripts/")
from asm_datamodules import *
from asm_models import *

In [2]:
%load_ext autoreload
%autoreload 2

# RCF Extraction Model

In [3]:
# device configuration
device, num_devices = ("cuda", torch.cuda.device_count()) if torch.cuda.is_available() else ("cpu", mp.cpu_count())
workers = mp.cpu_count()
torch.set_num_threads(8)
print(f"Running on {num_devices} {device}(s) with {workers} cpus")
print(f"Torch indicates there are {torch.get_num_threads()} CPUs")

# file names and paths
root = "/n/holyscratch01/tambe_lab/kayan/karena/" # root for data files

Running on 1 cuda(s) with 64 cpus
Torch indicates there are 8 CPUs


# Prepare Data

In [9]:
# data parameters
batch_size = 64
num_workers = 8
mines_only = False
split = False
split_n = None
split_path = "/n/home07/kayan/asm/data/splits/9_all_data_lowlr_save-split"
freeze_backbone = False
save_split = False

In [8]:
# create and set up datamodule
datamodule = ASMDataModule(batch_size=batch_size, num_workers=num_workers, split=split, split_n=split_n, 
                           root=root, transforms=rcf, mines_only=mines_only, split_path=split_path)
datamodule.setup("fit")
train_dataloader = datamodule.train_dataloader()
val_dataloader = datamodule.val_dataloader()

# Initialize RCF pixelwise regression task

In [14]:
# model parameters
lr = 1e-5
n_epoch = 1
batch_size = 64
loss = "ce"
class_weights = [0.5,0.5]

In [15]:
task = CustomSemanticSegmentationTask(
    model="rcf",
    weights=True,
    loss=loss,
    class_weights = torch.Tensor(class_weights),
    in_channels=16,
    num_classes=2,
    lr=lr
)

In [16]:
trainer = Trainer(
        accelerator=device,
        devices=num_devices,
        max_epochs=n_epoch,
        logger=False
    )

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [17]:
trainer.fit(model=task, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

You are using a CUDA device ('NVIDIA A100-SXM4-40GB MIG 3g.20gb') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/n/home07/kayan/miniconda3/envs/geo-ml/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:639: Checkpoint directory /n/home07/kayan/asm/notebooks/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [MIG-50e60ea6-7c54-5212-8df1-d6c45c908599]

  | Name          | Type             | Params
---------------------------------------------------
0 | criterion     | CrossEntropyLoss | 0     
1 | train_metrics | MetricCollection | 0     
2 | val_metrics   | MetricCollection | 0     
3 | test_metrics  | MetricCollection | 0     
4 | model         | RCFRegression    | 34    
-------

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/n/home07/kayan/miniconda3/envs/geo-ml/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.
/n/home07/kayan/miniconda3/envs/geo-ml/lib/python3.11/site-packages/lightning/pytorch/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 64. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/n/home07/kayan/miniconda3/envs/geo-ml/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

/n/home07/kayan/miniconda3/envs/geo-ml/lib/python3.11/site-packages/lightning/pytorch/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 57. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
`Trainer.fit` stopped: `max_epochs=1` reached.


In [18]:
# set up datamodule for testing
datamodule.setup("test")
test_dataloader = datamodule.test_dataloader()

trainer.test(model=task, dataloaders=test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [MIG-50e60ea6-7c54-5212-8df1-d6c45c908599]
/n/home07/kayan/miniconda3/envs/geo-ml/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

/n/home07/kayan/miniconda3/envs/geo-ml/lib/python3.11/site-packages/lightning/pytorch/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 39. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


[{'test_loss': 1.0391552448272705,
  'test_MulticlassAccuracy': 0.02345840074121952,
  'test_MulticlassJaccardIndex': 0.01186840794980526}]