In [3]:
from pathlib import Path
import torch
from torchvision.models.segmentation import deeplabv3_resnet101
from PIL import Image
import torchvision.transforms as T

In [4]:
# point to the data dir
current_folder = Path(".").resolve()
base_dir = current_folder / "dataset/4th_dwarf_tomato/image"

In [5]:
# point to the image files
image = "B_0328ab97"

# Construct the file paths for reading RGB/Depth images
rgb_file = base_dir / "rgb" / f"{image}.png"
depth_file = base_dir / "depth" / f"{image}_depth.png"

# Construct the file paths for saving segemented RGB and filtered Depth images
segmented_file = base_dir / "segmented_rgb" / f"{image}_segmented.png"
filtered_depth_file = base_dir / "masked_depth" / f"{image}_filtered.png"

In [7]:
# Load your images
rgb_image = Image.open(rgb_file)
depth_image = Image.open(depth_file)

# Define the transformation
transform = T.Compose(
    [
        T.Resize(
            (520, 520)
        ),  # Resize the image to the input size required by the model
        T.ToTensor(),  # Convert the image to a tensor
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

# Apply the transformation
input_tensor = transform(rgb_image).unsqueeze(0)  # Add a batch dimension

# Load a pre-trained DeepLabV3 model
model = deeplabv3_resnet101(pretrained=True)
model.eval()

# Perform the segmentation
with torch.no_grad():
    output = model(input_tensor)["out"][0]
output_predictions = output.argmax(0)


In [None]:
# Apply this mask to your RGB and depth images

# Assume 'tomato_plant_mask' is a binary mask with 1s for the plant and 0s for the background
# And 'rgb_image' and 'depth_image' are numpy arrays with your RGB and depth data respectively

# For the RGB image, we need to make sure the mask is 3D to match the image's channels
rgb_mask_3d = np.repeat(tomato_plant_mask[:, :, np.newaxis], 3, axis=2)

# Apply the mask to the RGB image
segmented_rgb_image = np.where(rgb_mask_3d, rgb_image, 0)  # Replace 0 with desired background value

# Apply the mask to the depth image (assuming it's a single channel)
segmented_depth_image = np.where(tomato_plant_mask, depth_image, 0)  # Replace 0 with desired background depth value