In [None]:
import numpy as np
import torch
import os
from torch.utils.data import DataLoader
from models import *
from training_utils import *
import torchvision.transforms.v2 as v2
from torchvision import tv_tensors


In [None]:
# Constants for training
DATA_SPLIT = 0.9
INIT_LR = 0.0001
NUM_EPOCHS = 2
BATCH_SIZE = 16
INPUT_IMAGE_SIZE = 512
NUM_CLASSES = 1  # Binary segmentation
THRESHOLD = 0.5
LOSS = JaccardLoss()
EXPERIMENT_NAME = "Mangrove_Test"
DEVICE = setup_device()
CLASSES = {0: 'non-mangrove', 1: 'mangrove'}

# define the path to each directory
BASE_DIR = "/Users/gage/Desktop/Mangrove/Drone Data"
IMAGE_DIR = os.path.join(BASE_DIR, f"{INPUT_IMAGE_SIZE}dataset_images.npy")
LABEL_DIR = os.path.join(BASE_DIR, f"{INPUT_IMAGE_SIZE}dataset_labels.npy")

# Get the dataset
images = np.load(IMAGE_DIR, 'r')
labels = np.load(LABEL_DIR, 'r')f
print(f"Dataset shape: {images.shape}, {labels.shape}")

transforms = v2.Compose([
    # Convert to tensor and normalize
    v2.ToDtype({tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.long}),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transforms = v2.Compose([
    v2.ToDtype({tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.long}),
    v2.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225))
])

# Build the dataset and split it into training and test sets
dataset = SegmentationDataset(images, labels, transforms)
trainDS, testDS = dataset.split(DATA_SPLIT, val_transforms)
del dataset

# create data loaders
trainLoader = DataLoader(trainDS, shuffle=True, batch_size=BATCH_SIZE)
testLoader = DataLoader(testDS, shuffle=False, batch_size=BATCH_SIZE)

print(f"Dataset containing {len(trainDS)} images loaded.")

del trainDS, testDS

Using Apple Metal Performance Shaders (MPS) device.

Dataset shape: (1000, 3, 512, 512), (1000, 1, 512, 512)
Dataset containing 900 images loaded.


In [None]:
# Find dataset distrubution for weighting the loss function
class_weights = calculate_class_weights(LABEL_DIR, CLASSES).to(DEVICE)
LOSS = JaccardLoss(num_classes=NUM_CLASSES, weight=class_weights)

In [None]:
model = ResNet_UNet(input_image_size=INPUT_IMAGE_SIZE, num_classes=NUM_CLASSES).to(DEVICE)
Machine = TrainingSession(model, trainLoader, testLoader, LOSS, init_lr=INIT_LR, num_epochs=NUM_EPOCHS, device=DEVICE, experiment_name=EXPERIMENT_NAME)
Machine.learn()