In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from fastai.vision.all import *
from fastai.callback.fp16 import *
import torch
import wandb
from fastai.callback.wandb import WandbCallback

from icevision.data import Dataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert torch.cuda.is_available(), "WARNING: No GPU currently available"

In [None]:
from ceruleanml.data import class_list
from ceruleanml.learner_config import (
    memtile_size,
    rrctile_size,
    run_list,
    final_px,
    classes_to_keep,
    get_tfms,
    wd,
    record_collection_train,
    record_collection_val,
    record_collection_test,
    model_name,
    num_workers,
    model_type,
    aux_layers
)
run_list

In [None]:
config = {
    'memtile_size': memtile_size,
    'rrctile_size': rrctile_size,
    'run_list': run_list,
    'final_px': final_px,
    'classes_to_keep': classes_to_keep,
    'weight_decay': wd,
    'num_workers': num_workers,
    'train_record_count': len(record_collection_train),
    'val_record_count': len(record_collection_val),
    'test_record_count': len(record_collection_test),
    'model_type': model_type,
    'aux_layers': aux_layers,
    'model_name': model_name,
}

wandb.init(project='cv3', entity="skytruth", config=config, name=model_name)

In [None]:
bs_d ={512:8, 256:16, 224:16, 128:32, 64:64} # Batch Size for each image size
lr_d = {512:1e-3, 256:1e-3, 224:1e-3, 128:1e-3, 64:1e-3} # Learning Rate for each image size
model_dict = {"resnet18": resnet18, "resnet34": resnet34, "convnext_small":convnext_small, "convnext_large":convnext_large}

In [None]:
splitter = FuncSplitter(lambda o: "val" in str(o.filepath.parent.parent.stem))

def get_image(record):
    if len(aux_layers)==3:
        return record.load().img
    elif len(aux_layers)==1:
        return record.load().img.split()[0].convert('L')
    else:
        raise("Layer count not Supported")

def get_mask(record):
    return generate_flattened_mask_array(record.load())

def generate_flattened_mask_array(record):
    # Extract necessary information from the record
    string_labels = record.detection.labels
    if not string_labels:
        return np.zeros(record.common.img_size, dtype=np.uint8)
    class_map = record.detection.class_map
    labels = np.array([class_map.get_by_name(label) for label in string_labels], dtype=np.uint8)
    masks = record.detection.mask_array.data
    
    # Broadcast labels to match the shape of masks and compute the weighted masks
    weighted_masks = masks * labels[:, np.newaxis, np.newaxis]
    
    # Take the maximum along the first dimension
    flattened_mask = np.max(weighted_masks, axis=0)
    
    return flattened_mask

In [None]:
r = Dataset(record_collection_train)[6]
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
axs[0].imshow(get_image(r))
axs[1].imshow(get_mask(r))
plt.show()

In [None]:
cbs = [
    WandbCallback(log_model=True),
    # ShortEpochCallback(pct=0.1, short_valid=True), 
    # EarlyStoppingCallback(min_delta=.0001, patience=20),
    TerminateOnNaNCallback(), 
    GradientAccumulation(256), 
    GradientClip(), 
    SaveModelCallback(), 
    ShowGraphCallback(),
    MixedPrecision(), # I'm used to this being a .to_fp() on a learner, rather than a callback???
    ]

In [None]:
def get_seg_dls(size):
    PIL_type = PILImageBW if len(aux_layers)==1 else PILImage
    batch_tfms, item_tfms = get_tfms(reduced_resolution_tile_size = size)
    seg_dblock = DataBlock(
            blocks=(ImageBlock(PIL_type), MaskBlock(codes=classes_to_keep)), # ImageBlock is RGB by default, uses PIL
            getters=[get_image, get_mask],
            splitter=splitter,
            batch_tfms=[*batch_tfms],
            item_tfms=[*item_tfms]
        )
    dls = seg_dblock.dataloaders(source=[r for r in record_collection_train+record_collection_val], batch_size=bs_d[size], verbose=False).to(device)
    return dls

dls = get_seg_dls(final_px)
# dls.show_batch()

In [None]:
dls = get_seg_dls(final_px)
dls.show_batch()

In [None]:
loss_func = CrossEntropyLossFlat(axis=1)

In [None]:
body = create_body(model_dict[model_type](), n_in=len(aux_layers), pretrained=True)
body = body[0] if 'convnext' in model_type else body #when using convnext models use body[0]
model = DynamicUnet(body, n_out=len(classes_to_keep), img_size=(final_px, final_px)) 

In [None]:
learner = Learner(dls=dls, model=model, loss_func=loss_func, cbs=cbs, lr=lr_d[final_px], wd=wd, metrics=[DiceMulti(), foreground_acc])
learner.to(device);

In [None]:
start_new = True
load_model_name = False

if start_new:
    print("Starting from scratch")
    learner.save("model")
elif load_model_name:
    print(f"Loading {load_model_name}")
    learner.load(load_model_name)
    learner.save("model")
else:
    print("Continuing current training session")
    learner.load("model")
    # export_scripted_model(learner)

In [None]:
from datetime import datetime
from ceruleanml.inference import save_fastai_model_state_dict_and_tracing

def export_scripted_model(learner, model_name):
    if not os.path.exists("/root/experiments/cv3/"):
        os.makedirs("/root/experiments/cv3/")
    dateTimeObj = datetime.now()
    timestampStr = dateTimeObj.strftime("%Y_%m_%d_%H_%M_%S")
    experiment_dir =  Path(f'/root/experiments/cv3/{timestampStr}_{model_name}_unet/')
    experiment_dir.mkdir(exist_ok=True)
    print(experiment_dir)
    save_fastai_model_state_dict_and_tracing(learner, learner.dls, model_name, experiment_dir) # XXX Ethan need to check this swap works

In [None]:
def set_encoder_state(learner,frozen=False):
    state = "Unfreezing" if not frozen else "Freezing"
    encoder_layers = learner.model[0]
    num_params = sum(p.numel() for p in encoder_layers.parameters())
    full_num_params = sum(p.numel() for p in learner.model.parameters())

    print(state, num_params, 'encoder parameters out of', full_num_params,'total parameters')
    for param in encoder_layers.parameters():
        param.requires_grad = not frozen

In [None]:
running_total_epochs = {}
for size, epochs, is_frozen in run_list:
    print("PR: Starting from running total", running_total_epochs)
    print("PR: image size", size)
    print("PR: epochs", epochs)
    print("PR: encoder is", is_frozen)
    
    frozen = is_frozen == 'frozen'
    learner.dls = get_seg_dls(size)
    set_encoder_state(learner, frozen=frozen)
    learner.fit_one_cycle(epochs)

    running_total_epochs[size] = sum(filter(None,[running_total_epochs.get(size),epochs]))
    learner.save(model_name)
    export_scripted_model(learner, model_name)

torch.cuda.empty_cache()

In [None]:
learner.show_results()

In [None]:
# wandb.finish()