In [6]:
import os
import glob
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as T
import torch
import numpy as np


In [27]:

class WheatSegmentationDataset(Dataset):
    def __init__(self, image_root, mask_root, transform=None):
        self.image_root = image_root
        self.mask_root = mask_root
        self.transform = transform

        # Recursively find all image files under image_root
        self.image_paths = sorted(glob.glob(os.path.join(image_root, '**', '*.png'), recursive=True))

        # Derive corresponding mask paths by replacing the image root with mask root
        self.mask_paths = [
            os.path.join(mask_root, os.path.relpath(p, image_root)) for p in self.image_paths
        ]

    def __len__(self):
        return len(self.image_paths)


    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")  # grayscale mask

        # Resize both image and mask
        resize = T.Resize((512, 512), interpolation=Image.NEAREST)
        image = resize(image)
        mask = resize(mask)

        if self.transform:
            image = self.transform(image)

        mask = torch.as_tensor(np.array(mask), dtype=torch.long)

        return image, mask

In [28]:
import numpy as np
from torch.utils.data import DataLoader, random_split


IMG_PATH = r"C:\Users\harry\Documents\Harry\UQ\sem2_2025\DATA7903\DATA7901\gwfss\images"
MASK_PATH = r"C:\Users\harry\Documents\Harry\UQ\sem2_2025\DATA7903\DATA7901\gwfss\masks_grayscale"


transform = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor(),
    T.Normalize([0.5]*3, [0.5]*3)
])

dataset = WheatSegmentationDataset(IMG_PATH, MASK_PATH, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8)

In [29]:
import torchvision.models.segmentation as models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = models.deeplabv3_resnet50(pretrained=False, num_classes=4)
model.to(device)

DeepLabV3(
  (backbone): IntermediateLayerGetter(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Se

In [44]:
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

def train_one_epoch(model, loader):
    model.train()
    total_loss = 0

    for images, masks in loader:
        images, masks = images.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)['out']
        outputs = F.interpolate(outputs, size=(512, 512), mode='bilinear', align_corners=False)

        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # ✅ Must return the average loss
    return total_loss / len(loader)

In [45]:
def evaluate(model, loader):
    model.eval()
    total_correct = 0
    total_pixels = 0

    with torch.no_grad():
        for images, masks in loader:
            images, masks = images.to(device), masks.to(device)

            outputs = model(images)['out']
            outputs = F.interpolate(outputs, size=(512, 512), mode='bilinear', align_corners=False)
            preds = outputs.argmax(dim=1)

            total_correct += (preds == masks).sum().item()
            total_pixels += torch.numel(masks)

    # Avoid division by zero
    if total_pixels == 0:
        return 0.0

    return 100.0 * total_correct / total_pixels

In [46]:
epochs = 10
for epoch in range(epochs):
    train_loss = train_one_epoch(model, train_loader)
    val_acc = evaluate(model, val_loader)
    print("DEBUG -- train_loss:", train_loss, type(train_loss))
    print("DEBUG -- val_acc:", val_acc, type(val_acc))
    print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Pixel Acc = {val_acc:.2f}%")

DEBUG -- train_loss: 0.5100697308778763 <class 'float'>
DEBUG -- val_acc: 80.25204225019975 <class 'float'>
Epoch 1: Train Loss = 0.5101, Val Pixel Acc = 80.25%
DEBUG -- train_loss: 0.471881209991195 <class 'float'>
DEBUG -- val_acc: 81.56801917336203 <class 'float'>
Epoch 2: Train Loss = 0.4719, Val Pixel Acc = 81.57%
DEBUG -- train_loss: 0.4452829965136268 <class 'float'>
DEBUG -- val_acc: 82.82566590742631 <class 'float'>
Epoch 3: Train Loss = 0.4453, Val Pixel Acc = 82.83%
DEBUG -- train_loss: 0.42222753871570934 <class 'float'>
DEBUG -- val_acc: 82.16665788130327 <class 'float'>
Epoch 4: Train Loss = 0.4222, Val Pixel Acc = 82.17%
DEBUG -- train_loss: 0.3978924810886383 <class 'float'>
DEBUG -- val_acc: 82.58771029385653 <class 'float'>
Epoch 5: Train Loss = 0.3979, Val Pixel Acc = 82.59%
DEBUG -- train_loss: 0.38515672954646024 <class 'float'>
DEBUG -- val_acc: 82.58490302345969 <class 'float'>
Epoch 6: Train Loss = 0.3852, Val Pixel Acc = 82.58%
DEBUG -- train_loss: 0.3667399671

In [25]:
import os
from PIL import Image
import glob

image_root = r"C:\Users\harry\Documents\Harry\UQ\sem2_2025\DATA7903\DATA7901\gwfss\images"
image_paths = glob.glob(os.path.join(image_root, '**', '*.png'), recursive=True)

sizes = {}

for path in image_paths:
    with Image.open(path) as img:
        size = img.size  # (width, height)
        sizes[size] = sizes.get(size, 0) + 1

for size, count in sizes.items():
    print(f"Size {size}: {count} images")

Size (512, 512): 987 images
Size (1024, 1024): 109 images


In [None]:
import matplotlib.pyplot as plt

model.eval()
with torch.no_grad():
    for images, masks in val_loader:
        images = images.to(device)
        outputs = model(images)['out']
        preds = torch.argmax(outputs, dim=1)

        for i in range(2):
            plt.subplot(1, 3, 1)
            plt.imshow(images[i].cpu().permute(1, 2, 0) * 0.5 + 0.5)
            plt.title("Input Image")

            plt.subplot(1, 3, 2)
            plt.imshow(masks[i].cpu(), cmap="tab10", vmin=0, vmax=3)
            plt.title("Ground Truth")

            plt.subplot(1, 3, 3)
            plt.imshow(preds[i].cpu(), cmap="tab10", vmin=0, vmax=3)
            plt.title("Prediction")

            plt.show()
        break