In [None]:
from icevision import models, tfms
from torchvision.ops import MultiScaleRoIAlign
from ceruleanml import coco_load_fastai, data, preprocess
from ceruleanml.coco_load_fastai import record_collection_to_record_ids, get_image_path, record_to_mask
from torchsummary import summary


In [None]:
memtile_size = 1024  # setting memtile_size=0 means use full scenes instead of tiling
rrctile_size = 1024  #
run_list = [
    [512, 80],
    # [416, 60],
]  # List of tuples, where the tuples are [px size, training time in minutes]

negative_sample_count_train = 100
negative_sample_count_val = 0
negative_sample_count_test = 0
negative_sample_count_rrctrained = 0

area_thresh = 100  # XXX maybe run a histogram on this to confirm that we have much more than 100 px normally!

classes_to_remove = [
    "ambiguous",
    # "natural_seep",
]
classes_to_remap = {
    "old_vessel": "recent_vessel",
    "coincident_vessel": "recent_vessel",
}

classes_to_keep = [
    c
    for c in data.class_list
    if c not in classes_to_remove + list(classes_to_remap.keys())
]

thresholds = {
    "pixel_nms_thresh": 0.4,  # prediction vs itself, pixels
    "bbox_score_thresh": 0.2,  # prediction vs score, bbox
    "poly_score_thresh": 0.2,  # prediction vs score, polygon
    "pixel_score_thresh": 0.2,  # prediction vs score, pixels
    "groundtruth_dice_thresh": 0.0,  # prediction vs ground truth, theshold
}

num_workers = 8  # based on processor, but I don't know how to calculate...

In [None]:
model_type = models.torchvision.mask_rcnn
backbone = model_type.backbones.resnext101_32x8d_fpn
model = model_type.model(
    backbone=backbone(pretrained=True),
    num_classes=len(classes_to_keep),
    box_nms_thresh=0.5,
    mask_roi_pool=MultiScaleRoIAlign(
        featmap_names=["0", "1", "2", "3"], output_size=14 * 4, sampling_ratio=2
    ),
)

In [None]:
# Regularization
wd = 0.01


# Ablation studies for aux channels
def triplicate(img, **params):
    img[..., :] = img[..., 0:1]
    return img


def sat_mask(img, **params):
    img[..., :] = img[..., 0:1]
    img[..., 2] = img[..., 2] != 0
    return img


def vessel_traffic(img, **params):
    img[..., 1] = img[..., 0]
    return img


def infra_distance(img, **params):
    img[..., 2] = img[..., 0]
    return img


def no_op(img, **params):
    return img


def get_tfms(
    memtile_size=memtile_size,
    rrctile_size=rrctile_size,
    reduced_resolution_tile_size=run_list[-1][0],
    scale_limit=0.05,
    rotate_limit=10,
    border_mode=0,  # cv2.BORDER_CONSTANT, use pad_fill_value
    pad_fill_value=[0, 0, 0],  # no_value
    mask_value=0,
    interpolation=0,  # cv2.INTER_NEAREST
    r_shift_limit=10,  # SAR Imagery
    g_shift_limit=0,  # Infrastructure Vicinity
    b_shift_limit=0,  # Vessel Density
):
    train_tfms = tfms.A.Adapter(
        [
            tfms.A.Flip(
                p=0.5,
            ),
            tfms.A.Affine(
                p=1,
                scale=(1 - scale_limit, 1 + scale_limit),
                rotate=[-rotate_limit, rotate_limit],
                interpolation=interpolation,
                mode=border_mode,
                cval=pad_fill_value,
                cval_mask=mask_value,
                fit_output=True,
            ),
            tfms.A.RandomSizedCrop(
                p=1,
                min_max_height=[rrctile_size, rrctile_size],
                height=reduced_resolution_tile_size,
                width=reduced_resolution_tile_size,
                w2h_ratio=1,
                interpolation=interpolation,
            ),
            tfms.A.RGBShift(
                p=1,
                r_shift_limit=r_shift_limit,
                g_shift_limit=g_shift_limit,
                b_shift_limit=b_shift_limit,
            ),
            tfms.A.Lambda(p=1, image=no_op),
        ]
    )
    valid_tfms = tfms.A.Adapter(
        [
            tfms.A.RandomSizedCrop(
                p=1,
                min_max_height=[rrctile_size, rrctile_size],
                height=reduced_resolution_tile_size,
                width=reduced_resolution_tile_size,
                w2h_ratio=1,
                interpolation=interpolation,
            ),
            tfms.A.Lambda(p=1, image=no_op),
        ]
    )

    return [train_tfms, valid_tfms]


In [None]:
# Datasets
mount_path = "/root"

# Parsing COCO Dataset with Icevision
json_name = "instances_TiledCeruleanDatasetV2.json"

train_set = f"train_tiles_context_{memtile_size}"
coco_json_path_train = f"{mount_path}/partitions/{train_set}/{json_name}"
tiled_images_folder_train = f"{mount_path}/partitions/{train_set}/tiled_images"

val_set = f"val_tiles_context_{rrctile_size}"
coco_json_path_val = f"{mount_path}/partitions/{val_set}/{json_name}"
tiled_images_folder_val = f"{mount_path}/partitions/{val_set}/tiled_images"

test_set = f"test_tiles_context_{rrctile_size}"
coco_json_path_test = f"{mount_path}/partitions/{test_set}/{json_name}"
tiled_images_folder_test = f"{mount_path}/partitions/{test_set}/tiled_images"

rrctrained_set = f"train_tiles_context_{rrctile_size}"
coco_json_path_rrctrained = f"{mount_path}/partitions/{rrctrained_set}/{json_name}"

tiled_images_folder_rrctrained = (
    f"{mount_path}/partitions/{rrctrained_set}/tiled_images"
)

In [None]:
record_collection_train = preprocess.load_set_record_collection(
    coco_json_path_train,
    tiled_images_folder_train,
    area_thresh,
    negative_sample_count_train,
    preprocess=True,
    classes_to_remap=classes_to_remap,
    classes_to_remove=classes_to_remove,
    classes_to_keep=classes_to_keep,
)

record_collection_val = preprocess.load_set_record_collection(
    coco_json_path_val,
    tiled_images_folder_val,
    area_thresh,
    negative_sample_count_val,
    preprocess=True,
    classes_to_remap=classes_to_remap,
    classes_to_remove=classes_to_remove,
    classes_to_keep=classes_to_keep,
)

record_collection_test = preprocess.load_set_record_collection(
    coco_json_path_test,
    tiled_images_folder_test,
    area_thresh,
    negative_sample_count_test,
    preprocess=True,
    classes_to_remap=classes_to_remap,
    classes_to_remove=classes_to_remove,
    classes_to_keep=classes_to_keep,
)

In [None]:
record_ids_train = coco_load_fastai.record_collection_to_record_ids(
    record_collection_train
)
record_ids_val = coco_load_fastai.record_collection_to_record_ids(record_collection_val)
record_ids_test = coco_load_fastai.record_collection_to_record_ids(
    record_collection_test
)

# Create name for model based on parameters above
model_name = f"{len(classes_to_keep)}cls_rnxt101_pr{run_list[-1][0]}_px{rrctile_size}_{sum([r[1] for r in run_list])}min"


In [None]:
type(record_collection_train)

In [None]:
from fastai.vision.all import *
from fastai.callback.fp16 import *
import torch
from tqdm import tqdm
# from torchsummary import summary
import json
import wandb
# from fastai.callback.wandb import WandbCallback

In [None]:
classes_to_keep

In [None]:
# from ceruleanml import data
# from ceruleanml import evaluation
# from ceruleanml import preprocess
# from fastai.data.block import DataBlock
# from fastai.vision.data import ImageBlock, MaskBlock
# from fastai.vision.augment import aug_transforms, Resize
# from fastai.vision.learner import unet_learner
# from fastai.data.transforms import IndexSplitter
# from fastai.metrics import DiceMulti, Dice, accuracy_multi, PrecisionMulti, RecallMulti
# from fastai.callback.fp16 import MixedPrecision
# # from fastai.callback.tensorboard import TensorBoardCallback
# from fastai.vision.core import PILImageBW
# from datetime import datetime
# from pathlib import Path
# import os, random
# from icevision.visualize import show_data
# import torch
# from fastai.callback.tracker import EarlyStoppingCallback, SaveModelCallback
# import skimage.io as skio
# import numpy as np
# from math import log

In [None]:

train_val_record_ids = record_ids_train + record_ids_val
# combined_record_collection = record_collection_with_negative_small_filtered_train + record_collection_with_negative_small_filtered_val
combined_record_collection = record_collection_train + record_collection_val
def get_val_indices(combined_ids, val_ids):
    return list(range(len(combined_ids)))[-len(val_ids):]

#show_data.show_records(random.choices(combined_train_records, k=9), ncols=3)

### Constructing a FastAI DataBlock that uses parsed COCO Dataset from icevision parser. aug_transforms can only be used with_context=True

val_indices = get_val_indices(train_val_record_ids, record_ids_val)

def get_image_by_record_id(record_id):
    return get_image_path(combined_record_collection, record_id)

def get_mask_by_record_id(record_id):
    return record_to_mask(combined_record_collection, record_id)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Fastai DataLoaders for grayscale images
path_seg = "/root/work/masked_tiles"  # Dataset's path
ParentSplitter_seg = FuncSplitter(lambda o: Path(o).parent.name == 'valid')
SAR_stats = [0.2087162, 0.13736105] # Calculated from the entire training dataset

codes = ['background', 'infrastructure', 'natural', 'vessel_coincident', 'vessel_recent', 'vessel_old', 'ambiguous']

# cbs_seg = [WandbCallback(log_model=True),TerminateOnNaNCallback(), GradientAccumulation(8), GradientClip(), SaveModelCallback(), ShowGraphCallback()]
cbs_seg = [TerminateOnNaNCallback(), GradientAccumulation(8), GradientClip(), SaveModelCallback(), ShowGraphCallback()]

#  ShortEpochCallback(pct=0.1, short_valid=False),
# EarlyStoppingCallback(min_delta=.001, patience=5)

In [None]:
size = 512
bs = 16

batch_transfms = [*aug_transforms(flip_vert=True, max_rotate=180, max_warp=0.1, size=size)]
coco_seg_dblock = DataBlock(
        blocks=(ImageBlock, MaskBlock(codes=data.class_list)), # ImageBlock is RGB by default, uses PIL
        get_x=get_image_by_record_id,
        splitter=IndexSplitter(val_indices),
        get_y=get_mask_by_record_id,
        batch_tfms=batch_transfms,
        item_tfms = Resize(size),
        n_inp=1
    )


dls = coco_seg_dblock.dataloaders(source=train_val_record_ids, batch_size=bs)

In [None]:
data.class_list

In [None]:
model=convnext_small()

In [None]:
body = create_body(model, 3, pretrained=True)
unet = DynamicUnet(body[0], n_out=7, img_size = (128,128))

In [None]:
summary(unet, (3,512,512))

In [None]:
loss_func = CrossEntropyLossFlat()

In [None]:
unet_learn = Learner(dls, unet, loss_func=loss_func, cbs=cbs_seg, lr=1e-3, wd=wd)

In [None]:
dls.show_batch()

In [None]:
unet_learn.show_results()

In [None]:
inputs, targets = unet_learn.dls.train.one_batch()

In [None]:
targets.shape