In [None]:
# Libraries
from fastai.vision.all import *
from fastcore.xtras import Path

# from fastai.callback.hook import summary
# from fastai.callback.progress import ProgressCallback
# from fastai.callback.schedule import lr_find, fit_flat_cos

# from fastai.data.block import DataBlock
# from fastai.data.external import untar_data, URLs
# from fastai.data.transforms import get_image_files, Normalize, FuncSplitter

# from fastai.losses import BaseLoss
# from fastai.layers import Mish
# from fastai.optimizer import ranger

# from fastai.torch_core import tensor

# from fastai.vision.augment import aug_transforms
# from fastai.vision.core import PILImage, PILMask, Image

# from fastai.vision.data import ImageBlock, MaskBlock, imagenet_stats
# from fastai.vision.learner import unet_learner

# from PIL import Image
import numpy as np

from torch import nn
# from torchvision.transforms import ToPILImage
# from torchvision.models.resnet import resnet34

import torch
import torch.nn.functional as F

from numba import jit, njit, prange

from tqdm import tqdm

import cv2

In [None]:
from utils import * # own utilities script

## Image representation

In [None]:
data_path = Path('../data/processed/')
train_data, test_data, samples = get_pkl_data(data_path)

In [None]:
# function that attaches four frames in one single image for the entire video
def make_three_channel(video):
    vid_sz = video.shape[0:2]
    num_frames = video.shape[2]

    framelist = []
    for f in range(num_frames):
        frame = np.zeros((vid_sz[0], vid_sz[1], 3), dtype=video.dtype)
        frame[:,:,0] = video[:, :, f]
        if f+1 < num_frames:
            frame[:,:,1] = video[:, :, f+1]
        if f+2 < num_frames:
            frame[:,:,2] = video[:, :, f+2]
        framelist.append(frame)

    frames = np.stack(framelist, axis=3)  
    return frames

# function to transform data to trichannel frames
def trichannel_frames(data):
    with tqdm(total=len(data)) as pbar:
        for id in range(len(data)):
            video = make_three_channel(data[id]['video'])
            data[id]['video'] = video
            pbar.update(1)
    return data

def sharpen_and_brighten(video):
    # Assuming `video` is your 3D array of frames
    sharpened_video = np.empty_like(video)

    # Define the Laplacian sharpening kernel
    kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]], dtype=np.float32)

    for i in range(video.shape[2]):
        frame = video[:, :, i]
        
        # Apply the sharpening filter
        sharpened_video[:, :, i] = cv2.equalizeHist(cv2.filter2D(frame, -1, kernel))

    return sharpened_video

def brighten(video):
    # Assuming `video` is your 3D array of frames
    brightened_video = np.empty_like(video)

    for i in range(video.shape[2]):
        frame = video[:, :, i]
        
        # Apply the sharpening filter
        brightened_video[:, :, i] = cv2.equalizeHist(frame)

    return brightened_video

def normalize_light_data(data):
    with tqdm(total=len(data)) as pbar:
        for id in range(len(data)):
            video = brighten(data[id]['video'])
            data[id]['video'] = video
            pbar.update(1)
    return data

In [None]:
train_data = trichannel_frames(train_data)
test_data = trichannel_frames(test_data)


In [None]:
from tqdm import tqdm

# function that attaches four frames in one single image for the entire video
def attach_four_frames(vid_list, train=True):

    video = vid_list[0]
    if train:
        videolab = vid_list[1]
        
    vid_sz = video.shape[0:2]
    num_frames = video.shape[2]

    framelist = []
    with tqdm(total=num_frames+train*num_frames) as pbar:
        for f in range(num_frames):
            frame = np.zeros((vid_sz[0]*2, vid_sz[1]*2), dtype=video.dtype)
            frame[0:vid_sz[0], 0:vid_sz[1]] = video[:, :, f]
            if f+1 < num_frames:
                frame[0:vid_sz[0], vid_sz[1]:2*vid_sz[1]] = video[:, :, f+1]
            if f+2 < num_frames:
                frame[vid_sz[0]:2*vid_sz[0], 0:vid_sz[1]] = video[:, :, f+2]
            if f+3 < num_frames:
                frame[vid_sz[0]:2*vid_sz[0], vid_sz[1]:2*vid_sz[1]] = video[:, :, f+3]
            framelist.append(frame)
            pbar.update(1)
            
        frames = np.dstack(framelist)
        
        if train:
            labellist = []
            for f in range(num_frames):
                label = np.zeros((vid_sz[0]*2, vid_sz[1]*2), dtype=videolab.dtype)
                label[0:vid_sz[0], 0:vid_sz[1]] = videolab[:, :, f]            
                labellist.append(label)
                pbar.update(1)
            
            labels = np.dstack(labellist)
            return frames, labels
        else:
            return frames

# function to transform data to quad frames
def quad_frames(data, train):
    with tqdm(total=len(data)) as pbar:
        for id in range(len(data)):
            if train:
                frames, labels = attach_four_frames([data[id]['video'], data[id]['label']], train=train)
                data[id]['label'] = labels
            else:
                frames = attach_four_frames([data[id]['video']], train=train)
            data[id]['video'] = frames
            pbar.update(1)
    return data


Make png scan files.

In [None]:
save_train_test_pngs(train_data, test_data, data_path, trichannel=True)

Make png label files.

In [None]:
save_label_pngs(train_data, data_path)

Now write a function that splits the training data into a training and validation set. You can use the `random.sample` function for this. Split along ids, and allow an option to include amateur data.

In [None]:
train_fnames, valid_fnames = get_sample_split_txt(0.1, 'full', data_path, 420)

## Dataloaders

In [None]:
from fastai.vision.all import *
from utils import *

data_path = Path('../data/processed/')
path_im = data_path/'train'/'scans'
path_lbl = data_path/'train'/'labels'
fnames = get_image_files(path_im)
lbl_names = get_image_files(path_lbl)

# provide path to an image --> returns path to the mask
get_mask = lambda o: path_lbl/f'{o.stem}_lab{o.suffix}'

# codes for each segmentation class
codes = np.array(['BG', 'MV'])

In [None]:
# check out an image
PILImage.create(fnames[0]).show(figsize=(2,2), title="Scan")

# check out a label
print("The mask for", fnames[0], "is:\n", get_mask(fnames[0]))

msk = PILMask.create(get_mask(fnames[0]))
msk.show(figsize=(2,2), alpha=1, title="Mask")

In [None]:
# try FileSplitter function
FileSplitter(data_path/'train'/'vld_expert.txt')(fnames) # example

In [None]:
# get the resolution of an expert and an amateur frame
res_exp, res_am = get_resolution(fnames, data_path)
half_res = (int(res_exp[0]/2), int(res_exp[1]/2)); 
double_am = (int(res_am[0]*2), int(res_am[1]*2))
trip_am = (int(res_am[0]*3), int(res_am[1]*3))

In [None]:
mvscans = DataBlock(blocks=(ImageBlock, MaskBlock(codes)), # blocks for segmentation
                    get_items=get_image_files, # how to get the files: use function
                    splitter=FileSplitter(data_path/'train'/'vld_full.txt'), # function to split the files
                    get_y=get_mask,
                    item_tfms=Resize(double_am),
                    batch_tfms=[*aug_transforms(size=double_am, 
                                                do_flip=False, 
                                                max_rotate=0.,
                                                max_zoom=1.0,
                                                max_warp=0.,
                                                p_affine=0.), Normalize.from_stats(*imagenet_stats)]
                    )

dls = mvscans.dataloaders(path_im, bs=4)

Show a batch of images and labels.

In [None]:
dls.show_batch(max_n = 4, vmin=0, vmax=1, figsize=(4,4))

In [None]:
amid = 0

In [None]:
_,axs = plt.subplots(1,3, figsize=(12,3))
amnames = fnames.filter(lambda x: 'am' in x.stem).map(lambda x: x.stem).filter(lambda x: f'am_{amid}_' in x)

PILImage.create("../data/processed/train/scans/"+amnames[0]+".png").show( title=amnames[0], ctx=axs[0])
PILMask.create("../data/processed/train/labels/"+amnames[0]+"_lab.png").show( ctx=axs[0])
PILImage.create("../data/processed/train/scans/"+amnames[1]+".png").show( title=amnames[1], ctx=axs[1])
PILMask.create("../data/processed/train/labels/"+amnames[1]+"_lab.png").show( ctx=axs[1])
PILImage.create("../data/processed/train/scans/"+amnames[2]+".png").show( title=amnames[2], ctx=axs[2])
PILMask.create("../data/processed/train/labels/"+amnames[2]+"_lab.png").show( ctx=axs[2])

print(amnames)

amid += 1

In [None]:
dls.vocab = codes
name2id = {v:k for k,v in enumerate(codes)}; name2id

In [None]:
opt = ranger
learn = unet_learner(
    dls, # dataloaders
    resnet34, # architecture
    metrics=acc_camvid, # metric
    self_attention=True,
    act_cls=Mish,
    opt_func=opt
)

learn.summary()

Find optimal learning rate.

In [None]:
learn.lr_find()

Fit the model for 10 epochs.

In [None]:
lr = 1e-3 # set learning rate
learn.fit_flat_cos(10, slice(lr))

Save model.

In [None]:
path_models = Path('../out/models/')
path_models.mkdir(exist_ok=True, parents=True)
learn.save('../../out/models/full_sample_extended_tric_224_10ep')

## Inference

Let's look at some results. They are not so bad!

In [None]:
learn.show_results(max_n=4, figsize=(6,12))

## Full Size Training

Restart the kernel.

In [None]:
from fastai.vision.all import *
from utils import *

data_path = Path('../data/processed/')
path_im = data_path/'train'/'scans'
path_lbl = data_path/'train'/'labels'
fnames = get_image_files(path_im)
lbl_names = get_image_files(path_lbl)

# provide path to an image --> returns path to the mask
get_mask = lambda o: path_lbl/f'{o.stem}_lab{o.suffix}'

# codes for each segmentation class
codes = np.array(['BG', 'MV'])

# validation filesnames
valid_fnames = (data_path/'train'/'vld_full.txt').read_text().split('\n')

res_exp, res_am = get_resolution(fnames, data_path)

In [None]:
half_res = (int(res_exp[1]/2), int(res_exp[0]/2))

Make dataloaders with full size images.

In [None]:
mvscans = DataBlock(blocks=(ImageBlock, MaskBlock(codes)), # blocks for segmentation
                    get_items=get_image_files, # how to get the files: use function
                    splitter=FileSplitter(data_path/'train'/'vld_full.txt'), # function to split the files
                    get_y=get_mask,
                    item_tfms=Resize(half_res),
                    batch_tfms=[*aug_transforms(size=half_res), Normalize.from_stats(*imagenet_stats)]
                    )

dls = mvscans.dataloaders(path_im, bs=1)

Assign vocab, make learner, load weights.

In [None]:
opt = ranger
dls.vocab = codes
learn = unet_learner(
    dls, # dataloaders
    resnet34, # architecture
    metrics=acc_camvid, # metric
    self_attention=True,
    act_cls=Mish,
    opt_func=opt
)
learn.load('../../out/models/full_sample_lowres_10ep')

In [None]:
learn.lr_find()

In [None]:
lr = 1e-3 # set learning rate

Fit the model for ten epochs.

In [None]:
learn.fit_flat_cos(10, slice(lr))

In [None]:
### dls full size
mvscans = DataBlock(blocks=(ImageBlock, MaskBlock(codes)), # blocks for segmentation
                    get_items=get_image_files, # how to get the files: use function
                    splitter=FileSplitter(data_path/'train'/'vld_full.txt'), # function to split the files
                    get_y=get_mask,
                    item_tfms=Resize(res_exp),
                    batch_tfms=[*aug_transforms(size=res_exp), Normalize.from_stats(*imagenet_stats)]
                    )

dls_full = mvscans.dataloaders(path_im, bs=1)
dls_full.vocab = codes
learn_full = unet_learner(
    dls_full, # dataloaders
    resnet34, # architecture
    metrics=acc_camvid, # metric
    self_attention=True,
    act_cls=Mish,
    opt_func=opt
)

learn_full.fit_flat_cos(5, slice(lr))

Now we can save the model and use it for inference.

In [None]:
learn.save('../../out/models/full_sample_highres_10ep')

In [None]:
learn.unfreeze()
lrs = slice(1e-6,lr/10)
learn.fit_flat_cos(10, lrs)
learn.save('../../out/models/full_sample_highres_10ep_unfreeze')

## Full Size Inference

In [None]:
res_inference = (224, 224)
# load the model and a dataloader with the correct image size
mvscans = DataBlock(blocks=(ImageBlock, MaskBlock(codes)), # blocks for segmentation
                    get_items=get_image_files, # how to get the files: use function
                    splitter=FileSplitter(data_path/'train'/'vld_full.txt'), # function to split the files
                    get_y=get_mask,
                    item_tfms=[Resize(res_inference)],
                    batch_tfms=[*aug_transforms(size=res_inference, 
                                                do_flip=False, 
                                                max_rotate=0.,
                                                max_zoom=1.0,
                                                max_warp=0.,
                                                p_affine=0.), Normalize.from_stats(*imagenet_stats)]
                    )

dls = mvscans.dataloaders(path_im, bs=4)
learn_pred = unet_learner(
    dls, # dataloaders
    resnet34, # architecture
    metrics=acc_camvid, # metric
    self_attention=True,
    act_cls=Mish,
    opt_func=ranger
)
learn_pred.load('../../out/models/full_sample_extended_tric_224_10ep')


In [None]:
def shape_mask_to_frame(msk_array, dim_frame):
    """Takes a square mask array and frame dimensions and returns a mask with the same shape as the video frame"""
    dim_msk = msk_array.shape[0]
    scale_y = dim_frame[0]/dim_msk
    rsz = lambda o: CropPad(dim_frame)(RatioResize(scale_y*dim_msk)(o))
    msk = PILMask.create(msk_array)
    return tensor(rsz(msk))


def show_some_predictions(preds, test_data, ids, frames, threshold = 0.5,  k = -1):
    """
        Shape of predictions must be square.
    """
    fig, axs = plt.subplots(2, 3, figsize=(12, 8))
    rsz = Resize(res_inference[0], method=ResizeMethod.Crop)
    
    if k == -1 :
        k = random.randint(0, len(preds[0]-5))

    print("Index of predictions: ", k, ", ..., ", k+5)
    for i in range(k, k+6) :
        pred_array = preds[0][i][1] > threshold # a threshold of 0.2 is applied to the prediction
                
        scan = test_data[ids[i]]['video'][:,:,frames[i]]
        img = PILImage.create(scan)
        
        msk = PILMask.create(pred_array)
        scale_y = img.shape[0]/scan.shape[0]
        rsz = lambda o: CropPad(img.shape)(RatioResize(scale_y*scan.shape[0])(o))

        axs[(i-k)//3, (i-k)%3].imshow(img, cmap='gray', alpha=1)
        axs[(i-k)//3, (i-k)%3].imshow(tensor(rsz(msk)), alpha=0.5)
        axs[(i-k)//3, (i-k)%3].set_title(f"Image ID {ids[i]}, Frame {frames[i]}")
        axs[(i-k)//3, (i-k)%3].axis('off')

    plt.show()


def predictions2list(preds, test_data, threshold=0.5):
    """Takes a list of predictions, the test_data and a threshold for classification, and returns 
        
        a list of dictionaries with 
        
        - name of the video
        - rescaled prediction
        
    """
    i = 0
    list_predictions = []
    while i < len(preds[0]):

        id = ids[i]
        frames_id = frames[np.array(ids) == id]

        list_frames = []
        for frame in frames_id:

            dim_frame = test_data[id]['video'][:,:,frame].shape
            msk_array = preds[0][i][1, :, :]

            msk_array_rszd = shape_mask_to_frame(msk_array, dim_frame).numpy()
            msk_array_rszd = msk_array_rszd > threshold
            list_frames.append(msk_array_rszd)
            i += 1

        dict_i = {
            "name": test_data[id]['name'],
            "prediction": np.dstack(list_frames)
        }
        list_predictions.append(dict_i)
    
    return list_predictions
    

In [None]:
def save_zipped_pickle(obj, filename):
    with gzip.open(filename, 'wb') as f:
        pickle.dump(obj, f, 2)

In [None]:
# list of files for prediction, associated ids and frames
pred_list = get_image_files(data_path/'test'/'scans').map(lambda o: o.stem)
pred_list = L(sorted(pred_list, key=lambda o: 1e6*int(o.split("_")[1]) + int(o.split("_")[2]))) # sort by id and frame

#all predictions
#pred_list = pred_list[0:300] # first 300 predictions

ids = pred_list.map(lambda o: int(o.split("_")[1]))
frames = pred_list.map(lambda o: int(o.split("_")[2]))

dl_test = learn_pred.dls.test_dl([data_path/'test'/('scans/' + pred_list[i] + '.png') for i in range(len(pred_list))])

dl_test.show_batch()


In [None]:
# make predictions
preds = learn_pred.get_preds(dl=dl_test)

In [None]:
print('The resolution of the prediction is (are): ', L([preds[0][i].argmax(dim=0).numpy().shape for i in range(len(preds[0]))]).unique() )

# load test images
_, test_data, _ = get_pkl_data(data_path)

print('The resolutions of the test data are:', L([test_data[ids[i]]['video'][:,:,frames[i]].shape for i in range(len(ids))]).unique())

too dark: 14 (*2)
18 (*1.5)

In [None]:
f = 0

In [None]:
id = 9

print('avg light:', test_data[id]['video'].mean())
frame = test_data[id]['video'][:,:,f]*1.
frame[frame > 255] = 255
plt.imshow(frame, cmap='gray')
f += 1

In [None]:
# Add to training set
# from predictions of "full_sample_noval_tric_224_15ep"

add_list = [(11, 74),
(11, 84),
(12, 0),
(12, 19),
(12, 27),
(12, 48),
(12, 56),
(13, 8),
(13, 11),
(13, 15),
(13, 24),
(13, 26),
(14, 18),
(14, 30),
(14, 37),
(15, 15),
(15, 23),
(15, 29),
(15, 39),
(15, 63),
(16, 5),
(16, 15),
(16, 29),
(17, 11),
(17, 23),
(17, 39),
(17, 51),
(18, 2),
(18, 7),
(18, 21),
(19, 0),
(19, 8),
(19, 22),
(19, 25),
(19, 44),
(0, 0),
(0, 5),
(0, 16),
(0, 28),
(1, 5),
(1, 22),
(1, 28),
(1, 46),
(2, 1),
(2, 18),
(2, 23),
(3, 12),
(3, 21),
(3, 28),
(4, 6),
(4, 18),
(4, 26),
(4, 41),
(5, 4),
(5, 25),
(5, 47),
(5, 65),
(6, 21),
(6, 25),
(9, 27),
(9, 35),
(9, 39),
(10, 16),
(10, 20),
(10, 22)]

In [None]:
predictions_reshaped[0].keys()

In [None]:

for id_key in add_list:
    id = id_key[0]
    fr = id_key[1]
    ToPILImage()(predictions_reshaped[id]['prediction'][:,:,fr].astype(np.uint8)).save(data_path/'train'/'labels'/f'tst_{id}_{fr}_lab.png', format="PNG")
    ToPILImage()(test_data[id]['video'][:,:,:,fr].astype(np.uint8)).save(data_path/'train'/'scans'/f'tst_{id}_{fr}.png', format="PNG")

In [None]:
show_some_predictions(preds, test_data, ids, frames, 0.3, 1400)
k = k + 6

In [None]:
predictions_reshaped = predictions2list(preds, test_data, threshold=0.3)

In [None]:

fig, axs = plt.subplots(2, 3, figsize=(12, 8))

rsz = Resize(224, method=ResizeMethod.Crop)
k = random.randint(0, 295)
for i in range(k, k+6) :
    pred_array = preds[0][i][1] > 0.2 # a threshold of 0.2 is applied to the prediction
        
    scan = test_data[ids[i]]['video'][:,:,frames[i]]
    img = PILImageBW.create(scan)
    ar = img.shape[0]/img.shape[1]
    rsz = lambda o: CropPad(224)(RatioResize(224/ar)(o))

    axs[(i-k)//3, (i-k)%3].imshow(rsz(img), cmap='gray', alpha=1)
    axs[(i-k)//3, (i-k)%3].imshow(pred_array, alpha=0.5)
    axs[(i-k)//3, (i-k)%3].set_title(f"Image ID {ids[i]}, Frame {frames[i]}")
    axs[(i-k)//3, (i-k)%3].axis('off')

plt.show()


In [None]:
predpath = Path("../out/predictions/")
predpath.mkdir(exist_ok=True, parents=True)
save_zipped_pickle(predictions_reshaped, predpath/"unet_224_amex_extended_tric_30pct_10ep.pkl")

In [None]:
len(test_data), len(predictions_reshaped)
for i in range(len(test_data)):
    print(test_data[i]['video'].shape == predictions_reshaped[i]['prediction'].shape)
