# FarmWise: Farmland Segmentation and Size Classification with U-Net

**Date**: April 14, 2025

This notebook implements a farm segmentation system using U-Net architecture to identify agricultural fields from satellite imagery, calculate their sizes, and classify them for targeted recommendations.

## Project Overview

**Goal**: Create a system that can:
1. Detect and segment farmlands from satellite imagery
2. Calculate the size/area of each identified farm
3. Classify farms by size (small, medium, large)
4. Enable a recommendation system based on farm size classification

**Approach**: U-Net architecture for semantic segmentation

## 1. Business Understanding

### 1.1 Problem Statement

Agricultural recommendations are most effective when tailored to the specific context of a farm, with farm size being a crucial factor. Large farms may benefit from different techniques, equipment, and crop selections compared to small ones. This project aims to automatically classify farms by size from satellite imagery to enable targeted recommendations.

### 1.2 Success Criteria

- **Technical Success**: Achieve high accuracy in farmland segmentation (IoU > 0.75)
- **Business Success**: Enable accurate size-based classification of farms for targeted recommendations

## 2. Data Acquisition and Understanding

### 2.1 Setup and Environment Preparation

In [None]:
# Check for Kaggle environment and set up dependencies for GPU acceleration
!pip install torch torchvision matplotlib numpy pillow scikit-learn scikit-image opencv-python roboflow tqdm

In [None]:
# Import required libraries
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from sklearn.model_selection import train_test_split
import cv2
from skimage import measure
from tqdm.notebook import tqdm
from roboflow import Roboflow

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check for GPU availability and set up CUDA device
print("CUDA Available:", torch.cuda.is_available())
if torch.cuda.is_available():
    num_gpus = torch.cuda.device_count()
    print(f"Number of GPUs available: {num_gpus}")
    for i in range(num_gpus):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    
    # Use all available GPUs if there are multiple
    if num_gpus > 1:
        device = torch.device('cuda')
        print(f"Using {num_gpus} GPUs for data parallel training")
    else:
        device = torch.device('cuda:0')
        print("Using single GPU")
else:
    device = torch.device('cpu')
    print("No GPU available, using CPU. This will be slower.")

# Display CUDA version if available
if torch.cuda.is_available():
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"Current CUDA device: {torch.cuda.current_device()}")

### 2.2 Data Acquisition from Roboflow

In [None]:
# Initialize Roboflow and load dataset
# Note: You will need to provide your Roboflow API key
rf = Roboflow(api_key="HE9CEH5JxJ3U0vXrQTOy")  # Replace with your actual API key
project = rf.workspace("sid-mp92l").project("final-detectron-2")
dataset = project.version(1).download("yolov8")

# Print dataset path
print(f"Dataset downloaded to: {dataset.location}")

### 2.3 Dataset Exploration

In [None]:
# Explore the dataset structure
def explore_directory(path, level=0):
    print('  ' * level + f"|-- {os.path.basename(path)}")
    if os.path.isdir(path):
        for item in os.listdir(path)[:10]:  # Limit to first 10 items
            item_path = os.path.join(path, item)
            if os.path.isdir(item_path):
                explore_directory(item_path, level + 1)
            else:
                print('  ' * (level + 1) + f"|-- {item}")
        if len(os.listdir(path)) > 10:
            print('  ' * (level + 1) + f"|-- ... ({len(os.listdir(path)) - 10} more items)")

print("Dataset Structure:")
explore_directory(dataset.location)

In [1]:
# Visualize some sample images and masks (Enhanced)
import yaml # Ensure yaml is imported if not already

def visualize_samples(data_dir, num_samples=3):
    # Paths for train images and labels
    train_img_dir = os.path.join(data_dir, 'train', 'images')
    train_mask_dir = os.path.join(data_dir, 'train', 'labels')

    img_files = os.listdir(train_img_dir)
    # Ensure we only process image files and handle potential non-image files
    img_files = [f for f in img_files if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    # Shuffle or select randomly if desired
    # import random
    # random.shuffle(img_files)
    img_files = img_files[:num_samples]

    if not img_files:
        print(f"No image files found in {train_img_dir}")
        return

    # Read the data.yaml file to get class information
    yaml_path = os.path.join(data_dir, 'data.yaml')
    class_names = ['Unknown']
    farm_class_id = None
    if os.path.exists(yaml_path):
        try:
            with open(yaml_path, 'r') as f:
                data_yaml = yaml.safe_load(f)
                if 'names' in data_yaml:
                    class_names = data_yaml['names']
                    print(f"Classes found in dataset: {class_names}")
                    # Try to find the 'farm' class ID
                    for i, name in enumerate(class_names):
                        if "farm" in name.lower():
                            farm_class_id = i
                            print(f"Identified 'farm' class ID: {farm_class_id}")
                            break
        except Exception as e:
            print(f"Error reading class names from data.yaml: {e}")
    else:
        print(f"Warning: data.yaml not found at {yaml_path}. Class names unknown.")

    fig, axes = plt.subplots(num_samples, 3, figsize=(18, 6 * num_samples)) # Increased width for 3 columns

    for i, img_file in enumerate(img_files):
        # Load image
        img_path = os.path.join(train_img_dir, img_file)
        try:
            img = Image.open(img_path).convert("RGB")
            img_np = np.array(img) # Keep a numpy copy for drawing outlines
            img_width, img_height = img.size
        except Exception as e:
            print(f"Error loading image {img_file}: {e}")
            # Set placeholder titles if image fails to load
            if num_samples == 1:
                 axes[0].set_title(f"Error loading {img_file}")
                 axes[1].set_title("Mask N/A")
                 axes[2].set_title("Raw Polygons N/A")
                 axes[0].axis('off'); axes[1].axis('off'); axes[2].axis('off')
            else:
                 axes[i, 0].set_title(f"Error loading {img_file}")
                 axes[i, 1].set_title("Mask N/A")
                 axes[i, 2].set_title("Raw Polygons N/A")
                 axes[i, 0].axis('off'); axes[i, 1].axis('off'); axes[i, 2].axis('off')
            continue # Skip to next image

        # Find corresponding mask (YOLOv8 format)
        mask_file = os.path.splitext(img_file)[0] + '.txt'
        mask_path = os.path.join(train_mask_dir, mask_file)

        # Create empty mask for filled visualization
        filled_mask = np.zeros(img.size[::-1], dtype=np.uint8) # (height, width)

        # Plot original image in first column
        ax_img = axes[i, 0] if num_samples > 1 else axes[0]
        ax_img.imshow(img_np)
        ax_img.set_title(f"Image: {img_file}")
        ax_img.axis('off')

        # Plot image for raw polygon drawing in third column
        ax_raw = axes[i, 2] if num_samples > 1 else axes[2]
        ax_raw.imshow(img_np)
        ax_raw.set_title(f"Raw Polygons (File: {mask_file})")
        ax_raw.axis('off')

        if os.path.exists(mask_path):
            # Read YOLOv8 format annotations
            with open(mask_path, 'r') as f:
                lines = f.readlines()

            print(f"\nProcessing Annotations for: {img_file}")
            found_polygons = False
            for line_idx, line in enumerate(lines):
                parts = line.strip().split(' ')
                if len(parts) < 5:
                    print(f"  Line {line_idx+1}: Malformed annotation (less than 5 parts)")
                    continue

                try:
                    class_id = int(parts[0])
                    current_class_name = class_names[class_id] if class_id < len(class_names) else f"Class {class_id}"

                    # Filter for farm class IF farm_class_id was found, otherwise process all
                    if farm_class_id is not None and class_id != farm_class_id:
                         print(f"  Line {line_idx+1}: Skipping class '{current_class_name}' (ID {class_id}), not farm class.")
                         continue
                    else:
                         print(f"  Line {line_idx+1}: Processing class '{current_class_name}' (ID {class_id})")


                    # Check if we have polygon points (instance segmentation)
                    if len(parts) > 5:
                        found_polygons = True
                        # Extract polygon points
                        polygon_points_pixels = []
                        for j in range(5, len(parts), 2):
                            if j + 1 < len(parts):
                                x = float(parts[j]) * img_width
                                y = float(parts[j+1]) * img_height
                                polygon_points_pixels.append((int(x), int(y)))
                            else:
                                print(f"  Warning: Odd number of polygon coordinates on line {line_idx+1}")
                                break # Stop processing points for this line

                        print(f"    Found {len(polygon_points_pixels)} vertices.")

                        if len(polygon_points_pixels) >= 3: # Need at least 3 points for a polygon
                            # Convert to numpy array for OpenCV
                            pts = np.array(polygon_points_pixels, np.int32)
                            pts = pts.reshape((-1, 1, 2))

                            # --- Draw Filled Polygon on Mask (Column 2) ---
                            cv2.fillPoly(filled_mask, [pts], 255)

                            # --- Draw Raw Polygon Outline/Vertices (Column 3) ---
                            # Use a unique color for each polygon if needed, here using lime
                            outline_color_bgr = (0, 255, 0) # Lime Green in BGR for OpenCV
                            vertex_color_bgr = (0, 0, 255) # Red in BGR

                            # Draw outline directly onto the numpy image copy
                            # We draw on img_np which is displayed by ax_raw
                            cv2.polylines(img_np, [pts], isClosed=True, color=outline_color_bgr, thickness=1)

                            # Draw vertices
                            for k, (px, py) in enumerate(polygon_points_pixels):
                                cv2.circle(img_np, (px, py), radius=2, color=vertex_color_bgr, thickness=-1)
                                # Optional: Add vertex number text (can get crowded)
                                # cv2.putText(img_np, str(k+1), (px+2, py+2), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255,255,255), 1)

                        else:
                             print(f"    Skipping polygon on line {line_idx+1}: Not enough vertices ({len(polygon_points_pixels)} found).")

                    else: # Use bounding box if no polygon points
                        print(f"    No polygon points found, using bounding box.")
                        x_center = float(parts[1]) * img_width
                        y_center = float(parts[2]) * img_height
                        width = float(parts[3]) * img_width
                        height = float(parts[4]) * img_height

                        x1 = max(0, int(x_center - width / 2))
                        y1 = max(0, int(y_center - height / 2))
                        x2 = min(img_width - 1, int(x_center + width / 2))
                        y2 = min(img_height - 1, int(y_center + height / 2))

                        # Draw rectangle on filled mask (Column 2)
                        cv2.rectangle(filled_mask, (x1, y1), (x2, y2), 255, -1) # Fill rectangle

                        # Draw bounding box outline (Column 3)
                        # Use a different color for bounding boxes, e.g., Blue
                        bbox_color_bgr = (255, 0, 0) # Blue in BGR
                        cv2.rectangle(img_np, (x1, y1), (x2, y2), bbox_color_bgr, thickness=1)

                except ValueError as ve:
                     print(f"  Error parsing line {line_idx+1}: {ve} - Line: '{line.strip()}'")
                except IndexError as ie:
                     print(f"  Error accessing class name for ID {class_id} on line {line_idx+1}: {ie}")
                except Exception as ex:
                     print(f"  Unexpected error processing line {line_idx+1}: {ex}")


            # Update the display for the third column after drawing all raw polygons/bboxes
            ax_raw.imshow(img_np) # Re-display image with drawings

            if not found_polygons:
                 print(f"  Note: No polygon data found in file {mask_file}, only bounding boxes (if any).")

        else:
            print(f"Mask file not found: {mask_path}")
            # Indicate missing mask in titles
            if num_samples == 1:
                 axes[1].set_title("Mask file missing")
                 axes[2].set_title("Raw Polygons N/A")
            else:
                 axes[i, 1].set_title("Mask file missing")
                 axes[i, 2].set_title("Raw Polygons N/A")

        # Display filled mask (Column 2)
        ax_mask = axes[i, 1] if num_samples > 1 else axes[1]
        ax_mask.imshow(filled_mask, cmap='gray')
        ax_mask.set_title(f"Generated Mask (from YOLO file)")
        ax_mask.axis('off')

        # Turn off axis for the third column if it wasn't already
        ax_raw.axis('off')


    plt.tight_layout(rect=[0, 0.03, 1, 0.97]) # Adjust layout slightly
    plt.suptitle("Sample Images, Generated Masks, and Raw Annotations", fontsize=16)
    plt.show()

# --- Run the visualization ---
try:
    # Make sure 'dataset.location' is defined from the download cell
    if 'dataset' in locals() and hasattr(dataset, 'location'):
         visualize_samples(dataset.location, num_samples=5) # Increase num_samples if desired
    else:
         print("Error: 'dataset' variable or 'dataset.location' not defined.")
         print("Please ensure the Roboflow download cell has been run successfully.")
except NameError:
     print("Error: 'dataset' variable not defined.")
     print("Please ensure the Roboflow download cell has been run successfully.")
except Exception as e:
    print(f"An error occurred during visualization: {e}")
    print("Note: Adjust the visualization code based on the actual data format and file structure.")


ModuleNotFoundError: No module named 'yaml'

### 2.4 Data Preparation

We need to convert YOLOv8 format annotations to segmentation masks for U-Net training.

In [None]:
# === Section 2.3: Dataset Exploration and Definition ===

# Define paths after loading the dataset
# Ensure the 'dataset' variable is available from the download cell
if 'dataset' in locals() and hasattr(dataset, 'location'):
    dataset_base_dir = dataset.location # For finding data.yaml and overall root
    train_img_dir = os.path.join(dataset_base_dir, 'train', 'images')
    train_mask_dir = os.path.join(dataset_base_dir, 'train', 'labels')
    val_img_dir = os.path.join(dataset_base_dir, 'valid', 'images')
    val_mask_dir = os.path.join(dataset_base_dir, 'valid', 'labels')
    print(f"Dataset paths set using location: {dataset.location}")
else:
    # Fallback or error if dataset location is not defined
    print("ERROR: 'dataset' variable or 'dataset.location' not found.")
    print("Please run the Roboflow download cell first.")
    # Define dummy paths to avoid crashing subsequent cells, but processing will fail
    dataset_base_dir = "."
    train_img_dir = os.path.join(dataset_base_dir, 'train', 'images')
    train_mask_dir = os.path.join(dataset_base_dir, 'train', 'labels')
    val_img_dir = os.path.join(dataset_base_dir, 'valid', 'images')
    val_mask_dir = os.path.join(dataset_base_dir, 'valid', 'labels')

# --- Custom Dataset Class ---
# Create a custom dataset class to load images and generate masks
class FarmlandDataset(Dataset):
    """
    PyTorch Dataset for loading satellite images and generating segmentation masks
    from YOLOv8 polygon annotation files.
    """
    def __init__(self, img_dir, mask_dir, dataset_root_dir, transform=None, farm_class_name="farm", img_size=256):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.img_size = img_size # Target size for model input (used if no transform)
        self.farm_class_name = farm_class_name

        # Filter for valid image files only
        try:
            self.img_files = sorted([
                f for f in os.listdir(img_dir)
                if f.lower().endswith(('.png', '.jpg', '.jpeg'))
            ])
            if not self.img_files:
                 print(f"Warning: No image files found in {self.img_dir}")
        except FileNotFoundError:
             print(f"Error: Image directory not found: {self.img_dir}")
             self.img_files = []

        self.class_names = ['Unknown']
        self.farm_class_id = None # Determined from data.yaml
        self._load_class_info(dataset_root_dir)


    def _load_class_info(self, dataset_root_dir):
        """Loads class names and identifies the farm class ID from data.yaml."""
        yaml_path = os.path.join(dataset_root_dir, 'data.yaml')
        if os.path.exists(yaml_path):
            try:
                import yaml
                with open(yaml_path, 'r') as f:
                    data_yaml = yaml.safe_load(f)
                    if 'names' in data_yaml:
                        self.class_names = data_yaml['names']
                        print(f"Dataset Class Names: {self.class_names}")
                        # Find the specific farm class ID using the name
                        for i, name in enumerate(self.class_names):
                            if self.farm_class_name.lower() in name.lower():
                                self.farm_class_id = i
                                print(f"Found target class '{name}' with ID: {self.farm_class_id}")
                                break
                        if self.farm_class_id is None:
                            print(f"Warning: Target class name '{self.farm_class_name}' not found in data.yaml names: {self.class_names}")
                    else:
                        print(f"Warning: 'names' key not found in {yaml_path}")

            except ImportError:
                 print("Warning: PyYAML not installed. Cannot read class names from data.yaml.")
            except Exception as e:
                print(f"Error reading class names from data.yaml: {e}")
        else:
             print(f"Warning: data.yaml not found at {yaml_path}. Cannot determine farm class ID automatically.")
             print("The dataset will attempt to use ALL annotations if farm_class_id remains None.")

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        if idx >= len(self.img_files):
             raise IndexError("Index out of bounds")

        img_filename = self.img_files[idx]
        img_path = os.path.join(self.img_dir, img_filename)

        try:
            # Load image
            image = Image.open(img_path).convert("RGB")
            original_width, original_height = image.size
        except Exception as e:
            print(f"Error loading image {img_path}: {e}. Returning dummy data.")
            # Return dummy tensors to avoid crashing DataLoader worker
            dummy_image = torch.zeros((3, self.img_size, self.img_size))
            dummy_mask = torch.zeros((1, self.img_size, self.img_size)) # Match mask dims
            return dummy_image, dummy_mask

        # --- Mask Generation ---
        mask_file = os.path.splitext(img_filename)[0] + '.txt'
        mask_path = os.path.join(self.mask_dir, mask_file)

        # Create empty mask with the *original* image size first
        # Use float32 for direct conversion to tensor later, filled with 0.0
        mask = np.zeros((original_height, original_width), dtype=np.float32)

        if os.path.exists(mask_path):
            try:
                with open(mask_path, 'r') as f:
                    lines = f.readlines()

                processed_polygon = False
                for line in lines:
                    parts = line.strip().split(' ')
                    if len(parts) < 5: continue # Skip malformed lines (need at least class + 2 points)

                    class_id = int(parts[0])

                    # Only process the target farm class if ID is known and matches
                    # If farm_class_id is None, process all classes found.
                    if self.farm_class_id is not None and class_id != self.farm_class_id:
                        continue

                    # --- PRIORITIZE POLYGON DATA for segmentation ---
                    if len(parts) > 5: # Indicates polygon points are present
                        # Extract polygon points (normalized)
                        polygon_points_normalized = []
                        # Iterate over coordinate pairs (parts[1], parts[2]), (parts[3], parts[4]), ...
                        # For YOLOv8 segmentation format, it's: class x1 y1 x2 y2 ... xN yN
                        for j in range(1, len(parts), 2):
                            if j + 1 < len(parts):
                                x_norm = float(parts[j])
                                y_norm = float(parts[j+1])
                                polygon_points_normalized.append((x_norm, y_norm))

                        # Denormalize to pixel coordinates
                        polygon_points_pixels = [
                            (int(x * original_width), int(y * original_height))
                            for x, y in polygon_points_normalized
                        ]

                        if len(polygon_points_pixels) >= 3:
                            pts = np.array(polygon_points_pixels, np.int32)
                            pts = pts.reshape((-1, 1, 2))
                            # Fill polygon with 1.0 (for float32 mask)
                            cv2.fillPoly(mask, [pts], 1.0) # Use 1.0 for float mask
                            processed_polygon = True
                        # else: (Optional) print warning about insufficient points for polygon
                    # else:
                         # If len(parts) == 5, it's a bounding box. Ignore for segmentation mask.
                         # print(f"Skipping bounding box annotation for {img_filename} line: {line.strip()}")
                         # pass # Explicitly do nothing for bounding boxes

                # If no valid farm polygons were found in the file, mask remains zeros
                # if not processed_polygon and self.farm_class_id is not None:
                #    print(f"Warning: No valid polygons found for farm class {self.farm_class_id} in {mask_path}")


            except ValueError as ve:
                 print(f"Error parsing values in {mask_path} for image {img_filename}: {ve}. Line: '{line.strip()}'")
            except Exception as e:
                 print(f"Error processing annotation file {mask_path} for image {img_filename}: {e}")
                 # Mask remains zeros if annotation processing fails

        # --- Transformations ---
        # Convert mask numpy array to PIL Image to apply transforms consistently ONLY IF NEEDED BY TRANSFORM
        # Usually, resizing is done separately for masks

        # Default resize and tensor conversion if no transform is provided
        if not self.transform:
            resizer = transforms.Resize((self.img_size, self.img_size), interpolation=transforms.InterpolationMode.BILINEAR) # For image
            to_tensor = transforms.ToTensor()

            image_resized = resizer(image)
            image_tensor = to_tensor(image_resized)

            # Resize mask using NEAREST interpolation
            mask_resized = cv2.resize(mask, (self.img_size, self.img_size), interpolation=cv2.INTER_NEAREST)
            # Add channel dimension [H, W] -> [1, H, W]
            mask_tensor = torch.from_numpy(mask_resized).float().unsqueeze(0)

        else:
             # Apply transforms to image (usually includes resize, ToTensor, Normalize)
             image_tensor = self.transform(image)

             # Resize mask separately using NEAREST interpolation to match image tensor size
             # Target size from the transformed image tensor
             target_h, target_w = image_tensor.shape[-2:]
             mask_resized = cv2.resize(mask, (target_w, target_h), interpolation=cv2.INTER_NEAREST)
             # Convert resized mask to tensor, add channel dimension
             mask_tensor = torch.from_numpy(mask_resized).float().unsqueeze(0)


        # Final check for shape consistency
        if image_tensor.shape[-2:] != mask_tensor.shape[-2:]:
              print(f"Warning: Final Image tensor shape {image_tensor.shape} != Final Mask tensor shape {mask_tensor.shape} for {img_filename}. Attempting F.interpolate.")
              # Requires: import torch.nn.functional as F
              try:
                  mask_tensor = F.interpolate( # Ensure torch.nn.functional is imported as F earlier
                       mask_tensor.unsqueeze(0), # Add batch dim for interpolate [N, C, H, W]
                       size=image_tensor.shape[-2:],
                       mode='nearest'
                  ).squeeze(0) # Remove batch dim
                  print(f"Mask reshaped via interpolate to: {mask_tensor.shape}")
              except Exception as interp_e:
                   print(f"ERROR: Could not reshape mask using F.interpolate: {interp_e}")
                   # Return dummy tensors if resizing fails critically
                   dummy_image = torch.zeros((3, self.img_size, self.img_size))
                   dummy_mask = torch.zeros((1, self.img_size, self.img_size))
                   return dummy_image, dummy_mask


        # Clamp mask values to be safe, although fillPoly uses 1.0
        mask_tensor = torch.clamp(mask_tensor, 0.0, 1.0)

        return image_tensor, mask_tensor


# --- Visualization Function (Debugging Raw Annotations) ---
# Visualize some sample images and how masks are generated from the raw YOLO files
import yaml # Ensure yaml is imported

def visualize_samples(data_dir, num_samples=3):
    """Visualizes original images, generated masks from YOLO files, and raw annotations."""
    print("\n--- Starting Sample Visualization ---")
    # Paths for train images and labels
    train_img_dir = os.path.join(data_dir, 'train', 'images')
    train_mask_dir = os.path.join(data_dir, 'train', 'labels')

    if not os.path.isdir(train_img_dir):
        print(f"Error: Train image directory not found: {train_img_dir}")
        return
    if not os.path.isdir(train_mask_dir):
        print(f"Error: Train mask directory not found: {train_mask_dir}")
        return

    img_files_all = os.listdir(train_img_dir)
    # Ensure we only process image files and handle potential non-image files
    img_files = [f for f in img_files_all if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    # Shuffle or select randomly if desired
    # import random
    # random.shuffle(img_files)
    img_files = img_files[:num_samples]

    if not img_files:
        print(f"No image files found in {train_img_dir}")
        return

    # Read the data.yaml file to get class information
    yaml_path = os.path.join(data_dir, 'data.yaml')
    class_names = ['Unknown']
    farm_class_id = None
    if os.path.exists(yaml_path):
        try:
            with open(yaml_path, 'r') as f:
                data_yaml = yaml.safe_load(f)
                if 'names' in data_yaml:
                    class_names = data_yaml['names']
                    print(f"Classes found in dataset: {class_names}")
                    # Try to find the 'farm' class ID (case-insensitive)
                    for i, name in enumerate(class_names):
                        if "farm" in name.lower():
                            farm_class_id = i
                            print(f"Identified 'farm' class ID for visualization: {farm_class_id}")
                            break
        except Exception as e:
            print(f"Error reading class names from data.yaml: {e}")
    else:
        print(f"Warning: data.yaml not found at {yaml_path}. Class names unknown for visualization.")

    fig, axes = plt.subplots(num_samples, 3, figsize=(18, 6 * num_samples))
    # Adjust title if only one sample
    if num_samples == 1:
        fig.suptitle("Sample Image, Generated Mask, and Raw Annotations", fontsize=16)
    else:
        plt.suptitle("Sample Images, Generated Masks, and Raw Annotations", fontsize=16)


    for i, img_file in enumerate(img_files):
        # Define axes for this row
        ax_img = axes[i, 0] if num_samples > 1 else axes[0]
        ax_mask = axes[i, 1] if num_samples > 1 else axes[1]
        ax_raw = axes[i, 2] if num_samples > 1 else axes[2]

        # Load image
        img_path = os.path.join(train_img_dir, img_file)
        try:
            img = Image.open(img_path).convert("RGB")
            img_np = np.array(img) # Keep a numpy copy for drawing outlines
            img_height, img_width = img_np.shape[:2] # Correct order for numpy array
        except Exception as e:
            print(f"Error loading image {img_file}: {e}")
            ax_img.set_title(f"Error loading {img_file}"); ax_img.axis('off')
            ax_mask.set_title("Mask N/A"); ax_mask.axis('off')
            ax_raw.set_title("Raw Annotations N/A"); ax_raw.axis('off')
            continue # Skip to next image

        # Find corresponding mask file (YOLOv8 format)
        mask_file = os.path.splitext(img_file)[0] + '.txt'
        mask_path = os.path.join(train_mask_dir, mask_file)

        # Create empty mask for filled visualization (same size as original image)
        filled_mask = np.zeros((img_height, img_width), dtype=np.uint8)

        # Plot original image in first column
        ax_img.imshow(img_np)
        ax_img.set_title(f"Image: {img_file}")
        ax_img.axis('off')

        # Copy original image for drawing raw annotations in third column
        img_np_raw = img_np.copy()
        ax_raw.set_title(f"Raw Annots (File: {mask_file})")
        ax_raw.axis('off') # Turn off axis first

        if os.path.exists(mask_path):
            # Read YOLOv8 format annotations
            try:
                with open(mask_path, 'r') as f:
                    lines = f.readlines()

                print(f"\nProcessing Annotations for visualization: {img_file}")
                found_farm_polygon = False
                found_farm_bbox = False

                for line_idx, line in enumerate(lines):
                    parts = line.strip().split(' ')
                    if len(parts) < 5:
                        print(f"  Line {line_idx+1}: Malformed annotation (less than 5 parts)")
                        continue

                    try:
                        class_id = int(parts[0])
                        current_class_name = class_names[class_id] if class_id < len(class_names) else f"Class {class_id}"

                        # Filter: Only draw annotations for the 'farm' class if ID is known
                        if farm_class_id is not None and class_id != farm_class_id:
                             # print(f"  Line {line_idx+1}: Skipping class '{current_class_name}' (ID {class_id}), not farm class.")
                             continue
                        # else: # If farm_class_id is None, process all classes
                             # print(f"  Line {line_idx+1}: Processing class '{current_class_name}' (ID {class_id}) for visualization")

                        # Check if we have polygon points (YOLOv8 segmentation format)
                        # class x1 y1 x2 y2 ... xN yN
                        if len(parts) > 5 and len(parts) % 2 == 1: # Must have odd number of parts (class + pairs)
                            is_polygon = True
                            coords = parts[1:]
                        # Check if we have bounding box points (YOLOv8 detection format)
                        # class x_center y_center width height
                        elif len(parts) == 5:
                            is_polygon = False
                            coords = parts[1:]
                        else:
                            print(f"  Line {line_idx+1}: Ambiguous annotation format (parts: {len(parts)})")
                            continue


                        if is_polygon:
                            # --- Polygon Processing ---
                            polygon_points_pixels = []
                            for j in range(0, len(coords), 2): # Iterate through x, y pairs
                                x = float(coords[j]) * img_width
                                y = float(coords[j+1]) * img_height
                                polygon_points_pixels.append((int(x), int(y)))

                            # print(f"    Found {len(polygon_points_pixels)} polygon vertices for class {class_id}.")

                            if len(polygon_points_pixels) >= 3:
                                found_farm_polygon = True # Mark that we found one for this class
                                pts = np.array(polygon_points_pixels, np.int32).reshape((-1, 1, 2))

                                # --- Draw Filled Polygon on Mask (Column 2) ---
                                cv2.fillPoly(filled_mask, [pts], 255) # Fill with white

                                # --- Draw Raw Polygon Outline/Vertices (Column 3) ---
                                outline_color_bgr = (0, 255, 0) # Lime Green (BGR)
                                vertex_color_bgr = (0, 0, 255) # Red (BGR)

                                cv2.polylines(img_np_raw, [pts], isClosed=True, color=outline_color_bgr, thickness=1)
                                for k, (px, py) in enumerate(polygon_points_pixels):
                                    cv2.circle(img_np_raw, (px, py), radius=2, color=vertex_color_bgr, thickness=-1)
                            else:
                                 print(f"    Skipping polygon on line {line_idx+1}: Not enough vertices ({len(polygon_points_pixels)} found).")

                        else: # Bounding Box Processing
                            # print(f"    Found bounding box for class {class_id}.")
                            found_farm_bbox = True # Mark that we found one for this class
                            x_center, y_center, width, height = map(float, coords)

                            x1 = max(0, int((x_center - width / 2) * img_width))
                            y1 = max(0, int((y_center - height / 2) * img_height))
                            x2 = min(img_width - 1, int((x_center + width / 2) * img_width))
                            y2 = min(img_height - 1, int((y_center + height / 2) * img_height))

                            # --- Draw Filled Rectangle on Mask (Column 2) ---
                            # NOTE: This is only for VISUALIZATION. Dataset ignores bboxes for masks.
                            cv2.rectangle(filled_mask, (x1, y1), (x2, y2), 128, -1) # Fill with gray to differentiate

                            # --- Draw Raw Bounding Box Outline (Column 3) ---
                            bbox_color_bgr = (255, 0, 0) # Blue (BGR)
                            cv2.rectangle(img_np_raw, (x1, y1), (x2, y2), bbox_color_bgr, thickness=1)

                    except ValueError as ve:
                         print(f"  Error parsing values on line {line_idx+1}: {ve} - Line: '{line.strip()}'")
                    except IndexError as ie:
                         print(f"  Error accessing class name for ID {class_id} on line {line_idx+1}: {ie}")
                    except Exception as ex:
                         print(f"  Unexpected error processing line {line_idx+1}: {ex}")

                # Report findings for the specific farm class
                if farm_class_id is not None:
                    if found_farm_polygon:
                        print(f"  Visualized polygons for farm class {farm_class_id}.")
                    if found_farm_bbox:
                        print(f"  Visualized bounding boxes for farm class {farm_class_id} (Note: Dataset ignores these for masks).")
                    if not found_farm_polygon and not found_farm_bbox:
                         print(f"  No annotations found for farm class {farm_class_id} in {mask_file}.")

            except Exception as read_ex:
                 print(f"Error reading or processing mask file {mask_path}: {read_ex}")
                 ax_mask.set_title("Error reading mask file")
                 ax_raw.set_title("Error reading mask file")


        else: # Mask file doesn't exist
            print(f"Mask file not found: {mask_path}")
            ax_mask.set_title("Mask file missing")
            ax_raw.set_title("Annotation file missing")

        # Display filled mask (Column 2) - Shows polygons (white) and bboxes (gray) if found
        ax_mask.imshow(filled_mask, cmap='gray', vmin=0, vmax=255)
        ax_mask.set_title("Generated Mask (Polygons=White, BBox=Gray)")
        ax_mask.axis('off')

        # Display image with raw annotations drawn (Column 3)
        ax_raw.imshow(img_np_raw)
        ax_raw.axis('off') # Ensure axis is off


    plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout slightly
    plt.show()
    print("--- Finished Sample Visualization ---")


# --- Run the visualization ---
# Make sure 'dataset_base_dir' is defined and valid
if 'dataset_base_dir' in locals() and os.path.isdir(dataset_base_dir):
    try:
        visualize_samples(dataset_base_dir, num_samples=5) # Increase num_samples if desired
    except NameError as ne:
         print(f"A NameError occurred during visualization: {ne}")
         print("Ensure all necessary variables and functions are defined.")
    except Exception as e:
        print(f"An unexpected error occurred during visualization: {e}")
        import traceback
        traceback.print_exc() # Print detailed traceback for debugging
        print("Check the visualization code and dataset paths.")
else:
     print("Error: 'dataset_base_dir' is not defined or is not a valid directory.")
     print("Please ensure the Roboflow download cell and path definition cell have been run successfully.")


In [None]:
# Set up augmented data transformations to prevent overfitting
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(p=0.5),  # Random horizontal flips
    transforms.RandomVerticalFlip(p=0.5),    # Random vertical flips
    transforms.RandomRotation(10),           # Random rotations up to 10 degrees
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Color jittering
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Keep validation transform simple
val_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create datasets with separate transforms
train_dataset = FarmlandDataset(train_img_dir, train_mask_dir, transform=train_transform)
val_dataset = FarmlandDataset(val_img_dir, val_mask_dir, transform=val_transform)

# Optimize batch size based on available GPUs and model complexity
if torch.cuda.device_count() > 1:
    batch_size = 24  # Increased batch size for multiple GPUs
else:
    batch_size = 8   # Default batch size for single GPU

# Determine optimal number of workers for data loading
num_workers = min(os.cpu_count(), 16) if os.cpu_count() else 4

# Configure DataLoader with more aggressive prefetching and optimized memory usage
train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=num_workers,        # More workers for faster data loading
    pin_memory=True,                # Use pinned memory for faster CPU->GPU transfer
    prefetch_factor=4,              # Prefetch more batches
    persistent_workers=True,        # Keep worker processes alive between iterations
    drop_last=True                  # Drop last incomplete batch for better performance
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=num_workers,
    pin_memory=True,
    prefetch_factor=4,
    persistent_workers=True
)

# Adjust CUDA settings for optimal performance
torch.backends.cudnn.benchmark = True  # Enable cuDNN auto-tuner to find the best algorithm
torch.backends.cudnn.deterministic = False
torch.backends.cuda.matmul.allow_tf32 = True  # Allow TF32 on Ampere GPUs
torch.backends.cudnn.allow_tf32 = True        # Allow TF32 on Ampere GPUs

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Using batch size: {batch_size} with {torch.cuda.device_count()} GPUs")
print(f"DataLoader configured with {num_workers} worker processes")
print(f"Data augmentation enabled for training to prevent overfitting")
print(f"CUDA optimizations enabled: benchmark={torch.backends.cudnn.benchmark}, TF32={torch.backends.cuda.matmul.allow_tf32}")

## 3. Modeling

### 3.1 U-Net Architecture

In [None]:
# Efficient Model implementation with optimized memory usage
class EfficientUNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1):
        super(EfficientUNet, self).__init__()
        
        # Use GroupNorm instead of BatchNorm for better performance with large batch sizes
        # and more efficiency on multiple GPUs
        
        # Encoder (downsampling)
        self.enc1 = nn.Sequential(
            nn.Conv2d(n_channels, 64, kernel_size=3, padding=1),
            nn.GroupNorm(num_groups=8, num_channels=64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.GroupNorm(num_groups=8, num_channels=64),
            nn.ReLU(inplace=True)
        )
        self.pool1 = nn.MaxPool2d(2)
        
        self.enc2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.GroupNorm(num_groups=16, num_channels=128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.GroupNorm(num_groups=16, num_channels=128),
            nn.ReLU(inplace=True)
        )
        self.pool2 = nn.MaxPool2d(2)
        
        self.enc3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.GroupNorm(num_groups=32, num_channels=256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.GroupNorm(num_groups=32, num_channels=256),
            nn.ReLU(inplace=True)
        )
        self.pool3 = nn.MaxPool2d(2)
        
        self.enc4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.GroupNorm(num_groups=32, num_channels=512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.GroupNorm(num_groups=32, num_channels=512),
            nn.ReLU(inplace=True)
        )
        self.pool4 = nn.MaxPool2d(2)
        
        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(512, 1024, kernel_size=3, padding=1),
            nn.GroupNorm(num_groups=32, num_channels=1024),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 1024, kernel_size=3, padding=1),
            nn.GroupNorm(num_groups=32, num_channels=1024),
            nn.ReLU(inplace=True)
        )
        
        # Decoder with memory-efficient skip connections
        self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=3, padding=1),
            nn.GroupNorm(num_groups=32, num_channels=512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.GroupNorm(num_groups=32, num_channels=512),
            nn.ReLU(inplace=True)
        )
        
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.GroupNorm(num_groups=32, num_channels=256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.GroupNorm(num_groups=32, num_channels=256),
            nn.ReLU(inplace=True)
        )
        
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.GroupNorm(num_groups=16, num_channels=128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.GroupNorm(num_groups=16, num_channels=128),
            nn.ReLU(inplace=True)
        )
        
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.GroupNorm(num_groups=8, num_channels=64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.GroupNorm(num_groups=8, num_channels=64),
            nn.ReLU(inplace=True)
        )
        
        # Final output layer
        self.outconv = nn.Conv2d(64, n_classes, kernel_size=1)
        
        # Initialize weights properly for faster convergence
        self._initialize_weights()
        
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.GroupNorm):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        p1 = self.pool1(e1)
        
        e2 = self.enc2(p1)
        p2 = self.pool2(e2)
        
        e3 = self.enc3(p2)
        p3 = self.pool3(e3)
        
        e4 = self.enc4(p3)
        p4 = self.pool4(e4)
        
        # Bottleneck
        b = self.bottleneck(p4)
        
        # Decoder with skip connections
        u4 = self.upconv4(b)
        u4 = torch.cat([u4, e4], dim=1)  # Skip connection
        d4 = self.dec4(u4)
        
        u3 = self.upconv3(d4)
        u3 = torch.cat([u3, e3], dim=1)  # Skip connection
        d3 = self.dec3(u3)
        
        u2 = self.upconv2(d3)
        u2 = torch.cat([u2, e2], dim=1)  # Skip connection
        d2 = self.dec2(u2)
        
        u1 = self.upconv1(d2)
        u1 = torch.cat([u1, e1], dim=1)  # Skip connection
        d1 = self.dec1(u1)
        
        # Final output
        out = self.outconv(d1)
        return torch.sigmoid(out)  # Apply sigmoid for binary segmentation

In [None]:
# Initialize the efficient model and optimizer
model = EfficientUNet(n_channels=3, n_classes=1)

# Use DataParallel if multiple GPUs are available
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs for parallel training")
    model = nn.DataParallel(model)

model = model.to(device)

# Define loss function and optimizer with weight decay for regularization
criterion = nn.BCEWithLogitsLoss()  # Combines sigmoid and BCE for better numerical stability
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)  # AdamW with weight decay to reduce overfitting

# Print model summary
print(model)

### 3.2 Training Functions

In [None]:
# Define training and validation functions
def train_epoch(model, dataloader, criterion, optimizer, device, amp_scaler=None):
    model.train()
    running_loss = 0.0
    
    # Use tqdm for progress tracking
    for images, masks in tqdm(dataloader, desc="Training"):
        images = images.to(device, non_blocking=True)  # Use non_blocking for async transfer
        masks = masks.to(device, non_blocking=True)
        
        # Clear gradients
        optimizer.zero_grad()  # Removed flush=True which is not supported in this version
        
        # Mixed precision training - updated to use newer API
        if amp_scaler is not None:
            with torch.amp.autocast(device_type='cuda'):  # Updated from torch.cuda.amp.autocast()
                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, masks)
            
            # Backward pass with gradient scaling
            amp_scaler.scale(loss).backward()
            amp_scaler.step(optimizer)
            amp_scaler.update()
        else:
            # Standard training
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
        
        running_loss += loss.item() * images.size(0)
    
    epoch_loss = running_loss / len(dataloader.dataset)
    return epoch_loss

def validate_epoch(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    
    with torch.no_grad():
        for images, masks in tqdm(dataloader, desc="Validation"):
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            running_loss += loss.item() * images.size(0)
    
    epoch_loss = running_loss / len(dataloader.dataset)
    return epoch_loss

# Function to calculate IoU (Intersection over Union)
def calculate_iou(pred, target, threshold=0.5):
    pred_binary = (pred > threshold).float()
    intersection = (pred_binary * target).sum()
    union = pred_binary.sum() + target.sum() - intersection
    
    iou = (intersection + 1e-8) / (union + 1e-8)  # Adding small epsilon to avoid division by zero
    return iou.item()

# Function to evaluate model on validation set
def evaluate_model(model, dataloader, device, threshold=0.5):
    model.eval()
    iou_scores = []
    
    with torch.no_grad():
        for images, masks in tqdm(dataloader, desc="Evaluating"):
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward pass
            outputs = model(images)
            
            # Calculate IoU for each image in batch
            for i in range(outputs.size(0)):
                iou = calculate_iou(outputs[i], masks[i], threshold)
                iou_scores.append(iou)
    
    mean_iou = sum(iou_scores) / len(iou_scores)
    return mean_iou

### 3.3 Training Loop

In [None]:
# Enhanced training loop with mixed precision, gradient accumulation, and early stopping
num_epochs = 30
best_val_loss = float('inf')
best_model_path = 'best_unet_model.pth'
patience = 5  # Number of epochs to wait for improvement before early stopping
no_improve_epochs = 0
# current_lr = 0.001 # Initial LR is set in the optimizer, no need to track separately here initially

# Lists to store training history
train_losses = []
val_losses = []
val_ious = []

# Set up mixed precision training for faster computation
scaler = None # Initialize scaler
if torch.cuda.is_available():
    scaler = torch.amp.GradScaler()
    print("Using mixed precision training for faster performance")
else:
    scaler = None

# Set up learning rate scheduler (removed verbose=True)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3
)

# Training loop
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    torch.cuda.empty_cache()  # Clear GPU memory before each epoch

    # Train
    train_loss = train_epoch(model, train_loader, criterion, optimizer, device, amp_scaler=scaler)
    train_losses.append(train_loss)

    # Validate
    val_loss = validate_epoch(model, val_loader, criterion, device)
    val_losses.append(val_loss)

    # Calculate IoU
    val_iou = evaluate_model(model, val_loader, device)
    val_ious.append(val_iou)

    # Get the current learning rate from the scheduler
    # Use get_last_lr() which returns a list of LRs for each param group
    last_lr = scheduler.get_last_lr()[0]

    # Update learning rate based on validation performance
    scheduler.step(val_loss)


    # Print epoch results
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Validation Loss: {val_loss:.4f}")
    print(f"Validation IoU: {val_iou:.4f}")
    # Print the LR *before* the scheduler potentially reduces it for the *next* epoch
    print(f"Learning rate for this epoch: {last_lr:.6f}")

    # Save best model
    if val_loss < best_val_loss:
        improvement = best_val_loss - val_loss
        best_val_loss = val_loss
        torch.save(model.state_dict(), best_model_path)
        print(f"Model improved by {improvement:.6f} and saved to {best_model_path}")
        no_improve_epochs = 0
    else:
        no_improve_epochs += 1
        print(f"No improvement for {no_improve_epochs} epochs")

        # Optional: Print a message if the LR was reduced by the scheduler
        if scheduler.get_last_lr()[0] < last_lr:
             print(f"Learning rate reduced to {scheduler.get_last_lr()[0]:.6f}")


    # Early stopping
    if no_improve_epochs >= patience:
        print(f"Early stopping triggered after {epoch+1} epochs")
        break

    print("-" * 50)

print("Training completed!")

### 3.4 Visualize Training Results

In [None]:
# Plot training history
plt.figure(figsize=(12, 4))

# Plot losses
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()

# Plot IoU
plt.subplot(1, 2, 2)
plt.plot(val_ious, label='Validation IoU')
plt.xlabel('Epoch')
plt.ylabel('IoU')
plt.title('Validation IoU')
plt.legend()

plt.tight_layout()
plt.show()

## 4. Farm Size Calculation and Classification

Now we'll use our trained model to segment farms and calculate their sizes.

In [None]:
# Load the best model
model.load_state_dict(torch.load(best_model_path))
model.eval()

# Function to segment farms in an image
def segment_farms(model, image_path, device, transform):
    # Load image
    image = Image.open(image_path).convert("RGB")
    original_size = image.size  # (width, height)
    
    # Preprocess image
    input_tensor = transform(image).unsqueeze(0).to(device)
    
    # Get model prediction
    with torch.no_grad():
        output = model(input_tensor)
        predicted_mask = output.squeeze().cpu().numpy()
    
    # Threshold to get binary mask
    binary_mask = (predicted_mask > 0.5).astype(np.uint8) * 255
    
    # Resize mask back to original image size
    binary_mask_resized = cv2.resize(binary_mask, (original_size[0], original_size[1]))
    
    return binary_mask_resized, np.array(image)

# Function to calculate farm sizes and classify them
def calculate_farm_sizes(binary_mask, pixels_per_meter=None):
    # Use connected component analysis to identify individual farms
    labeled_mask, num_farms = measure.label(binary_mask, connectivity=2, return_num=True)
    
    # Calculate properties of each labeled region
    regions = measure.regionprops(labeled_mask)
    
    # Store farm areas
    farm_areas = []
    
    for region in regions:
        # Skip very small regions (likely noise)
        if region.area < 100:  # Adjust threshold as needed
            continue
            
        # Calculate area in pixels
        area_pixels = region.area
        
        # Convert to real-world units if pixels_per_meter is provided
        if pixels_per_meter is not None:
            area_sq_meters = area_pixels / (pixels_per_meter ** 2)
            # Convert to hectares (1 hectare = 10,000 sq meters)
            area_hectares = area_sq_meters / 10000
            farm_areas.append(area_hectares)
        else:
            farm_areas.append(area_pixels)
    
    return farm_areas, labeled_mask

# Function to classify farms by size
def classify_farms(farm_areas, unit='pixels'):
    # Define size thresholds (adjust based on your specific context)
    if unit == 'hectares':
        # Real-world thresholds (in hectares)
        small_threshold = 10     # 0-10 hectares = small farm
        medium_threshold = 50    # 10-50 hectares = medium farm
        # > 50 hectares = large farm
    else:
        # Pixel-based thresholds (adjust based on your image resolution)
        small_threshold = 5000     # 0-5000 pixels = small farm
        medium_threshold = 20000   # 5000-20000 pixels = medium farm
        # > 20000 pixels = large farm
    
    # Classify each farm
    farm_classes = []
    for area in farm_areas:
        if area < small_threshold:
            farm_classes.append('Small')
        elif area < medium_threshold:
            farm_classes.append('Medium')
        else:
            farm_classes.append('Large')
    
    # Count farms in each category
    class_counts = {
        'Small': farm_classes.count('Small'),
        'Medium': farm_classes.count('Medium'),
        'Large': farm_classes.count('Large')
    }
    
    return farm_classes, class_counts

In [None]:
# Function to visualize farm segmentation and classification
def visualize_farm_classification(image, labeled_mask, farm_areas, farm_classes):
    # Create a colormap for visualization
    cmap = plt.cm.colors.ListedColormap(['black', 'green', 'yellow', 'red'])
    bounds = [0, 1, 2, 3, 4]
    norm = plt.cm.colors.BoundaryNorm(bounds, cmap.N)
    
    # Create a colored mask based on farm classification
    colored_mask = np.zeros_like(labeled_mask)
    
    for i, (region, area, farm_class) in enumerate(zip(measure.regionprops(labeled_mask), farm_areas, farm_classes)):
        # Skip very small regions
        if region.area < 100:
            continue
            
        # Assign color based on class
        if farm_class == 'Small':
            color_value = 1
        elif farm_class == 'Medium':
            color_value = 2
        else:  # Large
            color_value = 3
        
        # Fill region with corresponding color
        colored_mask[labeled_mask == region.label] = color_value
    
    # Plot results
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7))
    
    # Original image with segmentation overlay
    ax1.imshow(image)
    ax1.imshow(colored_mask, cmap=cmap, alpha=0.5, norm=norm)
    ax1.set_title('Farm Segmentation and Classification')
    ax1.axis('off')
    
    # Create custom legend
    legend_elements = [
        plt.Rectangle((0, 0), 1, 1, color='green', alpha=0.5, label='Small Farms'),
        plt.Rectangle((0, 0), 1, 1, color='yellow', alpha=0.5, label='Medium Farms'),
        plt.Rectangle((0, 0), 1, 1, color='red', alpha=0.5, label='Large Farms')
    ]
    ax1.legend(handles=legend_elements, loc='upper right')
    
    # Pie chart of farm size distribution
    class_counts = {
        'Small': farm_classes.count('Small'),
        'Medium': farm_classes.count('Medium'),
        'Large': farm_classes.count('Large')
    }
    
    if sum(class_counts.values()) > 0:  # Check if we have any farms
        labels = list(class_counts.keys())
        sizes = list(class_counts.values())
        colors = ['green', 'yellow', 'red']
        
        ax2.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%', startangle=90)
        ax2.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle.
        ax2.set_title('Farm Size Distribution')
    else:
        ax2.text(0.5, 0.5, 'No farms detected', horizontalalignment='center', verticalalignment='center')
        ax2.axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
# Process a sample image
def process_sample_image(image_path):
    # Segment farms
    binary_mask, image = segment_farms(model, image_path, device, transform)
    
    # Calculate farm sizes
    # Note: In a real application, you would need to determine pixels_per_meter based on image metadata
    farm_areas, labeled_mask = calculate_farm_sizes(binary_mask)
    
    # Classify farms
    farm_classes, class_counts = classify_farms(farm_areas)
    
    # Print results
    print(f"Number of farms detected: {len(farm_areas)}")
    print(f"Farm size classification: {class_counts}")
    
    # Visualize results
    visualize_farm_classification(image, labeled_mask, farm_areas, farm_classes)
    
    return farm_areas, farm_classes, class_counts

# Try with a test image from the validation set
test_img_dir = os.path.join(dataset.location, 'valid', 'images')
test_img_files = os.listdir(test_img_dir)

if test_img_files:
    test_img_path = os.path.join(test_img_dir, test_img_files[0])
    print(f"Processing test image: {test_img_path}")
    farm_areas, farm_classes, class_counts = process_sample_image(test_img_path)

## 5. Recommendation System Based on Farm Size

In [None]:
# Define recommendations based on farm size
def get_recommendations_by_size(farm_size):
    recommendations = {
        'Small': {
            'Crop Selection': [
                'Focus on high-value crops (e.g., specialty vegetables, herbs, berries)',
                'Consider intercropping to maximize land use',
                'Explore vertical farming techniques for space optimization'
            ],
            'Equipment': [
                'Invest in versatile, small-scale equipment',
                'Consider equipment sharing programs or cooperatives',
                'Focus on precision hand tools for specialized tasks'
            ],
            'Marketing': [
                'Direct-to-consumer sales (farmers markets, CSA)',
                'Develop value-added products',
                'Leverage organic or specialty certifications'
            ],
            'Sustainability': [
                'Implement intensive organic practices',
                'Consider agroecological approaches',
                'Explore permaculture design principles'
            ]
        },
        'Medium': {
            'Crop Selection': [
                'Balance between specialty and commodity crops',
                'Consider crop rotation systems',
                'Explore diversification strategies'
            ],
            'Equipment': [
                'Invest in mid-sized tractors and implements',
                'Consider precision agriculture technology',
                'Develop efficient irrigation systems'
            ],
            'Marketing': [
                'Develop relationships with local wholesalers and restaurants',
                'Consider cooperative marketing',
                'Explore agritourism opportunities'
            ],
            'Sustainability': [
                'Implement integrated pest management',
                'Consider conservation tillage practices',
                'Develop soil health management plans'
            ]
        },
        'Large': {
            'Crop Selection': [
                'Focus on efficient production of commodity crops',
                'Consider dedicating portions to specialty high-value crops',
                'Implement strategic crop rotation systems'
            ],
            'Equipment': [
                'Invest in large-scale, efficient machinery',
                'Implement precision agriculture and automation',
                'Consider GPS guidance systems and variable rate technology'
            ],
            'Marketing': [
                'Develop contracts with processors and distributors',
                'Consider futures markets and hedging strategies',
                'Explore export opportunities'
            ],
            'Sustainability': [
                'Implement conservation agriculture practices at scale',
                'Consider renewable energy investments',
                'Develop comprehensive nutrient management plans'
            ]
        }
    }
    
    return recommendations.get(farm_size, {})

# Function to display recommendations for a specific farm
def display_farm_recommendations(farm_class):
    recommendations = get_recommendations_by_size(farm_class)
    
    if not recommendations:
        print(f"No recommendations available for {farm_class} farms.")
        return
    
    print(f"\n=== Recommendations for {farm_class} Farms ===\n")
    
    for category, items in recommendations.items():
        print(f"\n{category}:")
        for item in items:
            print(f"  • {item}")
    
    print("\n" + "=" * 50)

In [None]:
# Display recommendations for each farm size category
for size in ['Small', 'Medium', 'Large']:
    display_farm_recommendations(size)

## 6. End-to-End Pipeline

In [None]:
# End-to-end pipeline function
def process_farm_image(image_path, pixel_scale=None):
    """
    Process a satellite image to detect farms, classify them by size, and provide recommendations.
    
    Args:
        image_path (str): Path to the satellite image
        pixel_scale (float, optional): Scale factor in meters per pixel (if available)
    
    Returns:
        dict: A dictionary containing the results of the analysis
    """
    print(f"Processing image: {image_path}")
    
    # Step 1: Segment farms using the trained U-Net model
    binary_mask, image = segment_farms(model, image_path, device, transform)
    
    # Step 2: Calculate farm sizes
    unit = 'hectares' if pixel_scale else 'pixels'
    farm_areas, labeled_mask = calculate_farm_sizes(binary_mask, pixel_scale)
    
    # Step 3: Classify farms by size
    farm_classes, class_counts = classify_farms(farm_areas, unit)
    
    # Step 4: Print summary
    print(f"\nFarm Analysis Summary:")
    print(f"Total farms detected: {len(farm_areas)}")
    print(f"Farm size distribution: {class_counts}")
    print(f"Area unit: {unit}")
    
    # Step 5: Calculate predominant farm size
    if len(farm_areas) > 0:
        predominant_size = max(class_counts, key=class_counts.get)
        print(f"Predominant farm size: {predominant_size}")
        
        # Step 6: Provide recommendations based on predominant farm size
        display_farm_recommendations(predominant_size)
    else:
        print("No farms detected in the image.")
    
    # Step 7: Visualize results
    if len(farm_areas) > 0:
        visualize_farm_classification(image, labeled_mask, farm_areas, farm_classes)
    
    # Return results
    return {
        'num_farms': len(farm_areas),
        'farm_areas': farm_areas,
        'farm_classes': farm_classes,
        'class_counts': class_counts,
        'predominant_size': predominant_size if len(farm_areas) > 0 else None
    }

# Test the end-to-end pipeline on a sample image (if available)
if test_img_files:
    test_img_path = os.path.join(test_img_dir, test_img_files[0])
    results = process_farm_image(test_img_path)
    
    # You could save these results to a file or database for future reference
    import json
    with open('farm_analysis_results.json', 'w') as f:
        # Convert non-serializable objects (like numpy arrays) to lists
        serializable_results = {
            'num_farms': results['num_farms'],
            'farm_areas': [float(area) for area in results['farm_areas']],
            'farm_classes': results['farm_classes'],
            'class_counts': results['class_counts'],
            'predominant_size': results['predominant_size']
        }
        json.dump(serializable_results, f, indent=4)

## 7. Conclusion and Next Steps

### 7.1 Summary

In this project, we have:
1. Loaded and processed a dataset of satellite imagery with farm annotations
2. Implemented and trained a U-Net model for farmland segmentation
3. Developed methods to calculate farm sizes from segmentation masks
4. Created a classification system to categorize farms by size
5. Built a recommendation system providing tailored advice based on farm size
6. Integrated all components into an end-to-end pipeline

### 7.2 Limitations

Current limitations of the system include:
- Relies on image resolution and quality for accurate segmentation
- Size classification thresholds may need adjustment for different regions
- Lacks real-world unit calibration (meters/hectares) without proper image metadata
- Recommendations are general and not region-specific

### 7.3 Future Improvements

Potential next steps for improving the system:
1. **Model Enhancements**:
   - Experiment with other architectures (DeepLabv3+, HRNet)
   - Implement data augmentation for better generalization
   - Train on a larger, more diverse dataset

2. **Size Calculation**:
   - Integrate with GIS systems to obtain accurate geospatial coordinates
   - Develop methods to automatically determine image scale
   - Account for terrain variations in area calculations

3. **Recommendation System**:
   - Incorporate climate and soil data for more targeted recommendations
   - Develop region-specific recommendation models
   - Create a more interactive recommendation interface

4. **User Interface**:
   - Develop a web or mobile application for easier access
   - Allow users to upload their own imagery
   - Provide visualization tools for farmers to explore results

5. **Validation**:
   - Conduct field validation with actual farm measurements
   - Collect feedback from farmers on recommendation usefulness

## 8. Transfer Learning with Pretrained Models

U-Net can benefit greatly from transfer learning by using pretrained encoders. We'll implement this to improve both training speed and model accuracy.

### 8.1 Pretrained U-Net Implementation

In [None]:
# First, let's install the necessary packages
!pip install segmentation-models-pytorch

# Import libraries for transfer learning
import segmentation_models_pytorch as smp

# Define a U-Net model with a pretrained encoder
class PretrainedUNet(nn.Module):
    def __init__(self, encoder_name="resnet34", encoder_weights="imagenet"):
        super(PretrainedUNet, self).__init__()
        self.model = smp.Unet(
            encoder_name=encoder_name,        # Choose encoder, e.g. resnet34, efficientnet-b0, etc.
            encoder_weights=encoder_weights,  # Use pretrained weights (e.g. imagenet) or None
            in_channels=3,                    # Input channels (RGB images)
            classes=1,                        # Output channels (binary segmentation)
            activation="sigmoid"              # Final activation function
        )
    
    def forward(self, x):
        return self.model(x)

# Create a pretrained model
if torch.cuda.is_available():
    # Show available encoders to choose from
    print("Available pretrained encoders:")
    for i, encoder in enumerate(smp.encoders.get_encoder_names()):
        if i % 5 == 0 and i > 0:
            print()  # Line break for readability
        print(f"{encoder}", end=", ")
    print("\n")

# Initialize the pretrained model
pretrained_model = PretrainedUNet(encoder_name="resnet34", encoder_weights="imagenet")

# Use DataParallel if multiple GPUs are available
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs for parallel training of pretrained model")
    pretrained_model = nn.DataParallel(pretrained_model)

pretrained_model = pretrained_model.to(device)

# Define optimizer with different learning rates for encoder and decoder
# This is a common technique for fine-tuning - lower learning rate for pretrained parts
encoder_params = []
decoder_params = []

if torch.cuda.device_count() > 1:
    # Handle DataParallel case
    for name, param in pretrained_model.module.model.named_parameters():
        if name.startswith("encoder"):
            encoder_params.append(param)
        else:
            decoder_params.append(param)
else:
    for name, param in pretrained_model.model.named_parameters():
        if name.startswith("encoder"):
            encoder_params.append(param)
        else:
            decoder_params.append(param)

optimizer_pretrained = optim.Adam([
    {'params': encoder_params, 'lr': 0.0001},  # Lower learning rate for pretrained encoder
    {'params': decoder_params, 'lr': 0.001}    # Higher learning rate for decoder
])

criterion_pretrained = nn.BCELoss()

print("Pretrained U-Net model initialized with ResNet34 encoder")

In [None]:
# Training loop for the pretrained model
num_epochs_pretrained = 15  # Usually needs fewer epochs due to pretrained weights
best_val_loss_pretrained = float('inf')
best_pretrained_model_path = 'best_pretrained_unet_model.pth'

# Lists to store training history
train_losses_pretrained = []
val_losses_pretrained = []
val_ious_pretrained = []

# Mixed precision training setup for faster training
if torch.cuda.is_available():
    # Initialize the scaler for mixed precision training
    from torch.cuda.amp import GradScaler, autocast
    scaler = GradScaler()
    print("Using mixed precision training for faster performance")

print(f"Beginning transfer learning with pretrained ResNet34 encoder...")
for epoch in range(num_epochs_pretrained):
    print(f"Epoch {epoch+1}/{num_epochs_pretrained}")
    
    # Train
    pretrained_model.train()
    running_loss = 0.0
    
    for images, masks in tqdm(train_loader, desc="Training"):
        images = images.to(device)
        masks = masks.to(device)
        
        # Clear gradients
        optimizer_pretrained.zero_grad()
        
        # Mixed precision training
        if torch.cuda.is_available():
            with autocast():
                # Forward pass
                outputs = pretrained_model(images)
                loss = criterion_pretrained(outputs, masks)
            
            # Backward pass with gradient scaling
            scaler.scale(loss).backward()
            scaler.step(optimizer_pretrained)
            scaler.update()
        else:
            # Standard training on CPU
            outputs = pretrained_model(images)
            loss = criterion_pretrained(outputs, masks)
            loss.backward()
            optimizer_pretrained.step()
        
        running_loss += loss.item() * images.size(0)
    
    train_loss = running_loss / len(train_loader.dataset)
    train_losses_pretrained.append(train_loss)
    
    # Validate
    pretrained_model.eval()
    running_loss = 0.0
    
    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc="Validation"):
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward pass
            outputs = pretrained_model(images)
            loss = criterion_pretrained(outputs, masks)
            
            running_loss += loss.item() * images.size(0)
    
    val_loss = running_loss / len(val_loader.dataset)
    val_losses_pretrained.append(val_loss)
    
    # Calculate IoU
    val_iou = evaluate_model(pretrained_model, val_loader, device)
    val_ious_pretrained.append(val_iou)
    
    # Print epoch results
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Validation Loss: {val_loss:.4f}")
    print(f"Validation IoU: {val_iou:.4f}")
    
    # Save best model
    if val_loss < best_val_loss_pretrained:
        best_val_loss_pretrained = val_loss
        torch.save(pretrained_model.state_dict(), best_pretrained_model_path)
        print(f"Model saved to {best_pretrained_model_path}")
    
    print("-" * 50)

print("Transfer learning completed!")

# Compare the performance of the basic UNet and pretrained UNet
plt.figure(figsize=(15, 5))

# Plot losses
plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Basic UNet - Train')
plt.plot(val_losses, label='Basic UNet - Val')
plt.plot(train_losses_pretrained, label='Pretrained UNet - Train')
plt.plot(val_losses_pretrained, label='Pretrained UNet - Val')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss Comparison')
plt.legend()

# Plot IoU
plt.subplot(1, 3, 2)
plt.plot(val_ious, label='Basic UNet')
plt.plot(val_ious_pretrained, label='Pretrained UNet')
plt.xlabel('Epoch')
plt.ylabel('IoU')
plt.title('Validation IoU Comparison')
plt.legend()

# Plot convergence speed
plt.subplot(1, 3, 3)
plt.plot([min(val_ious[:i+1]) for i in range(len(val_ious))], label='Basic UNet')
plt.plot([min(val_ious_pretrained[:i+1]) for i in range(len(val_ious_pretrained))], label='Pretrained UNet')
plt.xlabel('Epoch')
plt.ylabel('Best IoU So Far')
plt.title('Convergence Speed Comparison')
plt.legend()

plt.tight_layout()
plt.show()

# Use the best model for inference
if os.path.exists(best_pretrained_model_path):
    pretrained_model.load_state_dict(torch.load(best_pretrained_model_path))
    print(f"Loaded best pretrained model from {best_pretrained_model_path}")
else:
    print("Using current pretrained model state")

pretrained_model.eval()

## 9. Model Export and Optimization

In this section, we'll export our models to formats suitable for deployment and optimize them for inference.

### 9.1 TorchScript Export

In [None]:
# Export the model to TorchScript format
def export_to_torchscript(model, example_input_tensor, model_name="farm_segmentation_model.pt"):
    """
    Export a PyTorch model to TorchScript format for production deployment.
    
    Args:
        model: The PyTorch model to export
        example_input_tensor: An example input tensor with the correct shape
        model_name: The filename to save the model to
    """
    model.eval()  # Set model to evaluation mode
    
    # For DataParallel models, we need to access the module
    if isinstance(model, nn.DataParallel):
        model_to_trace = model.module
    else:
        model_to_trace = model
    
    # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing
    try:
        # Create a trace of the model with a sample input
        traced_model = torch.jit.trace(model_to_trace, example_input_tensor)
        
        # Save the traced model
        traced_model.save(model_name)
        print(f"TorchScript model saved to {model_name}")
        return True
    except Exception as e:
        print(f"Error exporting model to TorchScript: {e}")
        return False

# Export the basic U-Net model
print("Exporting basic U-Net model to TorchScript...")
# Create a sample input tensor with the correct shape
sample_input = torch.randn(1, 3, 256, 256, device=device)
export_to_torchscript(model, sample_input, "unet_farm_segmentation.pt")

# Export the pretrained U-Net model if it exists
if 'pretrained_model' in locals():
    print("Exporting pretrained U-Net model to TorchScript...")
    export_to_torchscript(pretrained_model, sample_input, "pretrained_unet_farm_segmentation.pt")

### 9.2 ONNX Export for Deployment

In [None]:
# Export model to ONNX format for wider compatibility
def export_to_onnx(model, example_input_tensor, model_name="farm_segmentation_model.onnx"):
    """
    Export a PyTorch model to ONNX format for cross-platform deployment.
    
    Args:
        model: The PyTorch model to export
        example_input_tensor: An example input tensor with the correct shape
        model_name: The filename to save the model to
    """
    # For DataParallel models, we need to access the module
    if isinstance(model, nn.DataParallel):
        model_to_export = model.module
    else:
        model_to_export = model
    
    model_to_export.eval()
    
    try:
        # Export the model to ONNX format
        torch.onnx.export(
            model_to_export,               # model being run
            example_input_tensor,          # model input (or a tuple for multiple inputs)
            model_name,                    # where to save the model
            export_params=True,            # store the trained parameter weights inside the model file
            opset_version=12,              # the ONNX version to export the model to
            do_constant_folding=True,      # whether to execute constant folding for optimization
            input_names=['input'],         # the model's input names
            output_names=['output'],       # the model's output names
            dynamic_axes={
                'input': {0: 'batch_size'},    # variable length axes
                'output': {0: 'batch_size'}
            }
        )
        print(f"ONNX model saved to {model_name}")
        
        # Verify the ONNX model
        import onnx
        onnx_model = onnx.load(model_name)
        onnx.checker.check_model(onnx_model)
        print("ONNX model checked - model is valid")
        return True
    except Exception as e:
        print(f"Error exporting model to ONNX: {e}")
        return False

# Export the basic U-Net model to ONNX
print("Exporting basic U-Net model to ONNX...")
sample_input = torch.randn(1, 3, 256, 256, device=device)
export_to_onnx(model, sample_input, "unet_farm_segmentation.onnx")

# Export the pretrained U-Net model if it exists
if 'pretrained_model' in locals():
    print("Exporting pretrained U-Net model to ONNX...")
    export_to_onnx(pretrained_model, sample_input, "pretrained_unet_farm_segmentation.onnx")

In [None]:
# Function to test and visualize model predictions
def visualize_model_predictions(model, test_img_path, device, transform):
    """
    Test the model on a single image and visualize the results with original image, 
    ground truth mask (if available), and predicted segmentation mask.
    
    Args:
        model: Trained segmentation model
        test_img_path: Path to the test image
        device: Device to run inference on (cuda/cpu)
        transform: Image transformations for model input
    """
    model.eval()
    
    # Load image
    image = Image.open(test_img_path).convert("RGB")
    filename = os.path.basename(test_img_path)
    
    # Get original size
    original_size = image.size  # (width, height)
    
    # Create figure
    plt.figure(figsize=(16, 8))
    
    # Plot original image
    plt.subplot(1, 3, 1)
    plt.imshow(image)
    plt.title('Original Image')
    plt.axis('off')
    
    # Try to find corresponding mask (if it exists in validation set)
    mask_exists = False
    mask_file = os.path.splitext(filename)[0] + '.txt'
    mask_path = os.path.join(os.path.dirname(test_img_path).replace('images', 'labels'), mask_file)
    
    if os.path.exists(mask_path):
        mask_exists = True
        # Create empty mask with the same size as the image
        ground_truth_mask = np.zeros(original_size[::-1], dtype=np.uint8)  # height, width
        
        # Read YOLOv8 format annotations
        with open(mask_path, 'r') as f:
            lines = f.readlines()
        
        img_width, img_height = image.size
        for line in lines:
            parts = line.strip().split(' ')
            if len(parts) >= 5:
                # Check if we have polygon points (instance segmentation)
                if len(parts) > 5:
                    # Extract polygon points
                    polygon_points = []
                    for i in range(5, len(parts), 2):
                        if i+1 < len(parts):
                            x = float(parts[i]) * img_width
                            y = float(parts[i+1]) * img_height
                            polygon_points.append((int(x), int(y)))
                    
                    if polygon_points:
                        # Convert to numpy array for OpenCV
                        pts = np.array(polygon_points, np.int32)
                        pts = pts.reshape((-1, 1, 2))
                        # Fill polygon with ones
                        cv2.fillPoly(ground_truth_mask, [pts], 255)
                else:
                    # Use bounding box
                    class_id = int(parts[0])
                    x_center = float(parts[1]) * img_width
                    y_center = float(parts[2]) * img_height
                    width = float(parts[3]) * img_width
                    height = float(parts[4]) * img_height
                    
                    x1 = max(0, int(x_center - width / 2))
                    y1 = max(0, int(y_center - height / 2))
                    x2 = min(img_width - 1, int(x_center + width / 2))
                    y2 = min(img_height - 1, int(y_center + height / 2))
                    
                    cv2.rectangle(ground_truth_mask, (x1, y1), (x2, y2), 255, -1)
        
        # Plot ground truth mask if found
        plt.subplot(1, 3, 2)
        plt.imshow(ground_truth_mask, cmap='gray')
        plt.title('Ground Truth Mask')
        plt.axis('off')
    
    # Process image for model prediction
    input_tensor = transform(image).unsqueeze(0).to(device)
    
    # Get prediction
    with torch.no_grad():
        output = model(input_tensor)
        pred_mask = output.squeeze().cpu().numpy()
    
    # Convert prediction to binary mask
    pred_binary = (pred_mask > 0.5).astype(np.uint8) * 255
    
    # Resize prediction back to original image size
    pred_resized = cv2.resize(pred_binary, (original_size[0], original_size[1]), interpolation=cv2.INTER_NEAREST)
    
    # Plot prediction
    subplot_pos = 3 if mask_exists else 2
    plt.subplot(1, 3, subplot_pos)
    plt.imshow(pred_resized, cmap='gray')
    plt.title('Model Prediction')
    plt.axis('off')
    
    # If ground truth exists, calculate IoU
    if mask_exists:
        # Calculate IoU
        intersection = np.logical_and(ground_truth_mask > 0, pred_resized > 0).sum()
        union = np.logical_or(ground_truth_mask > 0, pred_resized > 0).sum()
        iou = intersection / union if union > 0 else 0
        plt.suptitle(f'Farmland Segmentation - IoU: {iou:.4f}', fontsize=16)
    else:
        plt.suptitle('Farmland Segmentation Prediction', fontsize=16)
    
    plt.tight_layout()
    plt.show()
    
    # Create and show overlay image for better visualization
    plt.figure(figsize=(12, 10))
    
    # Create an RGB version of the prediction for overlay
    pred_color = np.zeros((pred_resized.shape[0], pred_resized.shape[1], 4), dtype=np.uint8)
    pred_color[pred_resized > 0] = [0, 255, 0, 128]  # Semi-transparent green for predictions
    
    # Convert PIL image to numpy array
    img_array = np.array(image)
    
    # Plot image with prediction overlay
    plt.imshow(img_array)
    plt.imshow(pred_color, alpha=0.5)
    plt.title('Prediction Overlay on Original Image')
    plt.axis('off')
    plt.tight_layout()
    plt.show()
    
    return pred_resized, iou if mask_exists else None