In [None]:
import sys
import os
sys.path.insert(0, os.path.abspath('..'))

%load_ext autoreload
%autoreload 2

In [None]:
from dotenv import load_dotenv
load_dotenv()

In [None]:
from google.cloud import storage
from project_config import GCP_PROJECT_NAME

gcp_client = storage.Client(project=GCP_PROJECT_NAME)

### Configuration

In [None]:
from os.path import expanduser

from experiment_configs.unet_fs_config import unet_orig_config, unet_resblocks_config, resnet18_unet_config
from experiment_configs.satmae_ft_config import satmae_ft_doubleupsampling_config, satmae_ft_lineardecoder_config
from experiment_configs.resnetmoco_ft_config import resnet50_moco_ft_config
config = resnet50_moco_ft_config
#config.output_dir = expanduser("~/sandmining-watch/out/1009_resnet50_moco")

### Create Rastervision datasets

In [None]:
from torch.utils.data import ConcatDataset

from utils.rastervision_pipeline import observation_to_scene, scene_to_training_ds, scene_to_validation_ds, warn_if_nan_in_raw_raster
from utils.data_management import observation_factory

from project_config import is_training, is_validation


all_observations = observation_factory(gcp_client)
training_scenes = []
validation_scenes = []

for observation in all_observations:
    is_train = is_training(observation.name)
    is_val = is_validation(observation.name)

    assert not (is_train and is_val), "An observation cannot be in both training and validation"

    scene = observation_to_scene(config, observation)    
    if is_train:
        training_scenes.append(scene)
    elif is_val:
        validation_scenes.append(scene)
    else:
        print(f"Ignoring observation {observation.name}")


#all_scenes = training_scenes + validation_scenes
#for scene in all_scenes:
#    warn_if_nan_in_raw_raster(scene.raster_source)

training_datasets = [
    scene_to_training_ds(config, scene) for scene in training_scenes
]
validation_datasets = [
    scene_to_validation_ds(config, scene) for scene in validation_scenes
]

train_dataset_merged = ConcatDataset(training_datasets)
val_dataset_merged = ConcatDataset(validation_datasets)


In [None]:
from utils.visualizing import Visualizer
visualizer = Visualizer(config.s2_channels)

train_ds = training_datasets[0]
windows = [train_ds.sample_window() for _ in range(train_ds.max_windows)]
visualizer.show_windows(
    train_ds.scene.raster_source.get_image_array(),
    windows
)

val_ds = validation_datasets[0]
visualizer.show_windows(
    val_ds.scene.raster_source.get_image_array(),
    val_ds.windows
)


# Train

In [None]:
# import torch
# torch.cuda.empty_cache()

In [None]:
from models.model_factory import model_factory
from ml.optimizer_factory import optimizer_factory
from ml.custom_learner import learner_factory

_, _, n_channels = training_datasets[0].scene.raster_source.shape
model = model_factory(
    config,
    n_channels=n_channels,
)

optimizer = optimizer_factory(config, model)

learner = learner_factory(
    config=config,
    model=model,
    optimizer=optimizer,
    training_ds=train_dataset_merged,  # for development and debugging, use training_datasets[0] or similar to speed up
    validation_ds=val_dataset_merged,  # for development and debugging, use training_datasets[1] or similar to speed up
)
learner.log_data_stats()

#### Check GPU Activity

You can continuously monitor your GPU activity by using the command in the terminal


`watch -d -n 0.5 nvidia-smi`

In [None]:
# Run this cell if you want to log the run to W&B. You might need to authenticate to W&B.
learner.initialize_wandb_run()

In [None]:
learner.train(epochs=15)
learner.save_model_bundle()

## Evaluate 

Run inference on validation sites and log results to W&B

In [None]:
learner.evaluate_and_log_to_wandb(validation_datasets)