In [None]:
from pathlib import Path
import pathlib
temp = pathlib.PosixPath
pathlib.PosixPath = pathlib.WindowsPath

from fastai.vision.all import *
# from fastai.vision.all import load_learner
import numpy as np
from PIL import Image
from IPython.display import display

np.int = np.int32 # Need to add this to make preset model work

In [None]:
def add_mask2(source: Image.Image, truth: np.ndarray, pred: np.ndarray) -> Image.Image:
    """
    Given source black and white image, true and predicted mask this function returns image
    with areas colored as following:
    - Green - annotation only
    - Red - prediction only
    - Yellow - overlap
    """
    source = source.convert('RGBA')
    M = np.zeros((*truth.shape, 4), dtype=np.uint8)
    M[:, :, 1] = truth[:, :] * 255
    M[:, :, 0] = pred[:, :] * 255
    M[:, :, 3] = ((truth > 0) | (pred > 0)) * 75
    
    mask = Image.fromarray(M, 'RGBA')
    return Image.alpha_composite(source, mask)

In [None]:
def image2segmentation_path(imgpath: Path) -> Path:
    return Path(str(imgpath).replace("images", "segmentations"))

def pixels2area(n: int) -> float:
    """Converts number of pixels into area in um^2"""
    return n * 0.023 * 0.023

def area2mass(A: float) -> float:
    """Converts area in um^2 into mass in mg"""
    return 0.197 * (A ** 1.38)

In [None]:
# TODO fix this section wherein the conditional batch norm needs to be read
import torch.nn as nn
# Function to modify BatchNorm layers
def modify_batchnorm(module):
    for child_name, child in module.named_children():
        if isinstance(child, nn.BatchNorm2d):
            # Replace BatchNorm2d with a conditional version or skip logic
            setattr(module, child_name, ConditionalBatchNorm(child.num_features))
        else:
            modify_batchnorm(child)

# Define a custom conditional batch normalization layer
class ConditionalBatchNorm(nn.Module):
    def __init__(self, num_features):
        super(ConditionalBatchNorm, self).__init__()
        self.bn = nn.BatchNorm2d(num_features)
    
    def forward(self, x):
        # Apply BatchNorm only if the spatial dimensions are greater than 1
        if x.size(2) > 1 and x.size(3) > 1:
            return self.bn(x)
        else:
            return x

# Other modules needed from training
class CombinedLoss:
    "Dice and Focal combined"
    def __init__(self, axis=1, smooth=1., alpha=1.):
        store_attr()
        self.focal_loss = FocalLossFlat(axis=axis)
        self.dice_loss =  DiceLoss(axis, smooth)
        
    def __call__(self, pred, targ):
        return self.focal_loss(pred, targ) + self.alpha * self.dice_loss(pred, targ)
    
    def decodes(self, x):    return x.argmax(dim=self.axis)
    def activation(self, x): return F.softmax(x, dim=self.axis)

def IoU(preds:Tensor, targs:Tensor, eps:float=1e-8):
    """Computes the Jaccard loss, a.k.a the IoU loss.
    Notes: [Batch size,Num classes,Height,Width]
    Args:
        targs: a tensor of shape [B, H, W] or [B, 1, H, W].
        preds: a tensor of shape [B, C, H, W]. Corresponds to
            the raw output or logits of the model. (prediction)
        eps: added to the denominator for numerical stability.
    Returns:
        iou: the average class intersection over union value 
             for multi-class image segmentation
    """
    num_classes = preds.shape[1]
    
    # Single class segmentation?
    if num_classes == 1:
        true_1_hot = torch.eye(num_classes + 1)[targs.squeeze(1)]
        true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
        true_1_hot_f = true_1_hot[:, 0:1, :, :]
        true_1_hot_s = true_1_hot[:, 1:2, :, :]
        true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
        pos_prob = torch.sigmoid(preds)
        neg_prob = 1 - pos_prob
        probas = torch.cat([pos_prob, neg_prob], dim=1)
        
    # Multi-class segmentation
    else:
        # Convert target to one-hot encoding
        # true_1_hot = torch.eye(num_classes)[torch.squeeze(targs,1)]
        true_1_hot = torch.eye(num_classes)[targs.squeeze(1)]
        
        # Permute [B,H,W,C] to [B,C,H,W]
        true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
        
        # Take softmax along class dimension; all class probs add to 1 (per pixel)
        probas = F.softmax(preds, dim=1)
        
    true_1_hot = true_1_hot.type(preds.type())
    
    # Sum probabilities by class and across batch images
    dims = (0,) + tuple(range(2, targs.ndimension()))
    intersection = torch.sum(probas * true_1_hot, dims) # [class0,class1,class2,...]
    cardinality = torch.sum(probas + true_1_hot, dims)  # [class0,class1,class2,...]
    union = cardinality - intersection
    iou = (intersection / (union + eps)).mean()   # find mean of class IoU values
    return iou

In [None]:
model = load_learner("models/learner_down4_1015.pkl", cpu=True)  # File models/learner.pkl
model.load("model_resnet34_1015")  # File models/model_resnet34.pth
model.dls.device = 'cpu' # PP? Why is it ran with the cpu and not the gpu?

In [None]:
# imgs = sorted(list(Path("images/example_inputs_orig").glob("*.bmp")))

# imgs = sorted(list(Path("../downscaled_4/data/images_processed").glob("*.bmp")))
# imgs = sorted(list(Path("images/example_inputs_lowres").glob("*.bmp")))
# imgs = sorted(list(Path("images/example_inputs_lowres2").glob("*.bmp")))

# The UVP images
# imgs = sorted(list(Path("../uvp_images/UVP6_darkedge_copepod_black_bg").glob("*.png")))
imgs = sorted(list(Path("../uvp_images/UVP5_inverted_images").glob("*.jpg")))

imgs = imgs[1:100]
# imgs


In [None]:
for pimg in imgs:
    im = Image.open(pimg)
    p_segmentation = image2segmentation_path(pimg)

    # Predict returns the decoded prediction, index of the predicted class, and tensor of probabilities (but here
    #  only the decoded prediction is retained/
    mask, *_ = model.predict(pimg)
    im_pred_mask = mask.numpy()

    # if p_segmentation is not None:
    #     im_truth = np.array(Image.open(p_segmentation).convert("L"))
    # else:
    #     im_truth = np.zeros_like(im_pred_mask)

    # if there are no true segmentation files at all:
    im_truth = np.zeros_like(im_pred_mask)

    img_with_masks = add_mask2(im, im_truth, im_pred_mask)

    lipid_annotated = area2mass(pixels2area((im_truth > 0).sum()))
    lipid_predicted = area2mass(pixels2area((im_pred_mask > 0.5).sum()))
    print("************************")
    print(pimg)
    print(f"Lipid annotation: {lipid_annotated:.5}mg")
    print(f"Lipid prediction: {lipid_predicted:.5}mg")
    # Green - annotation
    # Red - prediction
    # Yellow - overlap
    display(img_with_masks.convert("RGB"))

In [None]:
# TODO Add some segmentation diagnostics and output the images