In [1]:
import os
import tempfile
from glob import glob
from typing import Any, Callable, Dict, Optional, Iterator, Union, Tuple, List
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch

from torch.utils.data import DataLoader
from torchgeo.datasets import BoundingBox, stack_samples

import torch.nn.functional as F
from torchgeo.samplers.batch import RandomBatchGeoSampler
from torchgeo.samplers.single import GridGeoSampler, RandomGeoSampler
from torchgeo.samplers.utils import get_random_bounding_box
from torchgeo.samplers.constants import Units

import math
import random
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
os.chdir("/home/users/sofijas/WildfireDistribution/")

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
from src.data_loading import LandcoverSimple, LandcoverComplex, MODIS_CCI, MODIS_JD, Landsat7
from src.samplers import ConstrainedRandomBatchGeoSampler
from src.datamodules import MODISJDLandcoverSimpleDataModule , MODISJDLandcoverSimpleLandsatDataModule

In [5]:
batch_size = 2

In [6]:
sampler_size = 256

In [7]:
length = 10

In [32]:
datamodule = MODISJDLandcoverSimpleDataModule(
    modis_root_dir="data/modis/2017/",
    landcover_root_dir="data/landcover/",
    patch_size=sampler_size,
    length=length,
    batch_size=batch_size,
    num_workers=2,
    one_hot_encode=False,
    balance_samples=False,
    grid_sampler=False,
)

datamodule_balanced = MODISJDLandcoverSimpleDataModule(
    modis_root_dir="data/modis/2017/",
    landcover_root_dir="data/landcover/",
    patch_size=sampler_size,
    batch_size=batch_size,
    length=length,
    num_workers=0,
    one_hot_encode=False,
    balance_samples=True,
    burn_prop = 0.5,
)

In [33]:
from torchgeo.trainers import SemanticSegmentationTask
from src.tasks import BinarySemanticSegmentationTask

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint

from pytorch_lightning.callbacks import Callback

Notes on losses:

Jaccard is the Intersection over Union loss, so it accounts for the fact that one of the classes (background/no fire in our case) is more dominant, and it prevents situations where we get a very high accuracy just by getting a black prediction. 

Since we only have one positive class the IoU should only be computed for the positive class and not the background.

In [40]:
model = BinarySemanticSegmentationTask(
     segmentation_model="unet",
     encoder_name="resnet18",
     encoder_weights="imagenet",
     in_channels=1,
     num_filters=32,
     num_classes=1,
     loss="jaccard",
     learning_rate=0.1,
     ignore_zeros=None,
     learning_rate_schedule_patience=5,
 )

# model = SemanticSegmentationTask(
#      segmentation_model="unet",
#      encoder_name="resnet18",
#      encoder_weights="imagenet",
#      in_channels=1,
#      num_filters=32,
#      num_classes=2,
#      loss="jaccard",
#      learning_rate=0.1,
#      ignore_zeros=True,
#      learning_rate_schedule_patience=5,
#  )

In [41]:
trainer = Trainer(fast_dev_run=True,
               )

# trainer = Trainer(max_epochs=5,
#                   precision=16,
#                   log_every_n_steps=1,
#                )

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Running in fast_dev_run mode: will run a full train, val, test and prediction loop using 1 batch(es).


In [42]:
trainer.fit(model=model, datamodule=datamodule)


  | Name          | Type             | Params
---------------------------------------------------
0 | model         | Unet             | 14.3 M
1 | loss          | JaccardLoss      | 0     
2 | train_metrics | MetricCollection | 0     
3 | val_metrics   | MetricCollection | 0     
4 | test_metrics  | MetricCollection | 0     
---------------------------------------------------
14.3 M    Trainable params
0         Non-trainable params
14.3 M    Total params
57.288    Total estimated model params size (MB)


Epoch 0:  50%|████████████████████████████▌                            | 1/2 [00:02<00:02,  2.79s/it, loss=0, v_num=]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                                                              | 0/1 [00:00<?, ?it/s][A
Epoch 0: 100%|█████████████████████████████████████████████████████████| 2/2 [00:04<00:00,  2.09s/it, loss=0, v_num=][A
Epoch 0: 100%|█████████████████████████████████████████████████████████| 2/2 [00:04<00:00,  2.09s/it, loss=0, v_num=][A
