---
### Imports

In [1]:
import json
from rich import print as rprint
import pandas as pd
import numpy as np

import warnings
# Filter the specific UserWarning from torch regarding TF32/matmul precision
warnings.filterwarnings("ignore", category=UserWarning, module="torch")

import torch
torch.set_float32_matmul_precision('high')
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split

import cv2
from tqdm import tqdm

import segmentation_models_pytorch as smp

import albumentations as A
from torch.optim.lr_scheduler import ReduceLROnPlateau
import copy
import os
import math

print('OpenCV version: ', cv2.__version__)

OpenCV version:  4.12.0


In [2]:
import random
import os

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True, warn_only=True)

seed_everything(42)


---
### Load data

In [3]:
# Read Training Data CSV
df_train = pd.read_csv('../data/raw/Training/Train.csv')

In [4]:
# Drop image 332 due to corrupted file
df_train = df_train[df_train['image_id'] != 332].reset_index(drop=True)

In [5]:
# Read Cropped RGB images - image values [0,255]

bgr_path_train = '../data/processed/crops/Training/RGBImages/'

cropped_rgb_images_train = {}

for image_id in df_train['image_id'].values:
    if image_id % 332 == 0:
        continue
    img = cv2.imread(bgr_path_train + 'cropped_RGB_' + str(image_id) + '.png', cv2.IMREAD_COLOR)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    cropped_rgb_images_train[image_id] = img
print(f'Loaded {len(cropped_rgb_images_train)} training images.')

Loaded 230 training images.


In [6]:
# Read Depth images - image values - 16 bit
depth_path_train = '../data/processed/crops/Training/DepthImages/'

cropped_depth_images_train = {}

for image_id in df_train['image_id'].values:
    if image_id % 332 == 0:
        continue
    img = cv2.imread(depth_path_train + 'cropped_Depth_' + str(image_id) + '.png', cv2.IMREAD_UNCHANGED)
    cropped_depth_images_train[image_id] = img
print(f'Loaded {len(cropped_depth_images_train)} training depth images.')


Loaded 230 training depth images.


---
### Inference

In [7]:
def load_model(path):
    model = smp.Unet(encoder_name="efficientnet-b3", in_channels=3, classes=1).cuda()
    model.load_state_dict(torch.load(path))
    model.eval()
    return model

lettuce_model = load_model("../weights/best_lettuce_segmentation_model.pth")
crate_model = load_model("../weights/best_crate_segmentation_model.pth")

def predict_biomass(rgb_800, depth_800):
    # Preprocess for 512p model
    img_512 = cv2.resize(rgb_800, (512, 512), interpolation=cv2.INTER_AREA)
    img_t = torch.from_numpy(img_512.astype(np.float32)/255.0).permute(2,0,1).unsqueeze(0).cuda()
    
    with torch.no_grad():
        l_pred = torch.sigmoid(lettuce_model(img_t))
        c_pred = torch.sigmoid(crate_model(img_t))
        
    # SCALE BACK TO 800p (Precision Handshake)
    # Convert from boolean to uint8 so OpenCV can resize it
    l_mask_bin = (l_pred > 0.5).cpu().numpy()[0,0].astype(np.uint8)
    c_mask_bin = (c_pred > 0.5).cpu().numpy()[0,0].astype(np.uint8)
    
    # Now resize the uint8 masks
    l_mask = cv2.resize(l_mask_bin, (800, 800), interpolation=cv2.INTER_NEAREST)
    c_mask = cv2.resize(c_mask_bin, (800, 800), interpolation=cv2.INTER_NEAREST)
    
    # CALCULATE Z-BASELINE (The Floor Fix)
    # Using the median of the crate area to avoid outliers
    z_floor = np.median(depth_800[c_mask == 1])
    
    # Isolate lettuce heights
    lettuce_depths = depth_800[l_mask == 1]
    heights = z_floor - lettuce_depths
    
    # Filter noise (e.g., negative heights)
    heights = heights[heights > 0]
    
    # Calculate Volume/Biomass Proxy (Sum of heights * pixel area)
    # Adjust this based on your regression model
    biomass_estimate = np.sum(heights) 
    
    return biomass_estimate, l_mask, c_mask

In [11]:
def process_dataset(rgb_dict, depth_dict):
    results = {}
    for key in tqdm(rgb_dict.keys(), desc="Processing Biomass"):
        rgb = rgb_dict[key]
        depth = depth_dict[key]
        
        # Calculate the 3D volume proxy using specialists
        volume, _, _ = predict_biomass(rgb, depth)
        results[key] = volume
    return results

# Process all splits
train_volumes = process_dataset(cropped_rgb_images_train, cropped_depth_images_train) # volume estimate ultimately does not get used, just the masks

Processing Biomass: 100%|██████████| 230/230 [00:20<00:00, 11.29it/s]


In [12]:
def save_all_masks(rgb_dict, depth_dict, split_name):
    # Create directory for this split
    output_dir = f"../data/processed/masks_inferred/{split_name}"
    os.makedirs(output_dir, exist_ok=True)
    
    for key in tqdm(rgb_dict.keys(), desc=f"Saving {split_name} masks"):
        # Re-run inference to get the high-res 800p masks
        _, l_mask, c_mask = predict_biomass(rgb_dict[key], depth_dict[key])
        
        # Save Lettuce mask (0 or 255 for standard image viewing)
        l_save_path = os.path.join(output_dir, f"{key}_lettuce.png")
        cv2.imwrite(l_save_path, l_mask * 255)
        
        # Save Crate mask
        c_save_path = os.path.join(output_dir, f"{key}_crate.png")
        cv2.imwrite(c_save_path, c_mask * 255)

# Run for all sets
save_all_masks(cropped_rgb_images_train, cropped_depth_images_train, "Train")

Saving Train masks: 100%|██████████| 230/230 [00:25<00:00,  8.94it/s]
