In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from fastai.vision.all import *
from fastai.callback.fp16 import *
import torch
import wandb
from fastai.callback.wandb import WandbCallback

from icevision.data import Dataset

In [None]:
from ceruleanml.data import class_list
from ceruleanml.learner_config import (
    memtile_size,
    run_list,
    final_px,
    classes_to_keep,
    get_tfms,
    wd,
    record_collection_train,
    record_collection_val,
    record_collection_test,
    record_ids_train,
    record_ids_val,
    record_ids_test,
    model_name,
    num_workers,
)
run_list

In [None]:
# wandb.init(project='cv3-experiments') # XXX Ethan figure out how to make this drop into a shared project? Figure out how to wandb.close(?) at the end of the run_list

In [None]:
bs_d ={512:8, 256:16, 224:16, 128:32, 64:64} # Batch Size for each image size
lr_d = {512:1e-3, 256:1e-3, 224:1e-3, 128:1e-3, 64:1e-3} # Learning Rate for each image size

In [None]:
model = convnext_small()
body = create_body(model, n_in=3, pretrained=True)
unet = DynamicUnet(body[0], n_out=7, img_size=(final_px, final_px))
loss_func = CrossEntropyLossFlat()

In [None]:
splitter = FuncSplitter(lambda o: "val" in str(o.filepath.parent.parent.stem))

def get_image_by_record_id(record):
    return record.img

def get_mask_by_record_id(record):
    return generate_flattened_mask_array(record)

def generate_flattened_mask_array(record):
    # Extract necessary information from the record
    string_labels = record.detection.labels
    if not string_labels:
        return np.zeros(record.common.img_size, dtype=np.uint8)
    masks = record.detection.mask_array.data
    class_map = record.detection.class_map
    labels = np.array([class_map.get_by_name(label) for label in string_labels], dtype=np.uint8)
    
    # Broadcast labels to match the shape of masks and compute the weighted masks
    weighted_masks = masks * labels[:, np.newaxis, np.newaxis]
    
    # Take the maximum along the first dimension
    flattened_mask = np.max(weighted_masks, axis=0)
    
    return flattened_mask

In [None]:
r = Dataset(record_collection_train)[6]
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
axs[0].imshow(get_image_by_record_id(r)[:,:,0])
axs[1].imshow(get_mask_by_record_id(r))
plt.show()

In [None]:
# train_tfms, val_tfms = get_tfms(reduced_resolution_tile_size=final_px) # XXX ETHAN distinguish which tfms are batch_tfms and which are item_tfms


In [None]:
cbs = [
    # WandbCallback(log_model=True),
    # ShortEpochCallback(pct=0.1, short_valid=True), 
    # EarlyStoppingCallback(min_delta=.001, patience=5), 
    # TerminateOnNaNCallback(), 
    # GradientAccumulation(8), 
    # GradientClip(), 
    # SaveModelCallback(), 
    # ShowGraphCallback(),
    # MixedPrecision(), # I'm used to this being a .to_fp() on a learner, rather than a callback???
    ]

In [None]:
seg_dblock = DataBlock(
        blocks=(ImageBlock, MaskBlock(codes=class_list)), # ImageBlock is RGB by default, uses PIL
        getters=[get_image_by_record_id, get_mask_by_record_id],
        splitter=splitter,
        batch_tfms=[*aug_transforms(), Normalize.from_stats(*imagenet_stats)]
    )
dls = seg_dblock.dataloaders(Dataset(record_collection_train[:10]+record_collection_val[:10]), batch_size=bs_d[final_px], verbose=True)
dls.show_batch()

In [None]:
learner = Learner(dls=dls, model=unet, loss_func=loss_func, cbs=cbs, lr=lr_d[final_px], wd=wd)

In [None]:
learner.show_results()

In [None]:
learner.fit_one_cycle(50) # XXX TARGET FOR TONIGHT

In [None]:
start_new = True
load_model_name = False

if start_new:
    print("Starting from scratch")
    learner.save("model")
elif load_model_name:
    print(f"Loading {load_model_name}")
    learner.load(load_model_name)
    learner.save("model")
else:
    print("Continuing current training session")
    learner.load("model")
    # export_scripted_model(learner)

In [None]:
from datetime import datetime
from ceruleanml.inference import save_fastai_model_state_dict_and_tracing

def export_scripted_model(learner, model_name):
    dateTimeObj = datetime.now()
    timestampStr = dateTimeObj.strftime("%Y_%m_%d_%H_%M_%S")
    experiment_dir =  Path(f'/root/experiments/cv3/{timestampStr}_{model_name}_unet/')
    experiment_dir.mkdir(exist_ok=True)
    print(experiment_dir)
    save_template = "model.pt"
    save_fastai_model_state_dict_and_tracing(learner, save_template, experiment_dir) # XXX Ethan need to check this swap works

In [None]:
running_total_epochs = {}

for size, epochs in run_list:
    # train_tfms, val_tfms = get_tfms(reduced_resolution_tile_size=size) # XXX ETHAN distinguish which tfms are batch and which are item
    # seg_dblock = DataBlock(
    #         blocks=(ImageBlock, MaskBlock(codes=class_list)), # ImageBlock is RGB by default, uses PIL
    #         n_inp=1,
    #         splitter=splitter,
    #         get_x=get_image_by_record_id,
    #         get_y=get_mask_by_record_id,
    #         batch_tfms=batch_tfms,
    #         item_tfms=item_tfms, # XXX ETHAN if you find a cheaper easier way to push new TFMS into an existing DLS, then we don't need to fully recreate these here... learner.dls.add_tfms()
    #     )

    # learner.dls = seg_dblock.dataloaders(source=train_val_record_ids, batch_size=bs_d[size])
    learner.dls.add_tfms(.......) # XXX ETHAN to explore?
    print("PR: Starting from running total", running_total_epochs)
    print("PR: image size", size)
    print("PR: epochs", epochs)

    learner.fine_tune(epochs, lr_d[size], freeze_epochs=0)

    running_total_epochs[size] = sum(filter(None,[running_total_epochs.get(size),epochs]))
    learner.save(model_name)
    export_scripted_model(learner, model_name)

torch.cuda.empty_cache()

In [None]:
learner.show_results()

In [None]:
inputs, targets = learner.dls.train.one_batch()

In [None]:
targets.shape