Not a cleanest notebook, more for archiving purposes.

In [None]:
from fastai.vision.all import *
from utils import (
    generic_image_path,
    generic_image_path_processed,
    generic_segmentation_path,
    generic_segmentation_path_processed,
    paste_imgs,
)
np.int = np.int32 # Need to add this to make preset model work

In [None]:
def add_mask(source: Image.Image, mask: Image.Image) -> Image.Image:
    source = source.convert('RGBA')
    # mask = Image.fromarray(np.r_[mask] * 255).convert('RGBA')
    mask = mask.convert('RGBA')
    M = np.r_[mask]
    M[:, :, 1:2] = 0
    M[:, :, 3] = 120
    
    mask = Image.fromarray(M)
    return Image.alpha_composite(source, mask)

In [None]:
# In order to not keep changing the folder names in the utils file. I organized the "data/" folder to be inside different resolution scenario folders.
# So revise the root path assigned when necessary.
root = Path("../downscaled_2/")
# root = Path("../")

In [None]:
df = pd.read_csv("metadata.csv")

In [None]:
# Seems like image_path and segmentation_path folders are not used. Maybe holds original files?
# These functions in utils set the folder names there. Must change the root folder?

# image_path = generic_image_path(root)
image_path_processed = generic_image_path_processed(root)
# segmentation_path = generic_segmentation_path(root)
segmentation_path_processed = generic_segmentation_path_processed(root)

In [None]:
def label_func(p:Path):
    return segmentation_path_processed(p.name)

In [None]:
df['segmentation_file_processed'] = df.filename.apply(segmentation_path_processed)
df['image_file_processed'] = df.filename.apply(image_path_processed)

# print(df)

### Training

In [None]:
dls = DataBlock(
    blocks=(ImageBlock, MaskBlock(codes=["nothing", "lipid_sac"])),
    get_x=ColReader("image_file_processed"), # alternatively can be get_items? 
    # get_x is used because something was applied to the input data (transformations).
    get_y=ColReader("segmentation_file_processed"),
    splitter=ColSplitter("is_valid"),
    #splitter=RandomSplitter(valid_pct=0.8, seed=42),
    
     # These settings augments the input images to include padding, crop/zoom, and rotations.
    batch_tfms=[Normalize.from_stats(*imagenet_stats), *aug_transforms(pad_mode='zeros', max_rotate=180, max_zoom = 2)]
).dataloaders(df, batch_size=1,
             drop_last=True) #changed batch_size from 4 to 1
# Added drop_last=True to potentially fix the Error: Expected more than 1 value per channel error
# A higher batch size will result in a RuntimeError: stack expects each tensor to be equal size ... because the image batches dont have the same sizes?

In [None]:
# dls.one_batch()
# dls.show_batch()

xb, yb = dls.one_batch()

print(xb.shape)
print(yb.shape)

In [None]:
class CombinedLoss:
    "Dice and Focal combined"
    def __init__(self, axis=1, smooth=1., alpha=1.):
        store_attr()
        self.focal_loss = FocalLossFlat(axis=axis)
        self.dice_loss =  DiceLoss(axis, smooth)
        
    def __call__(self, pred, targ):
        return self.focal_loss(pred, targ) + self.alpha * self.dice_loss(pred, targ)
    
    def decodes(self, x):    return x.argmax(dim=self.axis)
    def activation(self, x): return F.softmax(x, dim=self.axis)

In [None]:
def IoU(preds:Tensor, targs:Tensor, eps:float=1e-8):
    """Computes the Jaccard loss, a.k.a the IoU loss.
    Notes: [Batch size,Num classes,Height,Width]
    Args:
        targs: a tensor of shape [B, H, W] or [B, 1, H, W].
        preds: a tensor of shape [B, C, H, W]. Corresponds to
            the raw output or logits of the model. (prediction)
        eps: added to the denominator for numerical stability.
    Returns:
        iou: the average class intersection over union value 
             for multi-class image segmentation
    """
    num_classes = preds.shape[1]
    
    # Single class segmentation?
    if num_classes == 1:
        true_1_hot = torch.eye(num_classes + 1)[targs.squeeze(1)]
        true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
        true_1_hot_f = true_1_hot[:, 0:1, :, :]
        true_1_hot_s = true_1_hot[:, 1:2, :, :]
        true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
        pos_prob = torch.sigmoid(preds)
        neg_prob = 1 - pos_prob
        probas = torch.cat([pos_prob, neg_prob], dim=1)
        
    # Multi-class segmentation
    else:
        # Convert target to one-hot encoding
        # true_1_hot = torch.eye(num_classes)[torch.squeeze(targs,1)]
        true_1_hot = torch.eye(num_classes)[targs.squeeze(1)]
        
        # Permute [B,H,W,C] to [B,C,H,W]
        true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
        
        # Take softmax along class dimension; all class probs add to 1 (per pixel)
        probas = F.softmax(preds, dim=1)
        
    true_1_hot = true_1_hot.type(preds.type())
    
    # Sum probabilities by class and across batch images
    dims = (0,) + tuple(range(2, targs.ndimension()))
    intersection = torch.sum(probas * true_1_hot, dims) # [class0,class1,class2,...]
    cardinality = torch.sum(probas + true_1_hot, dims)  # [class0,class1,class2,...]
    union = cardinality - intersection
    iou = (intersection / (union + eps)).mean()   # find mean of class IoU values
    return iou

In [None]:
combined_loss = CombinedLoss()

In [None]:
learn = unet_learner(dls, resnet34, loss_func=combined_loss, metrics=[IoU])

# learn.summary()
# To identify the optimal learning rate if necessary:
# learn.lr_find()

In [None]:
# This is a patch for downscale/2 and /4 to resolve the batchnorm issue when training with the downscaled data because it reduces the input to a single dimension.

# Get the model from the learner
model = learn.model

# Function to modify BatchNorm layers
def modify_batchnorm(module):
    for child_name, child in module.named_children():
        if isinstance(child, nn.BatchNorm2d):
            # Replace BatchNorm2d with a conditional version or skip logic
            setattr(module, child_name, ConditionalBatchNorm(child.num_features))
        else:
            modify_batchnorm(child)

# Define a custom conditional batch normalization layer
class ConditionalBatchNorm(nn.Module):
    def __init__(self, num_features):
        super(ConditionalBatchNorm, self).__init__()
        self.bn = nn.BatchNorm2d(num_features)
    
    def forward(self, x):
        # Apply BatchNorm only if the spatial dimensions are greater than 1
        if x.size(2) > 1 and x.size(3) > 1:
            return self.bn(x)
        else:
            return x

# Apply the function to modify batch normalization layers in the model
modify_batchnorm(model)

# Update the learner's model with the modified version
learn.model = model

In [None]:
# cbs are callbacks used for checkpointing
cbs = []
cbs.append(EarlyStoppingCallback(patience=5))
cbs.append(SaveModelCallback(fname="model_resnet34_1015")) # same name for both /2 and /4
cbs.append(GradientAccumulation(n_acc=8))

In [None]:
# learn = unet_learner(dls, resnet34)

# fine_tune attempmts to improve the model performance. Here the learning rate can be set which could be determined
# from learn.lr_find(). The number inside lr_find() is the number of epochs allowed.
learn.fine_tune(40, cbs=cbs)

In [None]:
# For exploring the model results
learn.show_results()

# # Errors sorted
# interp = SegmentationInterpretation.from_learner(learn) # Also has a softmax error
# interp.plot_top_losses()

In [None]:
learn.export("models/learner_down2_1015.pkl")

In [None]:
learn.save("model_backup_down2_1015") # Should this be the same file name as the one in SaveModelCallback()?

### Validation

In [None]:
learn2 = load_learner("models/learner_down4_1015.pkl", cpu=True)

In [None]:
learn2.load("model_resnet34_1015")

In [None]:
dls2 = DataBlock(
    blocks=(ImageBlock, MaskBlock(codes=["nothing", "lipid_sac"])),
    get_x=ColReader("image_file_processed"),
    get_y=ColReader("segmentation_file_processed"),
    splitter=ColSplitter("is_valid"),
    batch_tfms=[Normalize.from_stats(*imagenet_stats)]
).dataloaders(df, batch_size=1, drop_last=True)

In [None]:
res = learn2.get_preds(dl=dls2[1], reorder=False, with_input=True)
# Current error is: 'list' object has no attribute 'softmax'. This error supposedly arises when the output of the learner is a list rather than a tensor.