Not a cleanest notebook, more for archiving purposes.

In [None]:
from fastai.vision.all import *
from utils import (
    generic_image_path,
    generic_image_path_processed,
    generic_segmentation_path,
    generic_segmentation_path_processed,
    paste_imgs,
)

In [None]:
def add_mask(source: Image.Image, mask: Image.Image) -> Image.Image:
    source = source.convert('RGBA')
    # mask = Image.fromarray(np.r_[mask] * 255).convert('RGBA')
    mask = mask.convert('RGBA')
    M = np.r_[mask]
    M[:, :, 1:2] = 0
    M[:, :, 3] = 120
    
    mask = Image.fromarray(M)
    return Image.alpha_composite(source, mask)

In [None]:
root = Path("..")

In [22]:
df = pd.read_csv("metadata.csv")

In [None]:
image_path = generic_image_path(root)
image_path_processed = generic_image_path_processed(root)
segmentation_path = generic_segmentation_path(root)
segmentation_path_processed = generic_segmentation_path_processed(root)

In [24]:
def label_func(p:Path):
    return segmentation_path_processed(p.name)

In [None]:
df['segmentation_file_processed'] = df.filename.apply(segmentation_path_processed)
df['image_file_processed'] = df.filename.apply(image_path_processed)

#print(df)

### Training

In [27]:
dls = DataBlock(
    blocks=(ImageBlock, MaskBlock(codes=["nothing", "lipid_sac"])),
    get_x=ColReader("image_file_processed"),
    get_y=ColReader("segmentation_file_processed"),
    splitter=ColSplitter("is_valid"),
    #splitter=RandomSplitter(valid_pct=0.8, seed=42),
    batch_tfms=[Normalize.from_stats(*imagenet_stats), *aug_transforms(pad_mode='zeros', max_rotate=180, max_zoom = 2, )]
).dataloaders(df, batch_size=4)

Due to IPython and Windows limitation, python multiprocessing isn't available now.
So `number_workers` is changed to 0 to avoid getting stuck
Could not do one pass in your dataloader, there is something wrong in it


In [None]:
dls.show_batch()

In [None]:
class CombinedLoss:
    "Dice and Focal combined"
    def __init__(self, axis=1, smooth=1., alpha=1.):
        store_attr()
        self.focal_loss = FocalLossFlat(axis=axis)
        self.dice_loss =  DiceLoss(axis, smooth)
        
    def __call__(self, pred, targ):
        return self.focal_loss(pred, targ) + self.alpha * self.dice_loss(pred, targ)
    
    def decodes(self, x):    return x.argmax(dim=self.axis)
    def activation(self, x): return F.softmax(x, dim=self.axis)

In [None]:
def IoU(preds:Tensor, targs:Tensor, eps:float=1e-8):
    """Computes the Jaccard loss, a.k.a the IoU loss.
    Notes: [Batch size,Num classes,Height,Width]
    Args:
        targs: a tensor of shape [B, H, W] or [B, 1, H, W].
        preds: a tensor of shape [B, C, H, W]. Corresponds to
            the raw output or logits of the model. (prediction)
        eps: added to the denominator for numerical stability.
    Returns:
        iou: the average class intersection over union value 
             for multi-class image segmentation
    """
    num_classes = preds.shape[1]
    
    # Single class segmentation?
    if num_classes == 1:
        true_1_hot = torch.eye(num_classes + 1)[targs.squeeze(1)]
        true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
        true_1_hot_f = true_1_hot[:, 0:1, :, :]
        true_1_hot_s = true_1_hot[:, 1:2, :, :]
        true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
        pos_prob = torch.sigmoid(preds)
        neg_prob = 1 - pos_prob
        probas = torch.cat([pos_prob, neg_prob], dim=1)
        
    # Multi-class segmentation
    else:
        # Convert target to one-hot encoding
        # true_1_hot = torch.eye(num_classes)[torch.squeeze(targs,1)]
        true_1_hot = torch.eye(num_classes)[targs.squeeze(1)]
        
        # Permute [B,H,W,C] to [B,C,H,W]
        true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
        
        # Take softmax along class dimension; all class probs add to 1 (per pixel)
        probas = F.softmax(preds, dim=1)
        
    true_1_hot = true_1_hot.type(preds.type())
    
    # Sum probabilities by class and across batch images
    dims = (0,) + tuple(range(2, targs.ndimension()))
    intersection = torch.sum(probas * true_1_hot, dims) # [class0,class1,class2,...]
    cardinality = torch.sum(probas + true_1_hot, dims)  # [class0,class1,class2,...]
    union = cardinality - intersection
    iou = (intersection / (union + eps)).mean()   # find mean of class IoU values
    return iou

In [None]:
combined_loss = CombinedLoss()

In [None]:
learn = unet_learner(dls, resnet34, loss_func=combined_loss, metrics=[IoU])

In [None]:
cbs = []
cbs.append(EarlyStoppingCallback(patience=5))
cbs.append(SaveModelCallback(fname="model_resnet34"))
cbs.append(GradientAccumulation(n_acc=8))

In [None]:
learn = unet_learner(dls, resnet34)

learn.fine_tune(40, cbs=cbs)

In [None]:
learn.export("models/learner.pkl")

In [None]:
learn.save("model_backup")

### Validation

In [None]:
learn2 = load_learner("models/learner.pkl", cpu=False)

In [None]:
learn2.load("model_resnet34_after26")

In [None]:
dls2 = DataBlock(
    blocks=(ImageBlock, MaskBlock(codes=["nothing", "lipid_sac"])),
    get_x=ColReader("image_file_processed"),
    get_y=ColReader("segmentation_file_processed"),
    splitter=ColSplitter("is_valid"),
    batch_tfms=[Normalize.from_stats(*imagenet_stats)]
).dataloaders(df, batch_size=4)

In [None]:
res = learn2.get_preds(dl=dls2[1], reorder=False, with_input=True)