Import Essential Packages

In [None]:
import datetime
import math
import random
import time

import numpy as np
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from tqdm import tqdm
import pickle

import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset, Subset


from .model import PRC_Net
from .Utils.load_train_data import *
from .Utils.load_test_data import *
from .Utils.loss import ccfl_dice
from .Utils.training_utils import EarlyStopping, compute_iou


## Model Training

#### 1. Setup Training Parameters

In [None]:
channel = 'RGBNIRRE' # RGNIRRE (for Sequoia) | NGB (for Sesame Aerial)
SEED = 0
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_size = 320
IMG_CHANNELS = 3
N_CLASSES = 3
batch_size = 16
EPOCHS = 50
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-5
AUGS = [1,2]
early_stopping_patience = (EPOCHS // 5) * 2
patience_lr = (EPOCHS // 5)


filepath = "PATH_TO_CHECKPOINT_DIR"
training_data_dir = "PATH_TO_DATA_DIR"
training_masks_dir = "PATH_TO_MASK_DIR"
testing_data_dir =  "PATH_TO_DATA_DIR"
testing_masks_dir = "PATH_TO_MASK_DIR"
OUT_DIR = "PATH_for_Model_Trainig"

#### 2.1 Load Training data for WeedMap (RedEdge & Sequoia)

In [None]:

class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform

        self.image_files = sorted(os.listdir(image_dir))
        self.mask_files = sorted(os.listdir(mask_dir))

        # Each image will have 3 variants: original, h-flip, v-flip
        self.total_variants = 3
        self.total_images = len(self.image_files)

    def __len__(self):
        # total = original + horizontal + vertical
        return self.total_images * self.total_variants

    def __getitem__(self, idx):
        # Determine which image and which variant to load
        img_idx = idx // self.total_variants
        variant = idx % self.total_variants  # 0=original, 1=horizontal, 2=vertical

        image_path = os.path.join(self.image_dir, self.image_files[img_idx])
        mask_path = os.path.join(self.mask_dir, self.mask_files[img_idx])

        # ----- Load image (.npy, 5-channel case) -----
        image = np.load(image_path, allow_pickle=True).astype(np.float32)
        image = image / 127.5 - 1.0  # normalize to [-1, 1]

        # ----- Load mask -----
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.uint8)

        # ----- Apply flipping augmentation -----
        if variant == 1:  # horizontal flip
            image = np.flip(image, axis=1).copy()
            mask = np.flip(mask, axis=1).copy()
        elif variant == 2:  # vertical flip
            image = np.flip(image, axis=0).copy()
            mask = np.flip(mask, axis=0).copy()

        # ----- Apply Albumentations transform if available -----
        if self.transform is not None:
            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]  # tensor (C, H, W)
            mask = augmented["mask"]    # tensor (H, W)
        else:
            image = np.transpose(image, (2, 0, 1))  # C, H, W
            image = torch.from_numpy(image).float()
            mask = torch.from_numpy(mask).long()

        # ----- One-hot encode mask -----
        mask_onehot = F.one_hot(mask.long(), num_classes=3).permute(2, 0, 1).float()

        return image, mask, mask_onehot


# ------------- Data Augmentations -------------
train_transform = A.Compose([
    A.Resize(height=image_size, width=image_size),
    ToTensorV2(transpose_mask=True),
])

val_transform = A.Compose([
    A.Resize(height=image_size, width=image_size),
    ToTensorV2(transpose_mask=True),
])

# Paths
image_dir = training_data_dir
mask_dir  = training_masks_dir

# Collect all images (assumes names match between image_dir and mask_dir)
all_images = sorted(os.listdir(image_dir))
all_masks  = sorted(os.listdir(mask_dir))

# Ensure consistency
assert len(all_images) == len(all_masks), "Mismatch between images and masks"
total = len(all_images)
print(f"Total base samples: {total}")

# Shuffle indices (based only on original images)
indices = list(range(total))
random.shuffle(indices)

# Split sizes (e.g., 80% train, 20% validation)
val_size = int(0.2 * total)
train_size = total - val_size

# Assign indices for original (unaugmented) data
val_indices  = indices[:val_size]
train_indices = indices[val_size:]

print(f"Train: {len(train_indices)}, Val: {len(val_indices)}")

# ---------------- Dataset Construction ----------------
# Full dataset (with flipping augmentation)
full_dataset = SegmentationDataset(image_dir, mask_dir, transform=None)

# Training subset (augmented: original + H-flip + V-flip)
train_dataset = Subset(full_dataset, [i * 3 + j for i in train_indices for j in range(3)])

# Validation subset (originals only)
val_dataset = Subset(full_dataset, [i * 3 for i in val_indices])

# Apply transforms
train_dataset.dataset.transform = train_transform
val_dataset.dataset.transform   = val_transform

# ---------------- DataLoaders ----------------
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print(f"Final dataset sizes → Train: {len(train_dataset)}, Val: {len(val_dataset)}")


#### 2.2 Load Training data for Sesame Aerial dataset

In [None]:

class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform

        self.image_files = sorted(os.listdir(image_dir))
        self.mask_files = sorted(os.listdir(mask_dir))

        # Each image will have 3 variants: original, h-flip, v-flip
        self.total_variants = 3
        self.total_images = len(self.image_files)

    def __len__(self):
        # total = original + horizontal + vertical
        return self.total_images * self.total_variants

    def __getitem__(self, idx):
        # Determine which image and which variant to load
        img_idx = idx // self.total_variants
        variant = idx % self.total_variants  # 0=original, 1=horizontal, 2=vertical

        image_path = os.path.join(self.image_dir, self.image_files[img_idx])
        mask_path = os.path.join(self.mask_dir, self.mask_files[img_idx])

        # ----- Load image -----
        image = np.array(Image.open(image_path)).astype(np.float32)
        image = image / 127.5 - 1.0  # normalize to [-1, 1]

        # ----- Load mask -----
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.uint8)

        # ----- Apply flipping augmentation -----
        if variant == 1:  # horizontal flip
            image = np.flip(image, axis=1).copy()
            mask = np.flip(mask, axis=1).copy()
        elif variant == 2:  # vertical flip
            image = np.flip(image, axis=0).copy()
            mask = np.flip(mask, axis=0).copy()

        # ----- Apply Albumentations transform if available -----
        if self.transform is not None:
            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]  # tensor (C, H, W)
            mask = augmented["mask"]    # tensor (H, W)
        else:
            image = np.transpose(image, (2, 0, 1))  # C, H, W
            image = torch.from_numpy(image).float()
            mask = torch.from_numpy(mask).long()

        # ----- One-hot encode mask -----
        mask_onehot = F.one_hot(mask.long(), num_classes=3).permute(2, 0, 1).float()

        return image, mask, mask_onehot


# ------------- Data Augmentations -------------
train_transform = A.Compose([
    A.Resize(height=image_size, width=image_size),
    ToTensorV2(transpose_mask=True),
])

val_transform = A.Compose([
    A.Resize(height=image_size, width=image_size),
    ToTensorV2(transpose_mask=True),
])

# Paths
image_dir = training_data_dir
mask_dir  = training_masks_dir

# Collect all images (assumes names match between image_dir and mask_dir)
all_images = sorted(os.listdir(image_dir))
all_masks  = sorted(os.listdir(mask_dir))

# Ensure consistency
assert len(all_images) == len(all_masks), "Mismatch between images and masks"
total = len(all_images)
print(f"Total base samples: {total}")

# Shuffle indices (based only on original images)
indices = list(range(total))
random.shuffle(indices)

# Split sizes (e.g., 80% train, 20% validation)
val_size = int(0.2 * total)
train_size = total - val_size

# Assign indices for original (unaugmented) data
val_indices  = indices[:val_size]
train_indices = indices[val_size:]

print(f"Train: {len(train_indices)}, Val: {len(val_indices)}")

# ---------------- Dataset Construction ----------------
# Full dataset (with flipping augmentation)
full_dataset = SegmentationDataset(image_dir, mask_dir, transform=None)

# Training subset (augmented: original + H-flip + V-flip)
train_dataset = Subset(full_dataset, [i * 3 + j for i in train_indices for j in range(3)])

# Validation subset (originals only)
val_dataset = Subset(full_dataset, [i * 3 for i in val_indices])

# Apply transforms
train_dataset.dataset.transform = train_transform
val_dataset.dataset.transform   = val_transform

# ---------------- DataLoaders ----------------
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print(f"Final dataset sizes → Train: {len(train_dataset)}, Val: {len(val_dataset)}")


#### 3. Training

In [None]:


def loss_fn(y_true_onehot, logits):
    focal_dice = ccfl_dice(y_true_onehot, logits, lamda=0.7, from_logits=False)
    return focal_dice.mean()

# ------------- Model, Optimizer, Loss ------------- # commented for fine tuning
model = PRC_Net(n_classes=N_CLASSES, IMG_HEIGHT=image_size, IMG_WIDTH=image_size, IMG_CHANNELS=IMG_CHANNELS, dropout_rate=0.0)
model = model.to(DEVICE)

criterion = loss_fn


optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=patience_lr)

history = {
    'train_loss': [],
    'val_loss': [],
    'train_iou': [],
    'val_iou': []
}


Model Training Loop

In [None]:
# ------------- Training Loop -------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
best_val_loss = float('inf') # Intialize best loss
early_stopper = EarlyStopping(patience=early_stopping_patience, min_delta=0.001) # Early stopping
early_stopping_counter = 0
accumulation_steps = 1 
start_training_time = datetime.datetime.now()

for epoch in range(EPOCHS):
    model.train()
    running_loss, running_iou = 0, 0

    for step, (images, masks, mask_onehot) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} - Training")):
        # images, masks, mask_onehot = images.to(device), masks.to(device)
        images, masks, mask_onehot = images.to(device), masks.to(device), mask_onehot.to(device)
        outputs = model(images)
        loss = criterion(mask_onehot, outputs)

        loss.backward()

        # gradient accumulation logic
        if (step + 1) % accumulation_steps == 0 or (step + 1) == len(train_loader):
            optimizer.step()
            optimizer.zero_grad() 

        running_loss += loss.item()
        preds = torch.argmax(outputs, dim=1)
        running_iou += compute_iou(preds.cpu(), masks.cpu())


    train_loss = running_loss / len(train_loader)
    train_iou = running_iou / len(train_loader)
    history['train_loss'].append(train_loss)
    history['train_iou'].append(train_iou)


    # Validation
    model.eval()
    valr_loss, valr_iou = 0, 0
    with torch.no_grad():
        for images, masks, mask_onehot in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} - Validation"):
            images, masks, mask_onehot = images.to(device), masks.to(device), mask_onehot.to(device)
            outputs = model(images)
            loss = criterion(mask_onehot, outputs)
            valr_loss += loss.item()

            preds = torch.argmax(outputs, dim=1)
            valr_iou += compute_iou(preds.cpu(), masks.cpu())
    
    val_loss = valr_loss / len(val_loader)
    val_iou = valr_iou / len(val_loader)

    history['val_loss'].append(val_loss)
    history['val_iou'].append(val_iou)

    print(f"Epoch [{epoch+1}/{EPOCHS}] Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
          f"Train IoU: {train_iou:.4f}, Val IoU: {val_iou:.4f}")
    
    # Save the best model
    if val_loss < best_val_loss:
        best_model_path = f"{OUT_DIR}/PRC_NET_best_E-{epoch+1}.pth"
        print(f"Validation loss improved from {best_val_loss} to {val_loss}, saving best model at {best_model_path}")
        best_val_loss = val_loss
        torch.save(model.state_dict(), best_model_path)
        print(f"Best Model Saved for Epoch {epoch+1}. And reset early stopping counter {early_stopping_counter} to 0")
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1
        if early_stopping_counter >= early_stopping_patience:
            print(f"Early stopping has been triggered")
            break

    


    
    lr_scheduler.step()
    early_stopper.step(val_loss)
    
    if early_stopper.early_stop:
        print(f"Early stopping at epoch {epoch+1}")
        break

end_training_time = datetime.datetime.now()

total_training_time = end_training_time - start_training_time


Saving Model and history

In [None]:
torch.save(model.state_dict(), f"{OUT_DIR}/PRC_NET_last.pth")
with open(f"{OUT_DIR}/history.pkl", 'wb') as file:
    pickle.dump(history, file)


## Model Inference

#### 2.1 Load Testing data for WeedMap (RedEdge & Sequoia)

In [None]:


# Cell 3: dataset wrapper (expects numpy arrays from your Utils loader)
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.image_files = sorted(os.listdir(image_dir))
        self.mask_files = sorted(os.listdir(mask_dir))

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_files[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_files[idx])
        print(image_path, mask_path)
        print(f"getting image")

        # ----- Load image (.npy, 5-channel case for WeedMap) -----
        image = np.load(image_path, allow_pickle=True).astype(np.float32)
        image = image / 127.5 - 1.0  # normalize to [-1, 1]

        # ----- Load mask -----
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.uint8)

        # ----- Apply Albumentations transform if available -----
        if self.transform is not None:
            # Albumentations expects HWC, not CHW
            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]  # tensor, shape (C, H, W)
            mask = augmented["mask"]    # tensor, shape (H, W)
        else:
            # If no transform, manually convert
            image = np.transpose(image, (2, 0, 1))  # C, H, W
            image = torch.from_numpy(image).float()
            mask = torch.from_numpy(mask).long()

        # ----- Optional one-hot encoding -----
        mask_onehot = F.one_hot(mask.long(), num_classes=3).permute(2, 0, 1).float()

        return image, mask, mask_onehot



# ------------- Data Augmentations -------------
train_transform = A.Compose([
    A.Resize(height=image_size, width=image_size),
    ToTensorV2(transpose_mask=True),
])


val_transform = A.Compose([
    A.Resize(height=image_size, width=image_size),
    ToTensorV2(transpose_mask=True),
])



test_image_dir = testing_data_dir
test_mask_dir = testing_masks_dir


# image_size = image_size
image_size = 320
batch_size = 1

test_transform = A.Compose([
    A.Resize(height=image_size, width=image_size),
    ToTensorV2(transpose_mask=True),
])


# Full dataset
test_dataset = SegmentationDataset(test_image_dir, test_mask_dir, transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

#### 2.2 Load Testing data for Sesame Aerial dataset

In [None]:


# Cell 3: dataset wrapper (expects numpy arrays from your Utils loader)
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.image_files = sorted(os.listdir(image_dir))
        self.mask_files = sorted(os.listdir(mask_dir))

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_files[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_files[idx])
        print(image_path, mask_path)
        print(f"getting image")

        image = np.array(Image.open(image_path)).astype(np.float32)
        image = image / 127.5 - 1.0  # normalize to [-1, 1]

        # ----- Load mask -----
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.uint8)

        # ----- Apply Albumentations transform if available -----
        if self.transform is not None:
            # Albumentations expects HWC, not CHW
            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]  # tensor, shape (C, H, W)
            mask = augmented["mask"]    # tensor, shape (H, W)
        else:
            # If no transform, manually convert
            image = np.transpose(image, (2, 0, 1))  # C, H, W
            image = torch.from_numpy(image).float()
            mask = torch.from_numpy(mask).long()

        # ----- Optional one-hot encoding -----
        mask_onehot = F.one_hot(mask.long(), num_classes=3).permute(2, 0, 1).float()

        return image, mask, mask_onehot



# ------------- Data Augmentations -------------
train_transform = A.Compose([
    A.Resize(height=image_size, width=image_size),
    ToTensorV2(transpose_mask=True),
])


val_transform = A.Compose([
    A.Resize(height=image_size, width=image_size),
    ToTensorV2(transpose_mask=True),
])


# First, combine images and masks from all compaigns to one IMAGE_DIR and MASKS_DIR respectively. Then use the following code.
test_image_dir = testing_data_dir
test_mask_dir = testing_masks_dir


# image_size = image_size
image_size = 320
batch_size = 1

test_transform = A.Compose([
    A.Resize(height=image_size, width=image_size),
    ToTensorV2(transpose_mask=True),
])


# Full dataset
test_dataset = SegmentationDataset(test_image_dir, test_mask_dir, transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

#### 3. Make Predictions

In [None]:
def perform_inference(model, dataloader, device):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, masks, mask_onehot in tqdm(dataloader):
            images = images.to(device) 
            masks = masks.to(device) 
            mask_onehot = mask_onehot.to(device)
            probs = model(images)  # (B, 3, H, W)
            preds = torch.argmax(probs, dim=1)  # (B, H, W)
            all_preds.append(preds.cpu().view(-1))
            all_labels.append(masks.cpu().view(-1))
        y_pred = torch.cat(all_preds)
        y_true = torch.cat(all_labels)
    return y_pred, y_true

In [None]:
# -------- Load Model and Run --------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PRC_Net(n_classes=N_CLASSES, IMG_HEIGHT=image_size, IMG_WIDTH=image_size, IMG_CHANNELS=IMG_CHANNELS, dropout_rate=0.0)
model.load_state_dict(torch.load("model_path", map_location=device))
model = model.to(device) 

y_pred, y_true = perform_inference(model, test_loader, device)

