In [None]:
from fastai.vision.all import *
from fastai.callback.fp16 import *
import torch
from tqdm import tqdm
from torchsummary import summary
import json
import wandb
from fastai.callback.wandb import WandbCallback
import random
from pycocotools import mask
import matplotlib.pyplot as plt
import rasterio

In [None]:
def grab_tile_files_from_partition(data, percent_empty):
    masked_image_ids = list(set([ann['image_id'] for ann in data['annotations']]))
    masked_image_files = [im['file_name'] for im in data['images'] if im['id'] in masked_image_ids]
    unmasked_image_files = [im['file_name'] for im in data['images'] if not (im['id'] in masked_image_ids)]

    masked_num = len(masked_image_files)
    # unmasked_num = len(unmasked_image_files)

    num_empty = int(percent_empty*masked_num / (1-percent_empty))

    #num_empty = int(percent_empty*unmasked_num)

    print("Returning", masked_num, "Masked Tile Files and", num_empty, "Unmasked Tile Files")

    return [masked_image_files , random.sample(unmasked_image_files, num_empty)]

def get_masks_labels_for_file(file, ds):
  
  id_map = {}
  for im in ds['images']:
    id_map[im['file_name']] = im['id']
  labels = []
  masks = []
  for ann in ds['annotations']:
    if ann['image_id'] == id_map[file]:
      print(ann.keys())
      labels.append(ann['category_id'])
      masks.append(ann['segmentation'])
  return masks, labels

def grab_and_download_tile_and_mask(tile_location,save_location, file, ds):
  if not os.path.exists(save_location):
      os.makedirs(save_location)
  masks,labels = get_masks_labels_for_file(file, ds)
  mask_canvas = np.zeros(shape= (1024,1024))
  for i,m in enumerate(masks):
    label = labels[i]
    rle = mask.frPyObjects(m, 1024, 1024)
    mask_canvas+=mask.decode(rle)*label
  
  with rasterio.open(tile_location+file) as src:
    # Read the first band
    band1 = src.read(1)

  # Normalize the band data to the range 0-255
  band1_normalized = (band1).astype(np.uint8)

  # return band1_normalized, mask_canvas
  # Create an Image object from the normalized data
  vv = Image.fromarray(band1_normalized)
  mask_im = Image.fromarray(mask_canvas.astype(np.uint8))

  vv.save(save_location + file[:-4] + ".png")
  mask_im.save(save_location + file[:-4] + "_mask.png")




In [None]:
split_path = "/root/partitions/train_tiles_context_1024/"
with open(split_path+'instances_TiledCeruleanDatasetV2.json') as f:
    data = json.load(f)

In [None]:
masked_files, unmasked_files = grab_tile_files_from_partition(data, .1)

In [None]:
m = get_masks_labels_for_file(masked_files[0], data)

In [None]:
masked_files[0][:-4]

In [None]:
grab_and_download_tile_and_mask("/root/partitions/train_tiles_context_1024/tiled_images/",
                                "/root/partitions/tiles_context_1024_png/train/",
                                unmasked_files[0], 
                                data)

In [None]:
os.chdir("/root")
maskerino = Image.open("partitions/tiles_context_1024_png/train/S1A_IW_GRDH_1SDV_20210323T142456_20210323T142521_037127_045F0F_F358_vv-image_local_tile_9_mask.png")

In [None]:
maskerino = Image.fromarray(np.array(maskerino)*255)

In [None]:
maskerino

In [None]:
split_path = "/root/partitions/train_tiles_context_1024/"
with open(split_path+'/instances_TiledCeruleanDatasetV2.json') as f:
    data = json.load(f)
masked_files, unmasked_files = grab_tile_files_from_partition(data, .1)

for masked_file in masked_files:
    grab_and_download_tile_and_mask("/root/partitions/train_tiles_context_1024/tiled_images/",
                                "/root/partitions/tiles_context_1024_png/train/",
                                masked_file, 
                                data)
for unmasked_file in unmasked_files:
    grab_and_download_tile_and_mask("/root/partitions/train_tiles_context_1024/tiled_images/",
                            "/root/partitions/tiles_context_1024_png/train/",
                            unmasked_file, 
                            data)

In [None]:
os.listdir("/root/partitions/val_tiles_context_1024")

In [None]:
split_path = "/root/partitions/val_tiles_context_1024/\r"
with open(split_path+'/instances_TiledCeruleanDatasetV2.json') as f:
    data = json.load(f)
masked_files, unmasked_files = grab_tile_files_from_partition(data, .1)

for masked_file in masked_files:
    grab_and_download_tile_and_mask("/root/partitions/val_tiles_context_1024/\r/tiled_images/",
                                "/root/partitions/tiles_context_1024_png/valid/",
                                masked_file, 
                                data)

# for unmasked_file in unmasked_files:
#     grab_and_download_tile_and_mask("/root/partitions/val_tiles_context_1024/tiled_images/",
#                             "/root/partitions/tiles_context_1024_png/train/",
#                             unmasked_file, 
#                             data)


In [None]:
learner_config = json.load(open("/root/work/starter_learner_config.json"))

Model / Learner params

In [None]:
freeze_epochs =  learner_config["freeze_epochs"] #
unfreeze_epochs = learner_config["unfreeze_epochs"] #
loss_func_name = learner_config["loss_func"] #
w_d = learner_config["w_d"] #
backbone = learner_config["backbone"] #
final_resize_resolution = learner_config["final_resize_resolution"] #
num_of_workers = learner_config["num_of_workers"] #
progressive_resizing = learner_config["progressive_resizing"] #
batch_size = learner_config["batch_size"] #

Data Params (Not Relevent to this script)

In [None]:
classes = learner_config["classes"]
area_thresh = learner_config["area_thresh"]
negative_sample_count_train = learner_config["negative_sample_count_train"]
aux_channels = learner_config["aux_channels"]
scale_limit = learner_config["scale_limit"]
rotate_limit = learner_config["rotate_limit"]
r_g_b_shift_limit = learner_config["r/g/b_shift_limit"]


In [None]:
class NormalizeAndClamp(Transform):
    def __init__(self, mean, std, min_val=-3, max_val=3):
        self.mean = mean
        self.std = std
        self.min_val = min_val
        self.max_val = max_val

    def encodes(self, x):
        x = (x - self.mean) / self.std
        x = x.clamp(min=self.min_val, max=self.max_val)
        return x

def get_image_files_no_masks(path, pct_negative=0.0):
    """Collect image files that don't end with '_mask.png'."""
    empty_masks = []
    non_empty_masks = []

    all_files = get_image_files(path)
    masks = [f for f in all_files if f.name.endswith("_mask.png")]

    for m in tqdm(masks):
        with Image.open(m) as img:
            if not img.getbbox():
                # This is a quick check to see if the image is completely black or transparent.
                empty_masks.append(m)
            else:
                non_empty_masks.append(m)
    empty_count = int(pct_negative * len(empty_masks))
    print(f"Training on {len(non_empty_masks)} non-empty masks, and {empty_count} empty masks")
    fileset = non_empty_masks + empty_masks[:empty_count]
    return [str(f).replace('_mask', '') for f in fileset]

def get_y_seg(x_path):
    return x_path.replace('.png', '_mask.png')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Fastai DataLoaders for grayscale images
path_seg = "/root/partitions/tiles_context_1024_png"  # Dataset's path
ParentSplitter_seg = FuncSplitter(lambda o: Path(o).parent.name == 'valid')
SAR_stats = [0.2087162, 0.13736105] # Calculated from the entire training dataset

codes = ['background', 'infrastructure', 'natural', 'vessel_coincident', 'vessel_recent', 'vessel_old', 'ambiguous']

cbs_seg = [TerminateOnNaNCallback(), GradientAccumulation(8), GradientClip(), SaveModelCallback(), ShowGraphCallback()]
#  ShortEpochCallback(pct=0.1, short_valid=False),
# EarlyStoppingCallback(min_delta=.001, patience=5)



In [None]:
def get_dataloader_seg(SIZE, resize, batch_size, num_of_workers):

    # Define a DataBlock and dataloaders
    datablock_seg = DataBlock(
        blocks= [ImageBlock(cls=PILImageBW), MaskBlock(codes = codes)], # Use ImageBlock twice for autoencoder, or (ImageBlock, Maskblock) segmentation etc.
        get_items=get_image_files_no_masks,
        splitter=ParentSplitter_seg,  # Split based on folder names
        get_y=get_y_seg,  # For autoencoders, where input x is the target y. Adjust if necessary.
        batch_tfms = aug_transforms(do_flip=True, flip_vert=True, pad_mode=PadMode.Zeros, size=(SIZE, SIZE)),
        item_tfms = Resize(resize)
    )

    dls_seg = datablock_seg.dataloaders(path_seg, path=path_seg, bs=batch_size, num_workers=num_of_workers).to(device)
    dls_seg.after_item = Pipeline([ToTensor(), IntToFloatTensor()])
    dls_seg.after_batch = Pipeline([NormalizeAndClamp(*SAR_stats)])
    return dls_seg

#learn_baseline=unet_learner(dls=dls_seg, arch=resnet101, loss_func=loss_func, cbs=cbs_seg, n_in=1, lr=1e-3, wd=1e-3) #

In [None]:
def build_configured_learner(dls_seg, backbone=backbone, loss_func_name=loss_func_name, cbs_seg=cbs_seg, lr=1e-3, wd=w_d):
    loss_func = globals()[loss_func_name](ignore_index=6, axis=1)
    if "convnext" in backbone:
        model_func = globals()[backbone]
        body = create_body(model_func(), 1, pretrained=True)
        unet = DynamicUnet(body[0], n_out=7, img_size = (128,128))
    return Learner(dls_seg, unet, loss_func=loss_func, cbs=cbs_seg, lr=lr, wd=wd)

In [None]:
def train_loop(freeze_epochs,unfreeze_epochs, progressive_resizing=progressive_resizing):
    progressive_dls = []
    for resize in progressive_resizing:
        progressive_dls.append(get_dataloader_seg(final_resize_resolution, resize, batch_size, num_of_workers))
    #freeze training
    learner = build_configured_learner(progressive_dls[0])

    for i,dls in enumerate(progressive_dls):
        learner.dls = dls
        fr_ep = int(freeze_epochs/len(progressive_dls))

        print("Freezing encoder and training for", fr_ep, "epochs at ", progressive_resizing[i], "input size..,")
        for param in learner.model[0].parameters():
            param.requires_grad = False
        learner.fit_one_cycle(fr_ep)

    for i,dls in enumerate(progressive_dls):
        learner.dls = dls
        un_ep = int(unfreeze_epochs/len(progressive_dls))
        print("Unfreezing encoder and training for", un_ep, "epochs at ", progressive_resizing[i], "input size...")
        for param in learner.model[0].parameters():
            param.requires_grad = True
        learner.fit_one_cycle(un_ep)

    

In [None]:
train_loop(freeze_epochs=2, unfreeze_epochs=2)

In [None]:
dls = get_dataloader_seg(final_resize_resolution, 512, batch_size, num_of_workers)

In [None]:
dls.show_batch()