In [4]:
import os
import sys
import shutil
import multiprocessing as mp
import matplotlib.pyplot as plt
import torch
from torchgeo.models import RCF
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
import geopandas as gpd
import rasterio
import numpy as np
import wandb
from sklearn.metrics import roc_auc_score, roc_curve, auc

sys.path.append("../scripts/")
from asm_datamodules import *
from asm_models import *

In [2]:
%load_ext autoreload
%autoreload 2

# Configuration

In [5]:
# device configuration
device, num_devices = ("cuda", torch.cuda.device_count()) if torch.cuda.is_available() else ("cpu", mp.cpu_count())
workers = len(os.sched_getaffinity(0))
print(f"Running on {num_devices} {device}(s) with {workers} cpus")
print(f"Torch indicates there are {torch.get_num_threads()} CPUs")

# file names and paths
root = "/n/holyscratch01/tambe_lab/kayan/karena/" # root for data files

Running on 1 cuda(s) with 1 cpus
Torch indicates there are 1 CPUs


In [6]:
# model parameters
lr = 1e-4
n_epoch = 1
batch_size = 64
loss = "ce"
class_weights = [0.2,0.8]
num_workers = 1
mines_only = False
split = False
split_n = None
split_path = "/n/home07/kayan/asm/data/splits/9_all_data_lowlr_save-split"
freeze_backbone = False
save_split = False

# Prepare Data

In [21]:
# set up dataset to sample from in empirical RCF
emp_rcf_dataset = ASMDataset(transforms=min_max_transform, split="train", split_path=split_path)

In [22]:
# rcf parameters
features = 1000
crop_size = 32
mode = "empirical"
dataset = emp_rcf_dataset

In [23]:
# create and set up RCF-transformed datamodule 
datamodule = ASMDataModule(batch_size=batch_size, num_workers=num_workers, split=split, split_n=split_n, 
                           root=root, transforms=rcf, mines_only=mines_only, split_path=split_path,
                           features=features, crop_size=crop_size, mode=mode, dataset=dataset)

In [24]:
datamodule.setup("fit")
train_dataloader = datamodule.train_dataloader()
val_dataloader = datamodule.val_dataloader()

# Initialize RCF pixelwise regression task

In [25]:
task = CustomSemanticSegmentationTask(
    model="rcf",
    weights=True,
    loss=loss,
    class_weights = torch.Tensor(class_weights) if class_weights is not None else None,
    in_channels=features,
    num_classes=2,
    lr=lr
)

In [26]:
trainer = Trainer(
        accelerator=device,
        devices=num_devices,
        max_epochs=n_epoch,
        logger=False
    )

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model=task, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [MIG-aa265898-f38d-56fb-a517-3ec989e16b08]

  | Name          | Type             | Params
---------------------------------------------------
0 | criterion     | CrossEntropyLoss | 0     
1 | train_metrics | MetricCollection | 0     
2 | val_metrics   | MetricCollection | 0     
3 | test_metrics  | MetricCollection | 0     
4 | model         | RCFRegression    | 2.0 K 
---------------------------------------------------
2.0 K     Trainable params
0         Non-trainable params
2.0 K     Total params
0.008     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

In [18]:
# set up datamodule for testing
datamodule.setup("test")
test_dataloader = datamodule.test_dataloader()

trainer.test(model=task, dataloaders=test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [MIG-50e60ea6-7c54-5212-8df1-d6c45c908599]
/n/home07/kayan/miniconda3/envs/geo-ml/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

/n/home07/kayan/miniconda3/envs/geo-ml/lib/python3.11/site-packages/lightning/pytorch/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 39. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


[{'test_loss': 1.0391552448272705,
  'test_MulticlassAccuracy': 0.02345840074121952,
  'test_MulticlassJaccardIndex': 0.01186840794980526}]