In [10]:
import numpy as np
import torch
import os
from torch.utils.data import DataLoader
from models import *
from training_utils import *
import torchvision.transforms.v2 as v2
from torchvision import tv_tensors


In [11]:
# Constants for training
DATA_SPLIT = 0.9
INIT_LR = 0.0005
NUM_EPOCHS = 20
BATCH_SIZE = 16
INPUT_IMAGE_SIZE = 512
NUM_CLASSES = 1  # Binary segmentation
THRESHOLD = 0.5
LOSS = CEJaccardLoss()
EXPERIMENT_NAME = "Mangrove_NoAugs_weightedLoss"
DEVICE = setup_device()

# define the path to each directory
BASE_DIR = "A:\\Drone_Data\\original_data"
IMAGE_DIR = os.path.join(BASE_DIR, f"{INPUT_IMAGE_SIZE}dataset_images.npy")
LABEL_DIR = os.path.join(BASE_DIR, f"{INPUT_IMAGE_SIZE}dataset_labels.npy")

transforms = v2.Compose([
    # Convert to tensor and normalize
    v2.ToDtype({tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.long}),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transforms = v2.Compose([
    v2.ToDtype({tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.long}),
    v2.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225))
])


# Build the dataset and split it into training and test sets
dataset = SegmentationDataset(np.load(IMAGE_DIR, 'r'), np.load(LABEL_DIR, 'r'), transforms)
trainDS, testDS = dataset.split(DATA_SPLIT, val_transforms)
del dataset


# create data loaders
trainLoader = DataLoader(trainDS, shuffle=True, batch_size=BATCH_SIZE)
testLoader = DataLoader(testDS, shuffle=False, batch_size=BATCH_SIZE)

print(f"Dataset containing {len(trainDS)} images loaded.")
del trainDS, testDS


Using CUDA device.
Dataset containing 220271 images loaded.


In [None]:
# Find dataset distrubution for weighting the loss function
class_weights = calculate_class_weights(LABEL_DIR, {0: 'non-mangrove', 255: 'mangrove'}).to(DEVICE)
LOSS = CEJaccardLoss(weight=class_weights)

Labels shape: (244746, 1, 512, 512), dtype: uint8


Counting classes: 100%|██████████| 25/25 [03:20<00:00,  8.02s/it]



Class distribution:
  non-mangrove   : 21,817,867,448 pixels ( 43.1%)
  mangrove       : 28,838,449,312 pixels ( 56.9%)

Normalized class weights:
  non-mangrove   : 1.2720
  mangrove       : 0.7280
Class weights: tensor([1.2720, 0.7280])


In [None]:
model = SegFormer(num_classes=NUM_CLASSES, input_image_size=INPUT_IMAGE_SIZE)
Machine = TrainingSession(model, trainLoader, testLoader, LOSS, init_lr=INIT_LR, num_epochs=NUM_EPOCHS, device=DEVICE, experiment_name=EXPERIMENT_NAME)
Machine.learn()

In [None]:
Machine_two = TrainingSession(Machine.model, trainLoader, testLoader, LOSS, init_lr=INIT_LR/5, num_epochs=NUM_EPOCHS//2, device=DEVICE, experiment_name=EXPERIMENT_NAME+"_fineTuned")
Machine_two.model.train_backbone()
Machine_two.learn()

2025-10-13 13:43:32,347 - INFO - Starting training: 10 epochs
2025-10-13 13:43:32,348 - INFO - Model parameters: 33,248,321
Epoch 1/10:  92%|█████████▏| 12630/13767 [4:25:32<26:12,  1.38s/it, loss=0.3440, lr=0.000100]  

In [None]:
Machine.plot_loss()

In [None]:
print("Available Metrics:", Machine.get_available_metrics())
Machine.plot_metrics("Metrics")

In [None]:
torch.save(Machine.model.state_dict(), "resnet_unet.pth")