---
### Imports

In [1]:
import json
from rich import print as rprint
import pandas as pd
import numpy as np

import warnings
# Filter the specific UserWarning from torch regarding TF32/matmul precision
warnings.filterwarnings("ignore", category=UserWarning, module="torch")

import torch
torch.set_float32_matmul_precision('high')
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split

import cv2
from tqdm import tqdm

import segmentation_models_pytorch as smp

import albumentations as A
from torch.optim.lr_scheduler import ReduceLROnPlateau
import copy

print('OpenCV version: ', cv2.__version__)

OpenCV version:  4.12.0


In [3]:
import random
import os

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True, warn_only=True)

seed_everything(42)


---
### Load data

In [4]:
# Read Training Data CSV
df_train = pd.read_csv('../data/raw/Training/Train.csv')

In [5]:
# Drop image 332 due to corrupted file
df_train = df_train[df_train['image_id'] != 332].reset_index(drop=True)

In [6]:
# Read Cropped RGB images - image values [0,255]

bgr_path_train = '../data/processed/crops/Training/RGBImages/'

cropped_rgb_images_train = {}

for image_id in df_train['image_id'].values:
    if image_id % 332 == 0:
        continue
    img = cv2.imread(bgr_path_train + 'cropped_RGB_' + str(image_id) + '.png', cv2.IMREAD_COLOR)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    cropped_rgb_images_train[image_id] = img
print(f'Loaded {len(cropped_rgb_images_train)} training images.')


Loaded 230 training images.


In [7]:
# Read Depth images - image values - 16 bit
depth_path_train = '../data/processed/crops/Training/DepthImages/'

cropped_depth_images_train = {}

for image_id in df_train['image_id'].values:
    if image_id % 332 == 0:
        continue
    img = cv2.imread(depth_path_train + 'cropped_Depth_' + str(image_id) + '.png', cv2.IMREAD_UNCHANGED)
    cropped_depth_images_train[image_id] = img
print(f'Loaded {len(cropped_depth_images_train)} training depth images.')


Loaded 230 training depth images.


In [8]:
cropped_rgb_images_train[1].shape, cropped_depth_images_train[1].shape

((800, 800, 3), (800, 800))

In [9]:
def get_specific_mask(mask_uint8, target_name):
    """
    Returns a binary boolean mask for a specific object class.
    
    INPUTS:
        mask_uint8 (np.ndarray): The 800x800 grayscale image from CVAT 
                                 (values 0, 145, 169).
        target_name (str):       The object to isolate: 'lettuce' or 'crate'.
        
    OUTPUT:
        binary_mask (np.ndarray): A boolean array where True represents 
                                  the requested object.
    """
    if target_name.lower() == 'lettuce':
        return mask_uint8 == 169
    elif target_name.lower() == 'crate':
        return mask_uint8 == 145
    else:
        raise ValueError("target_name must be 'lettuce' or 'crate'")

In [10]:
# Load Segmentation Masks - image values [0,255]
# Read file names from os as not all images have masks. File names are in format 'cropped_RGB_{image_id}.png'

mask_path_train = '../data/labels/SegmentationClass/'

lettuce_masks_800 = {}
crate_masks_800 = {}

for filename in os.listdir(mask_path_train):
    if filename.endswith('.png'):
        image_id = int(filename.split('_')[2].split('.')[0])
        raw_img = cv2.imread(os.path.join(mask_path_train, filename), cv2.IMREAD_GRAYSCALE)
        
        # Specialist 1: Lettuce (Intensity 169)
        lettuce_masks_800[image_id] = (raw_img == 169).astype(np.uint8)
        
        # Specialist 2: Crate (Intensity 145)
        crate_masks_800[image_id] = (raw_img == 145).astype(np.uint8)

print(f"Ready to train: {len(lettuce_masks_800)} Lettuce masks and {len(crate_masks_800)} Crate masks.")


Ready to train: 70 Lettuce masks and 70 Crate masks.


---
### Train segmenter

In [11]:
# 1. The "Infinite Variety" Augmentation Pipeline
# This makes 70 images look like 7,000 to the model
train_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    # Affine replaces ShiftScaleRotate
    A.Affine(
        translate_percent={"x": (-0.06, 0.06), "y": (-0.06, 0.06)},
        scale=(0.9, 1.1),
        rotate=(-45, 45),
        p=0.5
    ),
    A.RandomBrightnessContrast(p=0.2),
])

class CompetitiveDataset(Dataset):
    def __init__(self, rgb_dict, mask_dict, ids, size=(512, 512), transform=None):
        self.rgb_dict = rgb_dict
        self.mask_dict = mask_dict
        self.ids = ids
        self.size = size
        self.transform = transform

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        key = self.ids[idx]
        img = cv2.resize(self.rgb_dict[key], self.size, interpolation=cv2.INTER_AREA)
        mask = cv2.resize(self.mask_dict[key].astype(np.uint8), self.size, interpolation=cv2.INTER_NEAREST)
        mask = (mask > 0).astype(np.float32)

        if self.transform:
            augmented = self.transform(image=img, mask=mask)
            img, mask = augmented['image'], augmented['mask']

        img_t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1)
        mask_t = torch.from_numpy(mask).unsqueeze(0)
        return img_t, mask_t


In [12]:
def train_specialist(model, train_loader, val_loader, optimizer, scheduler, criterion, epochs=100, name="Specialist"):
    # Ensure the weights directory exists
    os.makedirs('../weights', exist_ok=True)
    
    # Initialize tracking variables
    best_iou = 0.0  # 
    best_model_wts = copy.deepcopy(model.state_dict())

    for epoch in range(epochs):
        # --- TRAINING PHASE ---
        model.train()
        train_loss = 0.0
        pbar = tqdm(train_loader, desc=f"{name} Train - Epoch {epoch+1}/{epochs}")
        for imgs, msks in pbar:
            imgs, msks = imgs.cuda(), msks.cuda()
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, msks)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            pbar.set_postfix(loss=f"{loss.item():.4f}")
        
        avg_train_loss = train_loss / len(train_loader)

        # --- VALIDATION PHASE ---
        model.eval()
        val_loss = 0.0
        total_iou = 0.0
        with torch.no_grad():
            for imgs, msks in val_loader:
                imgs, msks = imgs.cuda(), msks.cuda()
                outputs = model(imgs)
                
                # Calculate Loss
                loss = criterion(outputs, msks)
                val_loss += loss.item()

                # Calculate IoU (Binary)
                preds = (torch.sigmoid(outputs) > 0.5).float()
                intersection = (preds * msks).sum()
                union = (preds + msks).gt(0).sum()
                iou = (intersection + 1e-6) / (union + 1e-6)
                total_iou += iou.item()
        
        avg_val_loss = val_loss / len(val_loader)
        avg_iou = total_iou / len(val_loader)
        
        # Step scheduler based on Validation Loss
        scheduler.step(avg_val_loss)

        # Check if this is the best spatial fit (IoU) so far
        if avg_iou > best_iou:
            best_iou = avg_iou
            best_model_wts = copy.deepcopy(model.state_dict())
            save_path = f"../weights/best_{name.lower()}_segmentation_model.pth"
            torch.save(best_model_wts, save_path)
            print(f"--> [NEW BEST IoU] {name}: {avg_iou:.4f} (Saved to {save_path})")

        current_lr = optimizer.param_groups[0]['lr']
        print(f"{name} Ep {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | IoU: {avg_iou:.4f} | LR: {current_lr:.6f}")

    # Load the best weights back before returning to the main notebook
    model.load_state_dict(best_model_wts)
    return model

In [13]:
# Combined Loss: Dice (shape) + BCE (pixel-level detail)
# Use BCEWithLogitsLoss because smp.Unet does NOT have sigmoid by default
dice_loss = smp.losses.DiceLoss(mode='binary')
bce_loss = nn.BCEWithLogitsLoss()

def criterion(outputs, masks):
    return dice_loss(torch.sigmoid(outputs), masks) + bce_loss(outputs, masks)

In [None]:
# --- CREATE TRAIN/VAL SPLIT ---
all_ids = list(lettuce_masks_800.keys())
train_ids, val_ids = train_test_split(all_ids, test_size=0.15, random_state=42)

# BATCH
BATCH_SIZE = 16
NUM_WORKERS = 4 

# --- SET UP LETTUCE SPECIALIST ---
lettuce_model = smp.Unet(encoder_name="efficientnet-b3", encoder_weights="imagenet", in_channels=3, classes=1).cuda()
lettuce_opt = torch.optim.AdamW(lettuce_model.parameters(), lr=3e-4, weight_decay=1e-4)
lettuce_sched = ReduceLROnPlateau(lettuce_opt, mode='min', factor=0.5, patience=10)
l_train_ds = CompetitiveDataset(cropped_rgb_images_train, lettuce_masks_800, train_ids, size=(512,512), transform=train_transform)
l_val_ds = CompetitiveDataset(cropped_rgb_images_train, lettuce_masks_800, val_ids, size=(512,512), transform=None)
l_train_loader = DataLoader(l_train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, prefetch_factor=2)
l_val_loader = DataLoader(l_val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

# --- SET UP CRATE SPECIALIST ---
crate_model = smp.Unet(encoder_name="efficientnet-b3", encoder_weights="imagenet", in_channels=3, classes=1).cuda()
crate_opt = torch.optim.AdamW(crate_model.parameters(), lr=3e-4, weight_decay=1e-4)
crate_sched = ReduceLROnPlateau(crate_opt, mode='min', factor=0.5, patience=10)
c_train_ds = CompetitiveDataset(cropped_rgb_images_train, crate_masks_800, train_ids, size=(512,512), transform=train_transform)
c_val_ds = CompetitiveDataset(cropped_rgb_images_train, crate_masks_800, val_ids, size=(512,512), transform=None)
c_train_loader = DataLoader(c_train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, prefetch_factor=2)
c_val_loader = DataLoader(c_val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

# --- TRAIN BOTH ---
# Train Lettuce (100 epochs for organic complexity)
train_specialist(lettuce_model, l_train_loader, l_val_loader, lettuce_opt, lettuce_sched, criterion, epochs=100, name="Lettuce")

# Train Crate (50 epochs for geometric simplicity)
train_specialist(crate_model, c_train_loader, c_val_loader, crate_opt, crate_sched, criterion, epochs=50, name="Crate")