# Train a mask-rcnn model for leaf morphometrics

## Load packages

In [None]:
# Import os and set environment variables
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

# Import standard libraries
import sys
import time
from datetime import datetime
import gc
import traceback

# Import third-party libraries
import numpy as np
from PIL import Image
import torch
import torchvision
import torchvision.transforms as T
from pycocotools.coco import COCO

# Import specific modules from torchvision
import torch.multiprocessing as mp
mp.set_start_method('spawn', force=True)
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

In [86]:
# Check Python version
print(f"Python version: {sys.version}")

# Check Torch version
print(f"Torch version: {torch.__version__}")

# Check Torchvision version
print(f"Torchvision version: {torchvision.__version__}")

Python version: 3.11.11 | packaged by conda-forge | (main, Mar  3 2025, 20:44:07) [Clang 18.1.8 ]
Torch version: 2.7.0.dev20250305
Torchvision version: 0.22.0.dev20250305


## Establish LeafDataset class

In [87]:
class LeafDataset(torch.utils.data.Dataset):
    def __init__(self, root, annotations_file, transforms):
        self.root = root
        self.transforms = transforms
        self.coco = torchvision.datasets.CocoDetection(root, annotations_file)
        self.ids = list(sorted(self.coco.ids))

    def __getitem__(self, index):
        img_id = self.ids[index]
        img, target = self.coco[index]

        boxes = []
        labels = []
        masks = []
        area = []
        iscrowd = []

        for annotation in target:
            bbox = annotation['bbox']
            x1, y1, width, height = bbox
            x2 = x1 + width
            y2 = y1 + height
            boxes.append([x1, y1, x2, y2])

            category_id = annotation['category_id']
            labels.append(category_id)

            if 'segmentation' in annotation:
                mask = self.coco.coco.annToMask(annotation)
                masks.append(mask)

            if 'area' in annotation:
                area.append(annotation['area'])
            else:
                area.append(width * height)

            if 'iscrowd' in annotation:
                iscrowd.append(annotation['iscrowd'])
            else:
                iscrowd.append(0)

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)

        if masks:
            masks = np.stack(masks)
            masks = torch.as_tensor(masks, dtype=torch.uint8)
        else:
            masks = torch.zeros((0, img.height, img.width), dtype=torch.uint8)

        area = torch.as_tensor(area, dtype=torch.float32)
        iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)

        target = {}
        target['boxes'] = boxes
        target['labels'] = labels
        target['masks'] = masks
        target['image_id'] = torch.tensor([img_id])
        target['area'] = area
        target['iscrowd'] = iscrowd

        if self.transforms is not None:
            img = self.transforms(img)

        return img, target

    def __len__(self):
        return len(self.ids)

## Define top level functions

### Define instance segmentation model

In [88]:
def get_instance_segmentation_model(num_classes):
    weights = torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights.DEFAULT
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=weights)

    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
    return model

### Define the transform function

In [89]:
def get_transform(train):
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

### Define collate function

In [90]:
def collate_fn(batch):
    return tuple(zip(*batch))

### Define function for model evaluation

In [91]:
def evaluate_model(model, data_loader, device, log_message=print):
    """
    Evaluate the model on a dataset without computing gradients.
    Returns the average loss.
    """
    model.train()  # Temporarily set to train mode to compute losses
    total_loss = 0
    batch_count = 0

    try:
        with torch.no_grad():  # No gradients needed for evaluation
            for i, (images, targets) in enumerate(data_loader):
                images = list(image.to(device) for image in images)
                targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

                # Forward pass and compute loss
                loss_dict = model(images, targets)

                # Sum all losses
                losses = sum(loss for loss in loss_dict.values())

                total_loss += losses.item()
                batch_count += 1

                if (i + 1) % 5 == 0:  # Log progress
                    log_message(f"  Eval batch {i+1}/{len(data_loader)}, Loss: {losses.item():.4f}")

                # Clear memory
                del images, targets, loss_dict, losses
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

        model.eval()  # Set back to eval mode
        avg_loss = total_loss / batch_count if batch_count > 0 else 0
        return avg_loss
    except Exception as e:
        model.eval()  # Ensure model is set back to eval mode even if an error occurs
        log_message(f"Error during evaluation: {e}")
        traceback.print_exc()
        return float('inf')


## Define data paths

In [92]:
# Set the path to your annotated images and annotations
data_path = "/Users/aja294/Documents/Grape_local/projects/leaf_morphometrics/data/annotations"
checkpoint_path = "/Users/aja294/Documents/Grape_local/projects/leaf_morphometrics/models/mask_rcnn"

In [93]:
# Define directories for train, validation, and test
train_dir = os.path.join(data_path, "coco", "train")
val_dir = os.path.join(data_path, "coco", "valid")
test_dir = os.path.join(data_path, "coco", "test")

In [94]:
# Define annotation files
train_annotations_file = os.path.join(train_dir, "_annotations.coco.json")
val_annotations_file = os.path.join(val_dir, "_annotations.coco.json")
test_annotations_file = os.path.join(test_dir, "_annotations.coco.json")

In [95]:
 # Create directories if they don't exist
for directory in [train_dir, val_dir, test_dir]:
    os.makedirs(directory, exist_ok=True)

print(f"Data path: {data_path}")
print(f"Train path: {train_dir}")
print(f"Validation path: {val_dir}")
print(f"Test path: {test_dir}")

Data path: /Users/aja294/Documents/Grape_local/projects/leaf_morphometrics/data/annotations
Train path: /Users/aja294/Documents/Grape_local/projects/leaf_morphometrics/data/annotations/coco/train
Validation path: /Users/aja294/Documents/Grape_local/projects/leaf_morphometrics/data/annotations/coco/valid
Test path: /Users/aja294/Documents/Grape_local/projects/leaf_morphometrics/data/annotations/coco/test


In [96]:
# Check if annotation files exist
missing_files = []
for file_path, name in [(train_annotations_file, "Training"), 
                        (val_annotations_file, "Validation"), 
                        (test_annotations_file, "Testing")]:
    if not os.path.exists(file_path):
        missing_files.append((name, file_path))

if missing_files:
    for name, path in missing_files:
        print(f"Warning: {name} annotations file not found at {path}")
    print("Please ensure all annotation files exist before proceeding.")

In [97]:
# Create checkpoint directory with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_dir = os.path.join(checkpoint_path, f"checkpoints_{timestamp}")
os.makedirs(checkpoint_dir, exist_ok=True)
print(f"Checkpoint directory: {checkpoint_dir}")

Checkpoint directory: /Users/aja294/Documents/Grape_local/projects/leaf_morphometrics/models/mask_rcnn/checkpoints_20250305_165324


In [98]:
# Create log file
log_file = os.path.join(checkpoint_dir, "training_log.txt")

def log_message(message):
    """Write message to log file and print to console"""
    print(message)
    with open(log_file, "a") as f:
        f.write(f"{message}\n")

log_message(f"=== Training started at {timestamp} ===")

# Load the datasets
log_message("Loading datasets...")


=== Training started at 20250305_165324 ===
Loading datasets...


In [99]:
# Training dataset
train_dataset = LeafDataset(train_dir, train_annotations_file, get_transform(train=True))
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn
)
log_message(f"Training dataset loaded with {len(train_dataset)} images")

loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
Training dataset loaded with 35 images


In [100]:
# Training dataset V2
train_dataset = LeafDataset(train_dir, train_annotations_file, get_transform(train=True))
train_loader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=16,           # Increased from 8 for better MPS utilization
    shuffle=True,
    num_workers=4,           # Increased from 0 to improve data loading speed
    collate_fn=collate_fn,
    pin_memory=True,         # Added to speed up data transfer to MPS
    persistent_workers=True  # Added to maintain workers between epochs
)
log_message(f"Training dataset loaded with {len(train_dataset)} images")

loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
Training dataset loaded with 35 images


In [101]:
# Validation dataset
val_dataset = None
val_loader = None
if os.path.exists(val_annotations_file):
    val_dataset = LeafDataset(val_dir, val_annotations_file, get_transform(train=False))
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn
    )
    log_message(f"Validation dataset loaded with {len(val_dataset)} images")
else:
    log_message("Validation dataset not available")

loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
Validation dataset loaded with 10 images


In [102]:
# Test dataset
test_dataset = None
test_loader = None
if os.path.exists(test_annotations_file):
    test_dataset = LeafDataset(test_dir, test_annotations_file, get_transform(train=False))
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=2, shuffle=False, num_workers=0, collate_fn=collate_fn
    )
    log_message(f"Test dataset loaded with {len(test_dataset)} images")
else:
    log_message("Test dataset not available")

loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
Test dataset loaded with 5 images


### Force run on CPU

In [103]:
device = torch.device('cpu')
log_message(f"Device forced to: {device}")

Device forced to: cpu


### Run on GPU

In [104]:
# Set up device
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
log_message(f"Using device: {device}")

Using device: mps


In [105]:
# Get number of classes from training dataset
num_classes = len(train_dataset.coco.coco.getCatIds()) + 1  # +1 for background
log_message(f"Number of classes: {num_classes}")

Number of classes: 3


In [106]:
# Initialize model
log_message("Initializing model...")
model = get_instance_segmentation_model(num_classes)
model.to(device)
log_message("Model initialized and moved to device")

Initializing model...
Model initialized and moved to device


In [107]:
# Check if there are existing checkpoints to resume from
start_epoch = 0
best_val_loss = float('inf')

In [108]:
# Look for existing checkpoints in data_path/checkpoints
base_checkpoint_dir = os.path.join(data_path, "checkpoints")
if os.path.exists(base_checkpoint_dir):
    checkpoint_dirs = [d for d in os.listdir(base_checkpoint_dir) 
                        if os.path.isdir(os.path.join(base_checkpoint_dir, d))]

    if checkpoint_dirs:
        # Find the latest checkpoint directory
        latest_dir = max(checkpoint_dirs)
        latest_checkpoint_dir = os.path.join(base_checkpoint_dir, latest_dir)

        # Find the latest checkpoint file
        checkpoints = [f for f in os.listdir(latest_checkpoint_dir) 
                        if f.endswith('.pth') and f.startswith('mask_rcnn_checkpoint_epoch_')]

        if checkpoints:
            latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('_')[-1].split('.')[0]))
            checkpoint_path = os.path.join(latest_checkpoint_dir, latest_checkpoint)
            log_message(f"Loading checkpoint from {checkpoint_path}")

            try:
                checkpoint = torch.load(checkpoint_path, map_location=device)
                model.load_state_dict(checkpoint['model_state_dict'])
                start_epoch = checkpoint['epoch']
                best_val_loss = checkpoint.get('best_val_loss', float('inf'))
                log_message(f"Resuming from epoch {start_epoch}")
            except Exception as e:
                log_message(f"Error loading checkpoint: {e}")
                log_message("Starting training from scratch")
                start_epoch = 0

In [109]:
# Set up optimizer and learning rate scheduler
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

In [110]:
# Load optimizer and scheduler states if resuming
if start_epoch > 0:
    try:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        lr_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        log_message("Loaded optimizer and scheduler states")
    except Exception as e:
        log_message(f"Error loading optimizer/scheduler states: {e}")

log_message("Optimizer and learning rate scheduler initialized")

Optimizer and learning rate scheduler initialized


In [111]:
# Training parameters
num_epochs = 1
# Print the device
print(f"Device: {device}")
log_message(f"Starting training for {num_epochs} epochs")

# Training loop
for epoch in range(start_epoch, num_epochs):
    epoch_start_time = time.time()
    model.train()
    epoch_loss = 0
    batch_count = 0

    log_message(f"Epoch {epoch+1}/{num_epochs}")

    # Training phase
    for i, (images, targets) in enumerate(train_loader):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        optimizer.zero_grad()
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        losses.backward()
        optimizer.step()

        epoch_loss += losses.item()
        batch_count += 1

        if (i + 1) % 10 == 0:  # Print every 10 batches
            log_message(f"  Batch {i+1}/{len(train_loader)}, Loss: {losses.item():.4f}")

    # Print epoch training summary
    avg_train_loss = epoch_loss / batch_count if batch_count > 0 else 0
    log_message(f"  Epoch {epoch+1} training completed. Average Loss: {avg_train_loss:.4f}")

    # Validation phase
    # Validation phase
    val_loss = None
    if val_loader:
        val_loss = evaluate_model(model, val_loader, device, log_message)
        log_message(f"  Validation Loss: {val_loss:.4f}")

    # Update learning rate
    lr_scheduler.step()
    log_message(f"  Learning rate updated to: {optimizer.param_groups[0]['lr']:.6f}")

    # Save checkpoint every epoch
    checkpoint = {
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': lr_scheduler.state_dict(),
        'train_loss': avg_train_loss,
        'val_loss': val_loss,
        'best_val_loss': best_val_loss,
        'num_classes': num_classes
    }

    checkpoint_path = os.path.join(checkpoint_dir, f'mask_rcnn_checkpoint_epoch_{epoch+1}.pth')
    torch.save(checkpoint, checkpoint_path)
    log_message(f"  Checkpoint saved to {checkpoint_path}")

    # Save best model based on validation loss
    if val_loss is not None and (epoch == start_epoch or val_loss < best_val_loss):
        best_val_loss = val_loss
        best_model_path = os.path.join(checkpoint_dir, 'mask_rcnn_best_model.pth')
        torch.save(checkpoint, best_model_path)
        log_message(f"  New best model saved to {best_model_path} (val_loss: {best_val_loss:.4f})")

    # Calculate epoch duration
    epoch_duration = time.time() - epoch_start_time
    log_message(f"  Epoch duration: {epoch_duration:.2f} seconds")

log_message("Training completed!")


Device: mps
Starting training for 1 epochs
Epoch 1/1


Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/aja294/miniforge3/envs/pytorch_nightly-env/lib/python3.11/multiprocessing/spawn.py", line 122, in spawn_main
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/aja294/miniforge3/envs/pytorch_nightly-env/lib/python3.11/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
    exitcode = _main(fd, parent_sentinel)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/Users/aja294/miniforge3/envs/pytorch_nightly-env/lib/python3.11/multiprocessing/spawn.py", line 132, in _main
  File "/Users/aja294/miniforge3/envs/pytorch_nightly-env/lib/python3.11/multiprocessing/spawn.py", line 132, in _main
    self = reduction.pickle.load(from_parent)
    self = reduction.pickle.load(from_parent)
             ^ ^ ^ ^ ^ ^ ^ ^ ^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^AttributeError^: ^C

RuntimeError: DataLoader worker (pid(s) 41099) exited unexpectedly

In [None]:
# Evaluate on test set if available
if test_loader:
    log_message("Evaluating on test set...")
    test_loss = evaluate_model(model, test_loader, device)
    log_message(f"Test Loss: {test_loss:.4f}")

Evaluating on test set...
Test Loss: 0.7329


In [None]:
  # Save the final model
final_model_path = os.path.join(checkpoint_dir, "maskrcnn_model_final.pth")
torch.save({
    'model_state_dict': model.state_dict(),
    'num_classes': num_classes,
    'train_loss': avg_train_loss,
    'val_loss': val_loss if val_loader else None,
    'test_loss': test_loss if test_loader else None
}, final_model_path)
log_message(f"Final model saved to {final_model_path}")

Final model saved to /Users/aja294/Documents/Grape_local/projects/leaf_morphometrics/models/mask_rcnn/checkpoints_20250305_135524/maskrcnn_model_final.pth


In [None]:
# Create a symlink to the latest checkpoint directory
# Extract the base directory (without the checkpoints_timestamp part)
base_dir = os.path.dirname(checkpoint_dir)
latest_link = os.path.join(base_dir, "latest")

# Check if the symlink already exists and remove it
if os.path.exists(latest_link):
    try:
        os.remove(latest_link)
    except:
        # On some systems, we might need to use unlink for symlinks
        os.unlink(latest_link)

# Create the symlink pointing to the current checkpoint directory
try:
    os.symlink(checkpoint_dir, latest_link)
    log_message(f"Created symlink to latest checkpoint directory: {latest_link}")
except Exception as e:
    log_message(f"Error creating symlink: {e}")

Created symlink to latest checkpoint directory: /Users/aja294/Documents/Grape_local/projects/leaf_morphometrics/models/mask_rcnn/latest
