In [None]:
import sys
import os
sys.path.insert(0, os.path.abspath('..'))

%load_ext autoreload
%autoreload 2

In [None]:
from google.cloud import storage
from project_config import GCP_PROJECT_NAME

gcp_client = storage.Client(project=GCP_PROJECT_NAME)

### Configuration

In [None]:
from models.model_selection import ModelSelection
import albumentations as A

VALIDATION_SITES = ["Ken_Banda", "Sone_Rohtas"]
TILE_SIZE = 110

AUGMENTATION_TRAIN = A.Compose([
    A.RandomRotate90(),
    A.HorizontalFlip(),
    A.VerticalFlip(),
    A.CoarseDropout(max_height=32, max_width=32, max_holes=3)
])

EXPERIMENT_DIR = '../out/0831'

BATCH_SIZE = 256
LR = 3e-2

MODEL_TYPE = ModelSelection.Segformer


### Create Rastervision datasets

In [None]:
from torch.utils.data import ConcatDataset

from utils.rastervision_utils import observation_to_scene_s1s2, scene_to_training_ds, scene_to_validation_ds, warn_if_nan_in_raw_scene
from utils.data_management import observation_factory

def is_validation(scene):
    return any(
        [validation_site in scene.id
         for validation_site in VALIDATION_SITES]
    )

def is_training(scene):
    return not is_validation(scene)


all_observations = observation_factory(gcp_client)
all_scenes = list(map(
    lambda observation: observation_to_scene_s1s2(observation),
    all_observations
))

for scene in all_scenes:
    warn_if_nan_in_raw_scene(scene)

training_scenes = filter(is_training, all_scenes)
validation_scenes = filter(is_validation, all_scenes)

training_datasets = list(map(
    lambda scene: scene_to_training_ds(scene, TILE_SIZE, AUGMENTATION_TRAIN),
    training_scenes
))
validation_datasets = list(map(
    lambda scene: scene_to_validation_ds(scene, TILE_SIZE),
    validation_scenes
))
assert len(training_datasets) + len(validation_datasets) == len(all_scenes)

train_dataset_merged = ConcatDataset(training_datasets)
val_dataset_merged = ConcatDataset(validation_datasets)


In [None]:
from utils.visualizing import show_windows, show_image_in_dataset
from project_config import DISPLAY_GROUPS, CLASS_CONFIG

ds_to_visualize = training_datasets[2]
show_windows(
    ds_to_visualize.scene.raster_source[:,:],
    ds_to_visualize.windows
)
show_image_in_dataset(
    ds_to_visualize,
    CLASS_CONFIG,
    DISPLAY_GROUPS,
)

# Train

In [None]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
# torch.cuda.empty_cache()

In [None]:
training_datasets[0].scene.raster_source.shape

In [None]:
from utils.rastervision_utils import construct_semantic_segmentation_learner
from models.model_selection import get_model

_, _, n_channels = training_datasets[0].scene.raster_source.shape
model = get_model(
    selection=MODEL_TYPE,
    n_channels=n_channels,
    img_size=(TILE_SIZE, TILE_SIZE),
)

learner = construct_semantic_segmentation_learner(
    model=model,
    training_ds=train_dataset_merged,  # for development and debugging, use training_datasets[0] or similar to speed up
    validation_ds=val_dataset_merged,  # for development and debugging, use training_datasets[1] or similar to speed up
    batch_size=BATCH_SIZE,
    learning_rate=LR,
    # class_loss_weights=[1., 10.],
    experiment_dir=EXPERIMENT_DIR
)
learner.log_data_stats()

## Check GPU Activity

You can continuously monitor your GPU activity by using the command in the terminal


`watch -d -n 0.5 nvidia-smi`

In [9]:
learner.train(epochs=1)
learner.save_model_bundle()

# Tensorboard

To activate tensorboard, run the following command in the terminal:
`conda activate rastervision` (or whatever your conda environment is called)

`tensorboard --logdir sandmining-watch/out` (make sure it's relative to your current path in the terminal)

You will also need to port-forward; on your local machine, run:

`ssh -N -f -L localhost:6006:localhost:6006 <USERNAME>@fati.ischool.berkeley.edu`

Note: this is not needed if you're running from within VS Code- it should automatically give you the option to open tensorboard in the browser.