Training locally
===

- RV geodatasets, learner, visualizer

In [None]:
from os.path import join
from pathlib import Path

from tqdm.auto import tqdm
import torch

In [None]:
from rastervision.core.data import ClassConfig

class_config = ClassConfig(
    names=['background', 'dust'],
    colors=['lightgray', 'maroon'],
    null_class='background',
)

In [None]:
img_dir = '/home/ahassan/un-sandstorm/data/gibs/downloaded/VIIRS_SNPP_CorrectedReflectance_TrueColor/'
label_dir = '/home/ahassan/un-sandstorm/data/gibs/labels_geojson/'
out_dir = 'data/train/2024-04-29_workshop'

In [None]:
filenames_labeled = [
    '2022-05-25_10,40,20,50.json',
    '2022-01-19_30,30,40,40.json',
    '2022-04-07_20,30,30,40.json',
    '2022-05-15_10,30,20,40.json',
    '2022-06-22_30,40,40,50.json',
    '2022-01-06_30,50,40,60.json',
    '2022-07-26_30,60,40,70.json',
    '2022-03-18_30,60,40,70.json',
    '2022-03-29_20,30,30,40.json',
    '2022-05-02_20,40,30,50.json',
    '2022-09-26_20,40,30,50.json',
    '2022-05-06_30,40,40,50.json',
    '2022-01-28_20,50,30,60.json',
    '2022-08-27_30,40,40,50.json',
    '2022-05-14_30,60,40,70.json',
    '2022-07-11_20,50,30,60.json',
    '2022-04-08_30,50,40,60.json',
    '2022-05-23_30,40,40,50.json',
    '2022-03-18_20,60,30,70.json',
    '2022-05-17_30,60,40,70.json',
    '2022-09-12_10,30,20,40.json',
    '2022-06-17_30,40,40,50.json',
    '2022-05-14_20,40,30,50.json',
    '2022-10-20_30,60,40,70.json',
    '2022-05-15_20,50,30,60.json',
    '2022-05-23_10,40,20,50.json',
    '2022-02-04_30,60,40,70.json',
    '2022-07-03_30,40,40,50.json',
    '2022-06-20_20,60,30,70.json',
    '2022-01-31_10,30,20,40.json',
    '2022-09-27_20,30,30,40.json',
    '2022-04-12_10,30,20,40.json',
    '2022-05-14_30,50,40,60.json',
    '2022-03-18_10,40,20,50.json',
    '2022-05-18_40,60,50,70.json',
    '2022-05-26_10,40,20,50.json',
    '2022-04-24_30,30,40,40.json',
    '2022-06-18_20,60,30,70.json',
    '2022-06-20_30,60,40,70.json',
    '2022-03-19_10,40,20,50.json',
    '2022-06-21_20,60,30,70.json',
    '2022-05-13_10,30,20,40.json',
    '2022-09-25_30,40,40,50.json',
    '2022-06-03_30,30,40,40.json',
    '2022-01-21_20,60,30,70.json',
    '2022-09-04_20,30,30,40.json',
    '2022-10-02_20,60,30,70.json',
    '2022-05-23_30,30,40,40.json',
    '2022-03-18_20,50,30,60.json',
    '2022-06-26_20,40,30,50.json',
    '2022-01-06_20,50,30,60.json',
    '2022-01-20_30,30,40,40.json',
    '2022-06-23_20,50,30,60.json',
    '2022-10-02_30,60,40,70.json',
    '2022-09-11_20,60,30,70.json',
    '2022-09-05_10,40,20,50.json',
    '2022-04-07_30,40,40,50.json',
    '2022-07-02_30,40,40,50.json',
    '2022-05-14_20,60,30,70.json',
    '2022-06-21_20,30,30,40.json',
    '2022-01-31_10,40,20,50.json',
    '2022-01-19_30,40,40,50.json',
    '2022-05-07_30,40,40,50.json',
    '2022-01-14_30,60,40,70.json',
    '2022-06-21_10,30,20,40.json',
    '2022-09-12_20,60,30,70.json',
    '2022-06-19_30,60,40,70.json',
    '2022-05-14_10,50,20,60.json',
    '2022-06-18_10,40,20,50.json',
    '2022-02-03_20,40,30,50.json',
    '2022-04-25_20,30,30,40.json',
    '2022-05-15_30,30,40,40.json',
    '2022-05-05_20,60,30,70.json',
    '2022-07-04_30,50,40,60.json',
    '2022-01-27_20,40,30,50.json',
    '2022-03-19_20,30,30,40.json',
    '2022-06-23_30,60,40,70.json',
    '2022-01-11_20,50,30,60.json',
    '2022-06-05_20,60,30,70.json',
    '2022-08-29_10,30,20,40.json',
    '2022-05-26_20,30,30,40.json',
    '2022-05-05_30,40,40,50.json',
    '2022-01-29_20,40,30,50.json',
    '2022-03-12_30,60,40,70.json',
    '2022-05-29_20,60,30,70.json',
    '2022-04-12_40,50,50,60.json',
    '2022-06-19_10,40,20,50.json',
    '2022-07-06_30,50,40,60.json',
    '2022-07-24_30,60,40,70.json',
    '2022-05-28_20,40,30,50.json',
    '2022-09-27_20,40,30,50.json',
    '2022-09-02_20,30,30,40.json',
    '2022-06-27_20,60,30,70.json',
    '2022-07-04_10,50,20,60.json',
    '2022-07-08_10,30,20,40.json',
    '2022-01-27_30,40,40,50.json',
    '2022-05-02_20,60,30,70.json',
    '2022-09-05_10,30,20,40.json',
    '2022-01-27_30,50,40,60.json',
    '2022-06-17_10,30,20,40.json',
    '2022-05-25_20,50,30,60.json',
    '2022-01-27_30,60,40,70.json',
    '2022-06-26_30,40,40,50.json',
    '2022-09-26_30,60,40,70.json',
    '2022-07-02_30,30,40,40.json',
    '2022-05-03_30,50,40,60.json',
    '2022-07-05_30,50,40,60.json',
    '2022-06-29_10,40,20,50.json',
    '2022-06-11_20,60,30,70.json',
    '2022-05-04_30,30,40,40.json',
    '2022-06-21_30,40,40,50.json',
    '2022-03-17_30,60,40,70.json',
    '2022-01-11_30,50,40,60.json',
    '2022-07-04_10,40,20,50.json',
    '2022-06-04_30,60,40,70.json',
    '2022-07-05_20,60,30,70.json',
    '2022-02-21_30,40,40,50.json',
    '2022-07-01_10,40,20,50.json',
    '2022-06-22_20,60,30,70.json',
    '2022-06-18_20,40,30,50.json',
    '2022-02-21_10,40,20,50.json',
    '2022-01-21_10,30,20,40.json',
    '2022-06-19_10,50,20,60.json',
    '2022-07-05_30,60,40,70.json',
    '2022-06-27_10,30,20,40.json',
    '2022-03-31_30,40,40,50.json',
    '2022-06-19_20,50,30,60.json',
    '2022-06-17_30,60,40,70.json',
    '2022-06-19_10,30,20,40.json',
    '2022-03-30_40,50,50,60.json',
    '2022-07-02_20,60,30,70.json',
    '2022-05-05_20,50,30,60.json',
    '2022-05-16_10,30,20,40.json',
    '2022-07-02_30,60,40,70.json',
    '2022-07-04_30,60,40,70.json',
    '2022-04-07_10,30,20,40.json',
    '2022-07-03_10,30,20,40.json',
    '2022-06-20_10,30,20,40.json',
    '2022-09-04_10,30,20,40.json',
    '2022-01-28_30,40,40,50.json',
    '2022-06-26_20,50,30,60.json',
    '2022-09-26_20,30,30,40.json',
    '2022-07-05_10,50,20,60.json',
    '2022-05-23_20,40,30,50.json',
    '2022-06-10_10,30,20,40.json',
    '2022-07-01_20,40,30,50.json',
    '2022-06-21_20,50,30,60.json',
    '2022-05-13_30,40,40,50.json',
    '2022-05-13_20,50,30,60.json',
    '2022-05-17_30,40,40,50.json',
    '2022-02-02_20,60,30,70.json',
    '2022-03-29_20,50,30,60.json',
    '2022-06-20_10,40,20,50.json',
    '2022-06-17_10,40,20,50.json',
    '2022-03-18_30,50,40,60.json',
    '2022-05-16_20,60,30,70.json',
    '2022-06-21_30,60,40,70.json',
    '2022-03-25_30,50,40,60.json',
    '2022-09-24_20,60,30,70.json',
    '2022-03-05_40,40,50,50.json',
    '2022-09-24_30,40,40,50.json',
    '2022-01-27_20,50,30,60.json',
    '2022-01-21_20,40,30,50.json',
    '2022-07-01_30,60,40,70.json',
    '2022-05-01_20,40,30,50.json',
    '2022-01-28_20,40,30,50.json',
    '2022-06-23_20,60,30,70.json',
    '2022-07-11_30,40,40,50.json',
    '2022-07-04_20,30,30,40.json',
    '2022-06-03_20,60,30,70.json',
    '2022-06-11_30,50,40,60.json',
    '2022-03-06_40,60,50,70.json',
    '2022-05-17_10,30,20,40.json',
    '2022-06-28_20,30,30,40.json',
    '2022-06-18_30,40,40,50.json',
    '2022-05-22_10,40,20,50.json',
    '2022-04-10_20,50,30,60.json',
    '2022-07-01_20,60,30,70.json',
    '2022-06-28_30,40,40,50.json',
    '2022-04-07_20,40,30,50.json',
    '2022-07-24_30,40,40,50.json',
    '2022-02-10_30,60,40,70.json',
    '2022-02-02_20,50,30,60.json',
    '2022-02-03_20,50,30,60.json',
    '2022-03-18_20,40,30,50.json',
    '2022-07-25_30,50,40,60.json',
    '2022-03-11_20,50,30,60.json',
    '2022-03-17_10,40,20,50.json',
    '2022-05-15_20,60,30,70.json',
    '2022-05-25_10,30,20,40.json',
    '2022-06-21_10,40,20,50.json',
    '2022-05-28_10,40,20,50.json',
    '2022-05-22_30,60,40,70.json',
    '2022-09-24_30,50,40,60.json',
    '2022-05-03_30,60,40,70.json',
    '2022-06-26_20,30,30,40.json',
    '2022-01-22_30,60,40,70.json',
    '2022-05-25_30,50,40,60.json',
    '2022-05-17_20,60,30,70.json',
    '2022-03-19_20,60,30,70.json',
    '2022-06-28_10,30,20,40.json',
    '2022-04-24_20,40,30,50.json',
    '2022-06-26_10,30,20,40.json',
    '2022-06-25_30,30,40,40.json',
    '2022-07-04_10,30,20,40.json',
    '2022-05-05_20,40,30,50.json',
    '2022-03-30_20,60,30,70.json',
    '2022-03-05_20,50,30,60.json',
    '2022-10-22_30,30,40,40.json',
    '2022-01-06_10,50,20,60.json',
    '2022-07-02_10,30,20,40.json',
    '2022-11-14_20,60,30,70.json',
    '2022-03-06_20,40,30,50.json',
    '2022-05-26_30,60,40,70.json',
    '2022-04-24_20,30,30,40.json',
    '2022-07-01_10,30,20,40.json',
    '2022-05-02_30,40,40,50.json',
    '2022-01-28_30,50,40,60.json',
    '2022-05-24_20,50,30,60.json',
    '2022-07-01_30,40,40,50.json',
    '2022-06-05_30,60,40,70.json',
    '2022-05-14_20,50,30,60.json',
    '2022-05-22_20,60,30,70.json',
    '2022-06-25_10,40,20,50.json',
    '2022-03-29_30,60,40,70.json',
    '2022-06-22_10,40,20,50.json',
    '2022-03-17_20,50,30,60.json',
    '2022-07-24_10,30,20,40.json',
    '2022-05-18_20,40,30,50.json',
    '2022-01-21_30,50,40,60.json',
    '2022-07-11_40,50,50,60.json',
    '2022-06-10_30,50,40,60.json',
    '2022-09-25_20,40,30,50.json',
    '2022-07-25_20,30,30,40.json',
    '2022-06-30_20,40,30,50.json',
    '2022-07-24_20,30,30,40.json',
    '2022-05-28_30,60,40,70.json',
    '2022-05-26_20,60,30,70.json',
    '2022-05-22_20,30,30,40.json',
    '2022-02-22_20,60,30,70.json',
    '2022-05-15_10,50,20,60.json',
    '2022-07-19_10,40,20,50.json',
    '2022-05-26_20,50,30,60.json',
    '2022-06-11_30,30,40,40.json',
    '2022-01-18_30,50,40,60.json',
    '2022-07-19_30,60,40,70.json',
    '2022-07-19_10,30,20,40.json',
    '2022-05-04_20,50,30,60.json',
    '2022-05-23_30,60,40,70.json',
    '2022-05-17_10,40,20,50.json',
    '2022-03-05_20,40,30,50.json',
    '2022-05-23_10,30,20,40.json',
    '2022-03-12_10,30,20,40.json',
    '2022-07-05_20,50,30,60.json',
    '2022-05-03_20,50,30,60.json',
    '2022-03-31_30,60,40,70.json',
    '2022-05-27_20,50,30,60.json',
    '2022-06-29_10,30,20,40.json',
    '2022-05-22_20,40,30,50.json',
    '2022-05-24_10,40,20,50.json',
    '2022-03-06_30,30,40,40.json',
    '2022-05-22_10,50,20,60.json',
    '2022-06-09_20,60,30,70.json',
    '2022-03-29_10,40,20,50.json',
    '2022-06-11_20,50,30,60.json',
    '2022-05-20_10,30,20,40.json',
    '2022-05-05_10,30,20,40.json',
    '2022-07-11_20,40,30,50.json',
    '2022-06-28_20,60,30,70.json',
    '2022-03-11_20,30,30,40.json',
    '2022-06-22_30,50,40,60.json',
    '2022-04-25_30,30,40,40.json',
    '2022-05-24_30,50,40,60.json',
    '2022-06-30_30,40,40,50.json',
    '2022-05-13_30,60,40,70.json',
    '2022-06-12_20,60,30,70.json',
    '2022-07-05_30,40,40,50.json',
    '2022-06-26_10,40,20,50.json',
    '2022-06-09_10,30,20,40.json',
    '2022-06-11_30,60,40,70.json',
    '2022-06-18_30,60,40,70.json',
    '2022-02-03_30,30,40,40.json',
    '2022-01-22_20,60,30,70.json',
    '2022-06-28_10,40,20,50.json',
    '2022-05-02_20,50,30,60.json',
    '2022-07-04_20,50,30,60.json',
    '2022-09-11_30,60,40,70.json',
    '2022-01-22_20,50,30,60.json',
    '2022-05-03_20,60,30,70.json',
    '2022-06-22_10,30,20,40.json',
    '2022-01-31_30,30,40,40.json',
    '2022-05-16_30,40,40,50.json',
    '2022-05-27_20,40,30,50.json',
    '2022-07-25_30,60,40,70.json',
    '2022-07-26_20,30,30,40.json',
    '2022-06-04_20,60,30,70.json',
    '2022-05-06_20,40,30,50.json',
    '2022-05-01_30,40,40,50.json',
    '2022-01-21_30,60,40,70.json',
    '2022-05-15_30,60,40,70.json',
    '2022-03-29_40,50,50,60.json',
    '2022-08-29_20,30,30,40.json',
    '2022-09-02_10,30,20,40.json',
    '2022-06-19_30,40,40,50.json',
    '2022-05-24_20,40,30,50.json',
    '2022-03-19_10,50,20,60.json',
    '2022-06-25_10,30,20,40.json',
    '2022-10-23_30,60,40,70.json',
    '2022-06-17_20,60,30,70.json',
    '2022-03-11_30,40,40,50.json',
    '2022-05-28_20,60,30,70.json',
    '2022-07-10_30,60,40,70.json',
    '2022-03-11_10,30,20,40.json',
    '2022-03-17_20,30,30,40.json',
    '2022-06-05_20,50,30,60.json',
    '2022-06-12_30,60,40,70.json',
    '2022-06-23_30,50,40,60.json',
    '2022-07-24_30,50,40,60.json',
    '2022-09-11_40,60,50,70.json',
    '2022-06-30_10,30,20,40.json',
    '2022-06-29_10,50,20,60.json',
    '2022-06-22_30,30,40,40.json',
    '2022-05-29_30,60,40,70.json',
    '2022-05-17_20,50,30,60.json',
    '2022-05-20_30,40,40,50.json',
    '2022-03-18_20,30,30,40.json',
    '2022-05-13_10,50,20,60.json',
    '2022-09-12_30,60,40,70.json',
    '2022-06-29_30,40,40,50.json',
    '2022-03-19_20,40,30,50.json',
    '2022-07-24_20,40,30,50.json',
    '2022-01-11_30,60,40,70.json',
    '2022-03-05_40,50,50,60.json',
    '2022-02-04_10,30,20,40.json',
    '2022-03-30_30,60,40,70.json',
    '2022-09-27_10,30,20,40.json',
    '2022-05-04_30,40,40,50.json',
    '2022-06-17_20,50,30,60.json',
    '2022-02-03_20,60,30,70.json',
    '2022-03-18_10,50,20,60.json',
    '2022-05-17_20,40,30,50.json',
    '2022-06-28_20,40,30,50.json',
    '2022-07-08_20,30,30,40.json',
    '2022-06-11_30,40,40,50.json',
    '2022-05-24_30,40,40,50.json',
    '2022-07-25_10,30,20,40.json',
    '2022-06-22_40,50,50,60.json',
    '2022-07-11_30,50,40,60.json',
    '2022-07-10_10,40,20,50.json',
    '2022-05-04_20,60,30,70.json',
    '2022-05-26_40,60,50,70.json',
    '2022-05-18_20,50,30,60.json',
    '2022-02-20_30,30,40,40.json',
    '2022-03-17_30,50,40,60.json',
    '2022-07-02_20,50,30,60.json',
    '2022-05-18_30,40,40,50.json',
    '2022-07-10_30,40,40,50.json',
    '2022-03-04_30,40,40,50.json',
    '2022-06-19_20,40,30,50.json',
    '2022-05-27_20,60,30,70.json',
    '2022-04-03_30,60,40,70.json',
    '2022-10-22_30,40,40,50.json',
    '2022-03-06_20,30,30,40.json',
    '2022-01-20_30,40,40,50.json',
    '2022-03-04_30,50,40,60.json',
    '2022-03-29_20,60,30,70.json',
    '2022-07-05_10,40,20,50.json',
    '2022-06-13_30,40,40,50.json',
    '2022-02-02_30,60,40,70.json',
    '2022-09-26_40,60,50,70.json',
    '2022-06-03_30,40,40,50.json',
    '2022-06-27_10,40,20,50.json',
    '2022-07-03_30,60,40,70.json',
    '2022-09-24_30,60,40,70.json',
    '2022-01-21_20,50,30,60.json',
    '2022-05-24_30,60,40,70.json',
    '2022-06-13_20,60,30,70.json',
    '2022-11-14_30,60,40,70.json',
    '2022-07-06_30,60,40,70.json',
    '2022-10-22_30,60,40,70.json',
    '2022-03-04_20,40,30,50.json',
    '2022-03-18_10,60,20,70.json',
    '2022-05-27_30,60,40,70.json',
    '2022-01-14_20,60,30,70.json',
    '2022-05-16_20,50,30,60.json',
    '2022-06-12_10,30,20,40.json',
    '2022-06-19_20,60,30,70.json',
    '2022-01-20_30,60,40,70.json',
    '2022-03-11_30,30,40,40.json',
    '2022-05-23_20,60,30,70.json',
    '2022-06-29_20,30,30,40.json',
    '2022-03-10_30,60,40,70.json',
    '2022-05-14_10,30,20,40.json',
    '2022-05-26_10,30,20,40.json',
    '2022-07-04_30,40,40,50.json',
    '2022-06-22_30,60,40,70.json',
    '2022-02-20_30,40,40,50.json',
    '2022-09-05_20,30,30,40.json',
    '2022-07-08_30,60,40,70.json',
    '2022-06-13_30,60,40,70.json',
    '2022-05-25_20,40,30,50.json',
    '2022-05-27_10,40,20,50.json',
    '2022-05-17_10,50,20,60.json',
    '2022-04-08_30,60,40,70.json',
    '2022-03-25_30,60,40,70.json',
    '2022-05-16_20,40,30,50.json',
    '2022-05-26_30,40,40,50.json',
    '2022-05-07_20,40,30,50.json',
    '2022-09-11_30,40,40,50.json',
    '2022-07-25_20,40,30,50.json',
    '2022-05-18_10,30,20,40.json',
    '2022-03-10_20,40,30,50.json',
    '2022-04-10_20,40,30,50.json',
    '2022-05-26_20,40,30,50.json',
    '2022-04-03_40,50,50,60.json',
    '2022-03-05_10,40,20,50.json',
    '2022-06-25_20,30,30,40.json',
    '2022-04-12_30,40,40,50.json',
    '2022-05-23_10,50,20,60.json',
    '2022-05-13_20,40,30,50.json',
    '2022-03-17_20,40,30,50.json',
    '2022-09-27_10,40,20,50.json',
    '2022-05-18_10,40,20,50.json',
    '2022-03-06_20,50,30,60.json',
]

In [None]:
from rastervision.pipeline.file_system.utils import list_paths

# img_uris = sorted(list_paths(img_dir, ext='.tif'))
img_uris = [join(img_dir, f'{Path(fname).stem}.tif') for fname in filenames_labeled]
len(img_uris)

In [None]:
label_uris = [join(label_dir, f'{Path(uri).stem}.json') for uri in img_uris]

In [None]:
import albumentations as A
from rastervision.core.data import ClassInferenceTransformer
from rastervision.pytorch_learner import (
    SemanticSegmentationRandomWindowGeoDataset,
    SemanticSegmentationSlidingWindowGeoDataset)

data_augmentation_transform = A.Compose([
    A.Flip(),
    A.ShiftScaleRotate(),
    A.CoarseDropout(max_height=32, max_width=32, max_holes=5)
])


def make_train_ds(img_uri: str, label_uri: str):
    ds = SemanticSegmentationRandomWindowGeoDataset.from_uris(
        class_config=class_config,
        image_uri=img_uri,
        label_vector_uri=label_uri,
        label_vector_source_kw=dict(vector_transformers=[
            ClassInferenceTransformer(
                default_class_id=class_config.get_class_id('background'),
                class_config=class_config,
                class_name_mapping=dict(
                    dust_over_land='dust', dust_over_water='dust'),
            )
        ]),
        size_lims=(512, 513),
        out_size=256,
        max_windows=16,
        padding=0,
        transform=data_augmentation_transform,
    )
    return ds


def make_val_ds(img_uri: str, label_uri: str):
    ds = SemanticSegmentationSlidingWindowGeoDataset.from_uris(
        class_config=class_config,
        image_uri=img_uri,
        label_vector_uri=label_uri,
        label_vector_source_kw=dict(vector_transformers=[
            ClassInferenceTransformer(
                default_class_id=class_config.get_class_id('background'),
                class_config=class_config,
                class_name_mapping=dict(
                    dust_over_land='dust', dust_over_water='dust'),
            )
        ]),
        size=512,
        stride=512,
        padding=0,
        out_size=256,
    )
    return ds

In [None]:
train_size = 300

img_uris_train = img_uris[:train_size]
label_uris_train = label_uris[:train_size]

img_uris_val = img_uris[train_size:]
label_uris_val = label_uris[train_size:]

In [None]:
from torch.utils.data import ConcatDataset

with tqdm(zip(img_uris_train, label_uris_train), total=train_size) as bar:
    train_dses = [
        make_train_ds(img_uri, label_uri) for img_uri, label_uri in bar
    ]
train_ds = ConcatDataset(train_dses)
len(train_ds)

In [None]:
val_size = len(img_uris_val)
with tqdm(zip(img_uris_val, label_uris_val), total=val_size) as bar:
    val_dses = [make_val_ds(img_uri, label_uri) for img_uri, label_uri in bar]
val_ds = ConcatDataset(val_dses)
len(val_ds)

In [None]:
from rastervision.pytorch_learner import SemanticSegmentationVisualizer

viz = SemanticSegmentationVisualizer(
    class_names=class_config.names, class_colors=class_config.colors)

In [None]:
x, y = viz.get_batch(train_ds, 8, shuffle=True)
viz.plot_batch(x, y, show=True)

---

- elaborate on pretrained weights, fine-tuning

In [None]:
model = torch.hub.load(
    'AdeelH/pytorch-fpn:0.3',
    'make_fpn_resnet',
    name='resnet18',
    fpn_type='panoptic',
    num_classes=len(class_config),
    fpn_channels=128,
    in_channels=3,
    out_size=(256, 256),
    pretrained=True)

In [None]:
from rastervision.pytorch_learner import DataConfig

data_cfg = DataConfig(
    class_config=class_config,
    num_workers=4,  # increase to use multi-processing
)

In [None]:
from rastervision.pytorch_learner import SolverConfig

solver_cfg = SolverConfig(
    batch_sz=16,
    lr=1e-4,
)

In [None]:
from rastervision.pytorch_learner import SemanticSegmentationLearnerConfig

learner_cfg = SemanticSegmentationLearnerConfig(
    data=data_cfg, solver=solver_cfg)

In [None]:
from rastervision.pytorch_learner import SemanticSegmentationLearner

learner = SemanticSegmentationLearner(
    cfg=learner_cfg,
    output_dir=out_dir,
    model=model,
    train_ds=train_ds,
    valid_ds=val_ds,
)
learner.log_data_stats()

In [None]:
learner.train(epochs=5)

In [None]:
learner.save_model_bundle()

---

In [None]:
preds = learner.predict_dataloader(
    learner.valid_dl, return_format='xyz', batched_output=True, raw_out=True)

In [None]:
for _ in range(4):
    x, y, z = next(preds)
    learner.visualizer.plot_batch(x, y, output_path=None, z=z, show=True)