# This notebook contains the code for training a random forest model for image segmentation

In [37]:
import os
import glob
import tifffile
import itertools
import logging
modlogger: logging.Logger = logging.getLogger(__name__)
logging.basicConfig(filename='myapp.log', level=logging.INFO)

from typing import List, Tuple, Dict, Any, Union

import numpy as np
import einops as eop
import pandas as pd
from scipy import ndimage
import matplotlib

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, accuracy_score, recall_score, precision_score, f1_score, roc_auc_score
from PIL import Image, ImageEnhance, ImageOps

from iife import iife



In [38]:
## Prepare file path
from constants import datasets

trainset = datasets / "training"
images = trainset / "images"
mask = trainset / "mask"
manual = trainset / "1st_manual"
img_files = sorted(glob.glob((images / "*.tif").__str__()))
mask_files = sorted(glob.glob((mask / "*.gif").__str__() ))
manual_files = sorted(glob.glob( (manual / "*.gif").__str__() ) )
assert len(img_files) == len(mask_files) == len(manual_files), "Number of images, masks and manual files do not match!"

modlogger.info(f"{images=}\n, {mask=}\n, {manual=}\n")



testset = datasets / "test"
testImages = testset / "images"; 
testImg_files = sorted(glob.glob( (testImages / "*.tif").__str__() ))
testMask = testset / "mask"; 
testMask_files = sorted(glob.glob((testMask / "*.gif").__str__() ))
modlogger.info(f"{testImages=}\n, {testMask=}\n")
assert len(testImg_files) == len(testMask_files), "Number of test images and masks do not match!"





In [42]:

# Utility Functions 




# Image Augmentation
def augment_color(image: Image, factor_color=1.5, factor_brightness=1.2, factor_contrast=1.3):
    enhancer_color = ImageEnhance.Color(image)
    image = enhancer_color.enhance(factor_color)
    enhancer_brightness = ImageEnhance.Brightness(image)
    image = enhancer_brightness.enhance(factor_brightness)
    enhancer_contrast = ImageEnhance.Contrast(image)
    image = enhancer_contrast.enhance(factor_contrast)
    return image

### Load Data 
# @iife(img_path= img_files[0], mask_path= mask_files[0], manual_path= manual_files[0])
def load_sample(img_path, mask_path, manual_path: str | None, augment = True):
    width, height = 565, 584
    # Load all components
    img = tifffile.imread(img_path)
    if augment:
        # Convert the numpy array to a PIL.Image object
        img_pil = Image.fromarray(img.astype("uint8"))
        # Enhance color, brightness, contrast.
        img_pil = augment_color(img_pil, factor_color=1.5, factor_brightness=1.2, factor_contrast=1.3)
        # Convert back to numpy array
        img = np.array(img_pil)
    mask = np.array(Image.open(mask_path))
    manual = np.array(Image.open(manual_path)) if manual_path is not None else None
    # Apply mask to image (optional)
    img_masked = img.copy()
    for i in range(3):
        img_masked[:, :, i] = img[:, :, i] * (mask > 0)  # Assuming binary mask
    # Prepare for training
    x = img_masked
    
    y = np.zeros(shape=x.shape[:-1], dtype = np.uint8)
    y[manual == 0] = 1
    y[mask == 0] = 0
    y[manual == 255] = 2

    # print(f"{np.unique(y)=}")
    #  if manual is not None else mask.astype(np.int32)
    # print(f"{img_path=} {xs.shape=}, {ys.shape=}, {xs.max()=}, {xs.min()=}, {xs.mean()=},  {ys.max()=}, {ys.min()=}, {ys.mean()=}")
    return x, y



def LoadAll(img_files: List[str], mask_files : List[str], manual_files: List[str|None], augment: bool = True) -> Tuple[np.array, np.array]:
    x_list = []
    y_list = []
    for img_path, mask_path, manual_path in zip(img_files, mask_files, manual_files):
        x, y = load_sample(img_path, mask_path, manual_path, augment)
        x_list.append(x)
        y_list.append(y)
    modlogger.getChild("LoadAll").info(f"Loaded {len(x_list)} samples")
    xs = np.array(x_list)
    ys = np.array(y_list)
    
    @iife
    def log(logger = modlogger.getChild("LoadAll")):
        logger.info(f"{xs.shape=}, {ys.shape=}")
    
    return xs.astype(np.float32), ys

xs, ys = LoadAll(img_files, mask_files, manual_files, augment=False)
modlogger.info(f"{xs.shape=}, {ys.shape=}")





In [43]:
print(f"{xs.dtype}, {ys.dtype}, {xs.max()}, {xs.min()}, {xs.mean()},  {ys.max()=}, {ys.min()}, {ys.mean()}")

float32, uint8, 255.0, 0.0, 76.86566925048828,  ys.max()=np.uint8(2), 0, 0.7744606922051158


In [53]:
## Train Random Forest Classifier
from sklearn.ensemble import RandomForestRegressor
rgr = RandomForestRegressor(n_jobs=-1)
sublist = lambda l : [l[2], l[3], l[10], l[12]] 

xtrain = eop.rearrange(sublist(xs), 'n h w c -> (n h w) c')
ytrain = eop.rearrange(sublist(ys), 'n h w -> (n h w)')

rgr.fit(xtrain, ytrain)




In [None]:
## Test
from sklearn.metrics import classification_report

def predict(img : np.array) -> np.array:
    """
    Parameters
    ----------
    img : np.array
        Image to predict. Shape is (584, 565, 3).
    Returns
    -------
    np.array
        Predicted mask. Shape is (584, 565).
    """    
    img = eop.rearrange(img, 'h w c -> (h w) c')
    pred = rgr.predict(img)
    pred = eop.rearrange(pred, '(h w) -> h w', h=584, w=565)
    modlogger.getChild("predict").info(f"{img.shape=}, {pred.shape=}")
    return pred



for i in range(6, 7):
    x, y = xs[i], ys[i]
    pixels = eop.rearrange(x, 'h w c -> (h w) c')
    pred = rgr.predict(pixels)
    roundpred = np.round(pred).astype(np.uint8)
    # display(np.unique(pred))

    print(f"{x.shape=} {y.shape=}, {pixels.shape=} ")
    correct = np.rint(y.flatten()).astype(np.uint8)
   
    
    print(f"{np.unique(correct)}, {np.unique(roundpred)}")
    

    manual = np.array(Image.open(manual_files[i]))
    
    roundpredimg = np.zeros_like(roundpred)
    roundpredimg[roundpred == 2] = 255;
    roundpredimg[np.logical_and(roundpred == 0 , manual.flatten() > 0)] = 127;
    imgarray = eop.rearrange(roundpredimg, '(h w) -> h w', h=584, w=565).astype(np.uint8) 
    display(np.unique(roundpred[correct > 0]), np.unique(correct[correct > 0]))
    
    with open(f"Report{i+20}.txt", "w") as text_file:
        text_file.write(classification_report(correct[correct > 0], roundpred[correct > 0]))
    
    Image.fromarray(imgarray).save(f"predict_{i+20}.png")


    


x.shape=(584, 565, 3) y.shape=(584, 565), pixels.shape=(329960, 3) 
[0 1 2], [0 1 2]


array([0, 1, 2], dtype=uint8)

array([1, 2], dtype=uint8)

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [97]:
np.unique(ys[16])

array([  0, 255], dtype=uint8)

In [None]:
ys[10, 300:400, 300:400]


array([[  0.,   0.,   0., ...,   0.,   0.,   0.],
       [255., 255., 255., ...,   0.,   0.,   0.],
       [255., 255., 255., ...,   0.,   0.,   0.],
       ...,
       [  0.,   0.,   0., ...,   0.,   0.,   0.],
       [  0.,   0.,   0., ...,   0.,   0.,   0.],
       [  0.,   0.,   0., ...,   0.,   0.,   0.]],
      shape=(100, 100), dtype=float32)