# Counting Boats from Space - Part 2

In [None]:
%reload_ext autoreload
%autoreload 2
%load_ext dotenv
%dotenv
%matplotlib inline

## 1. Install, Import requirements

In [None]:
# Various utilities
import os
import json
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import skimage
import torch
from torch.utils.data import DataLoader
import superintendent

In [None]:
from src.dataset import getImageSetDirectories, S2_Dataset, plot_dataset
from src.model import Model
from src.train import train, get_failures_or_success
from src.annotation_utils import display_image_and_references, display_heatmap_prediction

In [None]:
torch.cuda.is_available() # gpu support

## 2. Init K-Fold Dataset

In [None]:
data_dir = "/home/jovyan/data" # data directory (path)
labels_dir = './data'
checkpoint_dir = "/home/jovyan/checkpoints"
bands = ['img_08', 'bg_ndwi']
test_size = 0.1

In [None]:
train_list, val_list, fig = getImageSetDirectories(data_dir=os.path.join(data_dir, 'chips'), 
                                                   labels_filename=os.path.join(labels_dir, "labels.csv"),
                                                   band_list=bands, test_size=test_size, plot_coords=False, plot_class_imbalance=True, seed=123)
fig # mapbox plot train/val coordinates

In [None]:
train_dataset = S2_Dataset(imset_dir=train_list, augment=True, labels_filename=os.path.join(labels_dir,'labels.csv'))
val_dataset = S2_Dataset(imset_dir=val_list, augment=False, labels_filename=os.path.join(labels_dir,'labels.csv'))
plot_dataset(train_dataset, n_frames=14, n_rows=2, cmap='gray')

## 3. Train PyTorch Classifier

In [None]:
# training config
input_dim = train_dataset[0]['img'].shape[0]
hidden_dim, kernel_size, pool_size, n_max = 16, 3, 10, 1

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=16)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True, num_workers=4)

best_metrics = train(train_dataloader=train_dataloader, val_dataloader=val_dataloader,
             input_dim=input_dim, hidden_dim=hidden_dim, kernel_size=kernel_size, pool_size=pool_size, n_max=n_max, drop_proba=0.15,
             ld=0.5, water_ndwi=0.4,
             n_epochs=50, lr=0.007, lr_step=2, lr_decay=0.95,
             device='cpu', checkpoints_dir=checkpoint_dir, seed=42, verbose=1, version='0.0.5')

for k,v in best_metrics.items():
    print('{} {:.4f}'.format(k,v))

In [None]:
# v0.0.4 Epoch 34: train_clf_error 0.06563 / train_reg_error 0.10578 / val_clf_error 0.04367 / val_reg_error 0.06262

## 4. Test Model

In [None]:
# load pretrained model
model = Model(input_dim=input_dim, hidden_dim=hidden_dim, kernel_size=kernel_size, pool_size=pool_size, n_max=n_max, device='cpu', version='0.0.4')
checkpoint_file = os.path.join(checkpoint_dir, model.folder, 'model.pth')
model.load_checkpoint(checkpoint_file=checkpoint_file)
model = model.eval()

In [None]:
# Display failures (train, val), scatter plot (Predicted vs True) and inspect hidden channels --> Re label?
image_titles, relabel_images = get_failures_or_success(model, val_dataset,success=None, filter_on=None,
                                                       water_ndwi=0.5, filter_peaks=True, shift_pool=False, downsample=False,
                                                       plot_heatmap=False, hidden_channel=1,)

## 5. Relabel inputs

### Load superintendent widget and labelling

In [None]:
csv_file_path = "/home/jovyan/data/labels.csv"
labels_df = pd.read_csv(csv_file_path, index_col = ['lat_lon', 'timestamp'], dtype={'count': float})

labeller = superintendent.ClassLabeller(
    features=image_titles,
    options=[i for i in range(-1, 6)], 
    display_func=display_heatmap_prediction
)

#labeller

### Extract new labels and save them in labels_df

In [None]:
for i in range(len(relabel_images)):
    timestamp = relabel_images[i].stem.split('t_')[1]
    lat_lon = relabel_images[i].parts[-2]
    count = labeller.new_labels[i]
    # overwrite if the 
    if count:
        labels_df.at[(lat_lon, timestamp)] = count

### Dump back to csv file

In [None]:
#labels_df.to_csv(csv_file_path)