In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from fastai.vision.all import *
from fastai.callback.fp16 import *
import torch
from tqdm import tqdm
import json
import wandb
from fastai.callback.wandb import WandbCallback

In [None]:
from ceruleanml.data import class_list
from ceruleanml.learner_config import (
    run_list,
    final_px,
    classes_to_keep,
    get_tfms,
    wd,
    record_collection_train,
    record_collection_val,
    record_collection_test,
    model_name,
    num_workers,
)
run_list

In [None]:
wandb.init(project='cv3-experiments')

In [None]:
bs_d ={512:8, 256:16, 224:16, 128:32, 64:64} # Batch Size for each image size
lr_d = {512:1e-3, 256:1e-3, 224:1e-3, 128:1e-3, 64:1e-3} # Learning Rate for each image size

In [None]:
model = convnext_small()
body = create_body(model, n_in=3, pretrained=True)
unet = DynamicUnet(body[0], n_out=7, img_size=(final_px, final_px))
loss_func = CrossEntropyLossFlat()

In [None]:
def get_image_by_record_id(record_id):
    return get_image_path(combined_record_collection, record_id)

def get_mask_by_record_id(record_id):
    return record_to_mask(combined_record_collection, record_id)

train_tfms, val_tfms = get_tfms(reduced_resolution_tile_size=final_px) # XXX ETHAN distinguish which tfms are batch and which are item

splitter = FuncSplitter(lambda o: Path(o).parent.name == 'valid'), # XXX ETHAN would prefer to use a funcsplitter if possible

### 

coco_seg_dblock = DataBlock(
        blocks=(ImageBlock, MaskBlock(codes=class_list)), # ImageBlock is RGB by default, uses PIL
        n_inp=1,
        splitter=splitter,
        get_x=get_image_by_record_id,
        get_y=get_mask_by_record_id,
        batch_tfms=batch_tfms,
        item_tfms=item_tfms,
    )


dls = coco_seg_dblock.dataloaders(source=train_val_record_ids, batch_size=bs_d[final_px])
dls.show_batch()

In [None]:
dls = get_dataloaders(model_type, [record_collection_train, record_collection_val], get_tfms(), batch_size=bs_d[final_px])
cbs = [
    WandbCallback(log_model=True),
    # ShortEpochCallback(pct=0.1, short_valid=True), 
    # EarlyStoppingCallback(min_delta=.001, patience=5), 
    # TerminateOnNaNCallback(), 
    # GradientAccumulation(8), 
    # GradientClip(), 
    # SaveModelCallback(), 
    # ShowGraphCallback(),
    # MixedPrecision(),
    ]

In [None]:
learner = Learner(dls=dls, model=unet, loss_func=loss_func, cbs=cbs, lr=lr_d[final_px], wd=wd)

In [None]:
start_new = True
load_model_name = False

if start_new:
    print("Starting from scratch")
    learner.save("model")
elif load_model_name:
    print(f"Loading {load_model_name}")
    learner.load(load_model_name)
    learner.save("model")
else:
    print("Continuing current training session")
    learner.load("model")
    # export_scripted_model(learner)

In [None]:
from datetime import datetime
# from ceruleanml.inference import save_icevision_model_state_dict_and_tracing

def export_scripted_model(learner, model_name):
    dateTimeObj = datetime.now()
    timestampStr = dateTimeObj.strftime("%Y_%m_%d_%H_%M_%S")
    experiment_dir =  Path(f'/root/experiments/cv3/{timestampStr}_{model_name}_unet/')
    experiment_dir.mkdir(exist_ok=True)
    print(experiment_dir)
    save_template = "model.pt"
    save_icevision_model_state_dict_and_tracing(learner, save_template, experiment_dir) # XXX Ethan need to change this to be the proper export for FastaiUnet

In [None]:
running_total_epochs = {}

for size, epochs in run_list:
    train_tfms, val_tfms = get_tfms(reduced_resolution_tile_size=size) # XXX ETHAN distinguish which tfms are batch and which are item
    coco_seg_dblock = DataBlock(
            blocks=(ImageBlock, MaskBlock(codes=class_list)), # ImageBlock is RGB by default, uses PIL
            n_inp=1,
            splitter=splitter,
            get_x=get_image_by_record_id,
            get_y=get_mask_by_record_id,
            batch_tfms=batch_tfms,
            item_tfms=item_tfms,
        )

    learner.dls = coco_seg_dblock.dataloaders(source=train_val_record_ids, batch_size=bs_d[size])
    print("PR: Starting from running total", running_total_epochs)
    print("PR: image size", size)
    print("PR: epochs", epochs)

    learner.fine_tune(epochs, lr_d[size], freeze_epochs=0) # cbs=cbs

    running_total_epochs[size] = sum(filter(None,[running_total_epochs.get(size),epochs]))
    learner.save(model_name)
    export_scripted_model(learner, model_name)

torch.cuda.empty_cache()

In [None]:
learner.show_results()

In [None]:
# train_val_record_ids = record_ids_train + record_ids_val
# # combined_record_collection = record_collection_with_negative_small_filtered_train + record_collection_with_negative_small_filtered_val
# combined_record_collection = record_collection_train + record_collection_val
# def get_val_indices(combined_ids, val_ids):
#     return list(range(len(combined_ids)))[-len(val_ids):]

# #show_data.show_records(random.choices(combined_train_records, k=9), ncols=3)

# ### Constructing a FastAI DataBlock that uses parsed COCO Dataset from icevision parser. aug_transforms can only be used with_context=True

# val_indices = get_val_indices(train_val_record_ids, record_ids_val)

In [None]:
inputs, targets = learner.dls.train.one_batch()

In [None]:
targets.shape