## Refine pre-trained model Ghana per tile

Using high quality labels just from Ghana

In [2]:
import os
import sys
import importlib
from pathlib import Path
import pandas as pd

In [3]:
repo = "deeplearner"
clone_path = "/home/mappers/projects/"
repo_clone_path = f"{clone_path}/{repo}"

In [4]:
sys.path.insert(0, os.path.join(repo_clone_path, 'deeplearner/'))
sys.path.insert(0, repo_clone_path)
import deeplearner
importlib.reload(deeplearner)
from deeplearner.models import *
from deeplearner.losses import *
from deeplearner.datatorch2 import *
from deeplearner.utils import *
from deeplearner.compiler import *

## Configuration

In [5]:
config = {
    "source_dir" : "/home/mappers/data/",
    "working_dir" : "/home/mappers/tmp/gh_cg_tz_ng/refine_ghana_per_tile",
    
    # train and validation dataset
    # "train_csv_name" : "catalog_gh_cg_tz_ng_v1.csv",
    "train_pickle_name" : "refine_ghana_per_tile.pickle",
    "val_pickle_name" : "val_ghana_per_tile.pickle",
    "lbl_patchSize" : 200,
    "one_side_buffer" : 12,
    "tile_buffer" : 11,
    "img_path_cols" : ["dir_os"],
    "norm_stats_type" : "local_per_tile",
    "label_path_col" : "dir_label",
    "train_lbl_quality_groups" : (3, 4),
    "val_lbl_quality_groups" : (3, 4),
    "transformations" : 
        ['vflip', 'hflip', 'rotate', 'resize', 'shift_brightness'],
    "rotationDegree" : (-90, 90),
    "bshift_band_grouping" : [4],

    # train and validation DataLoader
    "train_BatchSize" : 32,
    "val_BatchSize" : 2,

    # Model
    "input_channels" : 4,
    "n_classes" : 3,

    # Model compiler
    "gpuDevices" : [0],
    "params_init_path" : 
        (
            f"s3://activemapper/DL/models/gh_cg_tz_ng/local_per_tile/"
            f"unet_params.pth"
        ),
    "freeze_layer_ls" : list(range(58)),
    

    # Model fitting
    "epochs" : 20,
    "optimizer": "nesterov",
    "LR" : 0.01, 
    "LR_policy" : "PolynomialLR",
    "criterion" : "BalancedTverskyFocalLoss(gamma = 0.9)",
    "momentum" : 0.95,
    "resume" : False,
    "resume_epoch" : None,
    "bucket" : "activemapper",
    "prefix_out": "DL/models/gh_cg_tz_ng/refine_ghana_per_tile",

    #prediction 
}

pickle_dir = Path(config["source_dir"]) / "pickles"
if not os.path.exists(pickle_dir):
    os.makedirs(pickle_dir)

train_pickle_path = pickle_dir / config["train_pickle_name"]
val_pickle_path = pickle_dir / config["val_pickle_name"]

if not os.path.exists(config["working_dir"]):
    os.makedirs(config["working_dir"])

log_dir = Path(config["working_dir"]) / "logs"
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

# config

### Workflow

**Step 1.** Setup seeding to make the experiment reproducible. 

In [6]:
make_reproducible()

**Step 2.** Load the train and val datasets (i.e. divide the dataset into mini-batches after applying the augmentation, convert to tensor and put them on GPU if available)

In [7]:
train_dataset = load_dataset(train_pickle_path)
train_dataloader = DataLoader(
    train_dataset, batch_size=config["train_BatchSize"], shuffle=True
)

In [8]:
validation_dataset = load_dataset(val_pickle_path)
validation_dataloader = DataLoader(
    validation_dataset, batch_size=config["val_BatchSize"], shuffle=False
)

**Step 3.** Initialize the model

In [9]:
model = eval('unet'.lower())(config["input_channels"], config["n_classes"])

**Step 4.** Compile

The model Compiler is responsible for: 
1) handling the model parallelism on multiple GPUs, 
2) loading existing model parametrs if needed and
3) freeze user-defined layers of model if model-based transfer learning is pursued.

In [10]:
model = ModelCompiler(
    model, buffer = config["one_side_buffer"], 
    gpuDevices = config["gpuDevices"], 
    params_init = config["params_init_path"],
    freeze_params = config["freeze_layer_ls"]
)

----------GPU available----------
total number of trainable parameters: 20.5M
---------- Pre-trained model compiled successfully ----------


**Step 5.** train the model

In [11]:
model.fit(
    train_dataloader, 
    validation_dataloader, 
    epochs = config["epochs"], 
    optimizer_name = config["optimizer"], 
    lr_init = config["LR"], 
    lr_policy = config["LR_policy"], 
    criterion = config["criterion"], 
    momentum = config["momentum"],
    resume = config["resume"], 
    resume_epoch = config["resume_epoch"]
)

-------------------------- Start training --------------------------
[1/20]
train loss:0.763514104343596
validation loss: 0.5836250590005269
LR: 0.009919999758550428
time: 24
[2/20]
train loss:0.7530495552789598
validation loss: 0.5842316497427722
LR: 0.009919999758550428
time: 23
[3/20]
train loss:0.75078642084485
validation loss: 0.58841101002569
LR: 0.009839837734062955
time: 22
[4/20]
train loss:0.7539393816675458
validation loss: 0.5855828745601078
LR: 0.009759511943429413
time: 23
[5/20]
train loss:0.7506121794382731
validation loss: 0.5802314729585002
LR: 0.009679020358515226
time: 23
[6/20]
train loss:0.7562589304787772
validation loss: 0.5814116720420619
LR: 0.009598360904656872
time: 23
[7/20]
train loss:0.7543824400220599
validation loss: 0.5803205695624153
LR: 0.009517531459092827
time: 23
[8/20]
train loss:0.7557718299684071
validation loss: 0.5820903285872191
LR: 0.009436529849324329
time: 23
[9/20]
train loss:0.7492261557351976
validation loss: 0.5906727826533218
LR: 0.0

**Step 6.** Save the trained parameters

In [12]:
model.save(bucket=config["bucket"], outPrefix=config["prefix_out"])

Loss files uploaded to s3
model parameters uploaded to s3!, at  DL/models/gh_cg_tz_ng/refine_ghana_per_tile


**Step 7.** Evaluate the trained model against the evaluation dataset and report a number of accuracy metrics in a csv format.

In [13]:
model.evaluate(validation_dataloader, bucket=config["bucket"], 
               outPrefix=config["prefix_out"])

-------------------------- Start evaluation --------------------------
-------------------------- Evaluation finished in 19s --------------------------


#### Prediction

**Step 8.** Prediction

In [None]:
def load_pred_data(dir_data, pred_patch_size, pred_buffer, pred_composite_buffer, 
                   pred_batch, catalog, catalog_row, img_path_cols, average_neighbors=False):
    def load_single_tile(catalog_ind = catalog_row):
        dataset = planetData(dir_data, catalog, pred_patch_size, pred_buffer, 
                             pred_composite_buffer, "predict", 
                             catalogIndex=catalog_ind, imgPathCols=img_path_cols)
        data_loader = DataLoader(dataset, batch_size=pred_batch, shuffle=False)
        meta = dataset.meta
        tile = dataset.tile
        return data_loader, meta, tile

    if average_neighbors == True:
        catalog["tile_col_row"] = catalog.apply(lambda x: "{}_{}".format(x['tile_col'], x['tile_row']), axis=1)
        tile_col = catalog.iloc[catalog_row].tile_col
        tile_row = catalog.iloc[catalog_row].tile_row
        row_dict = {
            "center": catalog_row,
            "top": catalog.query('tile_col=={} & tile_row=={}'.format(tile_col, tile_row - 1)).iloc[0].name \
                if "{}_{}".format(tile_col, tile_row - 1) in list(catalog.tile_col_row) else None,
            "left" : catalog.query('tile_col=={} & tile_row=={}'.format(tile_col - 1, tile_row)).iloc[0].name \
                if "{}_{}".format(tile_col - 1, tile_row) in list(catalog.tile_col_row) else None,
            "right" : catalog.query('tile_col=={} & tile_row=={}'.format(tile_col + 1, tile_row)).iloc[0].name \
                if "{}_{}".format(tile_col + 1, tile_row) in list(catalog.tile_col_row) else None,
            "bottom": catalog.query('tile_col=={} & tile_row=={}'.format(tile_col, tile_row + 1)).iloc[0].name \
                if "{}_{}".format(tile_col, tile_row + 1) in list(catalog.tile_col_row) else None,
            }
        dataset_dict = {k:load_single_tile(catalog_ind = row_dict[k]) if row_dict[k] is not None else None 
                        for k in row_dict.keys()}
        return dataset_dict
    # direct crop edge pixels
    else:
        return load_single_tile()

In [None]:
prefix_out = 'DL/predictions/gh_cg_tz/DFUNet_WithoutAttention_05032022/ghana'
catalog = 'catalogs/predict/catalog_ghana_retiled_8.csv'

In [None]:
pred_catalog = pd.read_csv(os.path.join(dir_data, catalog))
inds = pred_catalog.query("type == 'center'").index.values

In [None]:
for i in inds:
    print("Predicting on index %s" % (i))
    pred_dataloader = load_pred_data(
        dir_data, pred_patch_size, pred_buffer, pred_composite_buffer, 
        pred_batch, pred_catalog, i, img_path_cols, 
        average_neighbors = average_neighbors
    )
    p = model.predict(
        pred_dataloader, bucket, prefix_out, 
        pred_buffer, averageNeighbors=average_neighbors, 
        shrinkBuffer = shrink_pixels
    )

In [None]:
import math
fig, ax = pyplot.subplots(3,3, figsize=(15,20))
for i in range(len(inds)):
    col, row = pred_catalog.iloc[inds[i]][['tile_col', 'tile_row']]
    image_path = 's3://{}/{}/Score_1/score_c{}_r{}.tif'.\
        format(bucket, prefix_out, col, row)
    img = rasterio.open(image_path).read()
    show(img.astype('uint8'), ax = ax[math.floor(i/3)][i%3], title=os.path.basename(image_path))
    
image_path
pyplot.savefig("score_map_unet.png")

In [None]:
import math
import re
from rasterio.plot import reshape_as_raster, reshape_as_image

def findRowCol(filename):
    col = re.findall(r"(?<=_c)\d\d\d(?=_r)", filename)[0]
    row = re.findall(r"(?<=_r)\d\d\d(?=.tif)", filename)[0]
    return int(row), int(col)

def rescale_image(image, bands=(0, 1, 2, 3)):
    img = reshape_as_image(image.read())[:,:,bands]
    max_vals = [img[:, :, band].max() for band in range(img.shape[-1])]
    img = img.astype('float64')
    for band in bands:
        band_vals = img[:, :, band]
        img[:, :, band] = band_vals / max_vals[band]

    return img

path_to_unet_tiles = 'predicted_tiles_unet'
path_to_dfunet_noattn_tiles = 'predicted_tiles_dfunet_noattn'
path_to_simpledfunet_tiles = 'predicted_tiles_simpledfunet'

predicted_unet = os.listdir(path_to_unet_tiles)
predicted_dfunet = os.listdir(path_to_dfunet_noattn_tiles)
predicted_simpledfunet = os.listdir(path_to_simpledfunet_tiles)

assert(predicted_dfunet == predicted_unet == predicted_simpledfunet)
n = len(predicted_unet)
assert(len(os.listdir('dir_os')) == len(os.listdir('dir_gs')) == n)

fig, ax = pyplot.subplots(n, 4, figsize=(18,40))

os_images = []
for i in range(n):
    unet_path = os.path.join(path_to_unet_tiles, predicted_unet[i])
    img = rasterio.open(unet_path).read()
    show(img.astype('uint8'), ax = ax[i][0], title="UNet_"+os.path.basename(unet_path)[:-4])
    
    dfunet_path = os.path.join(path_to_dfunet_noattn_tiles, predicted_unet[i])
    img = rasterio.open(dfunet_path).read()
    show(img.astype('uint8'), ax = ax[i][1], title="DFUNet_NoAttn_"+os.path.basename(unet_path)[:-4])
    
    simpledfunet_path = os.path.join(path_to_simpledfunet_tiles, predicted_unet[i])
    img = rasterio.open(simpledfunet_path).read()
    show(img.astype('uint8'), ax = ax[i][2], title="Simple DFUNet_"+os.path.basename(unet_path)[:-4])
    
    row, col = findRowCol(unet_path)
    # gs_path = os.path.basename(pred_catalog.loc[pred_catalog['tile_col']==col, 'dir_gs'].iloc[0])
    # img = rasterio.open(f"dir_gs/{gs_path}").read()
    # show(img.astype('uint8'), ax = ax[i][3], title="GS")
    
    os_path = os.path.basename(pred_catalog.loc[pred_catalog['tile_col']==col, 'dir_os'].iloc[0])
    img = rasterio.open(f"dir_os/{os_path}")
    os_images.append(img)
    ax_curr = ax[i][3]
    img_rescaled = rescale_image(img)
    ax_curr.imshow(img_rescaled[:,:,(3,2,1)])

pyplot.savefig("score_map_compare_simpledfunet.svg",facecolor='white', format='svg',transparent=False, dpi=300)