In [None]:
import torch
# from options import get_arguments
import albumentations as A
from albumentations.pytorch import ToTensorV2
from dataset import WE3DSDataset
from model import create_deeplabv3
import matplotlib.pyplot as plt
import numpy as np

class args:
    train_images_dir = '../WE3DS_DATASET/Train/images'
    test_images_dir = '../WE3DS_DATASET/Test/images'
    train_segmentations_dir = '../WE3DS_DATASET/Train/annotations'
    test_segmentations_dir = '../WE3DS_DATASET/Test/annotations'
    
    batch_size = 4
    num_classes = 18

transform_train = A.Compose(
    [   
        A.Resize(480, 640),
        A.Normalize(mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
)

transform_val = A.Compose(
    [   
        A.Resize(480, 640),
        A.Normalize(mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
)


train_dataset = WE3DSDataset(args.train_images_dir, args.train_segmentations_dir, transform=transform_train, train=True)
val_dataset = WE3DSDataset(args.test_images_dir, args.test_segmentations_dir, transform=transform_val, train=False)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)

In [3]:
model = create_deeplabv3(args.num_classes).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))



In [None]:
from sklearn.metrics import confusion_matrix

class_names = [
    # 'void',
    'soil',
    'broad_bean',
    'corn_spurry',
    'red-root_amaranth',
    'common_buckwheat',
    'pea',
    'red_fingergrass',
    'common_wild_oat',
    'cornflower',
    'corn_cockle',
    'corn',
    'milk_thistle',
    'rye_brome',
    'soybean',
    'sunflower',
    'narrow-leaved_plantain',
    'small-flower_geranium',
    'sugar_beet'
]

# Create a mapping from class index to class name
class_indices = {i: name for i, name in enumerate(class_names)}

def compute_confusion_matrix(preds, labels, num_classes):
    preds = preds.flatten()
    labels = labels.flatten()
    mask = (labels >= 0) & (labels < num_classes)
    confusion = confusion_matrix(labels[mask], preds[mask], labels=np.arange(num_classes))
    return confusion

def compute_miou(confusion):
    intersection = np.diag(confusion)
    ground_truth_set = confusion.sum(axis=1)
    predicted_set = confusion.sum(axis=0)
    union = ground_truth_set + predicted_set - intersection
    iou = intersection / union
    return np.nanmean(iou), iou

def evaluate_model(model, dataloader, device, num_classes):
    model.eval()
    confusion_matrix_all = np.zeros((num_classes, num_classes), dtype=np.int64)
    
    with torch.no_grad():
        for img, mask, _ in dataloader:
            img = img.to(device)
            mask = mask.to(device)
            outputs = model(img)['out']
            preds = outputs.argmax(dim=1).cpu().numpy()
            mask = mask.cpu().numpy()

            for pred, true_mask in zip(preds, mask):
                # Print class indices in the original mask and their corresponding names
                unique_classes = np.unique(true_mask)
                confusion = compute_confusion_matrix(pred, true_mask, num_classes)
                confusion_matrix_all += confusion

    mean_iou, class_iou = compute_miou(confusion_matrix_all)
    return mean_iou, class_iou

### Patch-level synthetic dataset 5x

In [None]:
device = torch.device('cuda')

ckpt_path = './output/patch_level_5x/ckpt/epoch_60.pth'
checkpoint = torch.load(ckpt_path)
model.load_state_dict(checkpoint["model_state_dict"])

<All keys matched successfully>

In [None]:
# evaluation on test data
mean_iou, class_iou = evaluate_model(model, val_loader, device, 18)

print(f"Mean IOU: {mean_iou}, Mean IOU (exclude soil): {np.mean(class_iou[1:])}")
for i, iou in enumerate(class_iou):
    if i != 255:  # Skip void class
        print(f"Class {i} ({class_indices[i]}) IOU: {iou}")

Mean IOU: 0.6316171713221412, Mean IOU (exclude soil): 0.6101243663888996
Class 0 (soil) IOU: 0.9969948551872515
Class 1 (broad_bean) IOU: 0.6064831162431015
Class 2 (corn_spurry) IOU: 0.49750581975390756
Class 3 (red-root_amaranth) IOU: 0.5278951201747997
Class 4 (common_buckwheat) IOU: 0.8788716949276759
Class 5 (pea) IOU: 0.6632309026105385
Class 6 (red_fingergrass) IOU: 0.6488552934532343
Class 7 (common_wild_oat) IOU: 0.4464625054816547
Class 8 (cornflower) IOU: 0.5931937567479517
Class 9 (corn_cockle) IOU: 0.6860049114925776
Class 10 (corn) IOU: 0.8920001238275083
Class 11 (milk_thistle) IOU: 0.8798643491458885
Class 12 (rye_brome) IOU: 0.6165264313478682
Class 13 (soybean) IOU: 0.8569258579910619
Class 14 (sunflower) IOU: 0.7938006538731172
Class 15 (narrow-leaved_plantain) IOU: 0.03619309782331651
Class 16 (small-flower_geranium) IOU: 0.3156181806845813
Class 17 (sugar_beet) IOU: 0.43268241303250854


In [None]:
# evaluation on train data
mean_iou, class_iou = evaluate_model(model, train_loader, device, 18)

print(f"Mean IOU: {mean_iou}, Mean IOU (exclude soil): {np.mean(class_iou[1:])}")
for i, iou in enumerate(class_iou):
    if i != 255:  # Skip void class
        print(f"Class {i} ({class_indices[i]}) IOU: {iou}")

Mean IOU: 0.6809964787445622, Mean IOU (exclude soil): 0.6623694630294328
Class 0 (soil) IOU: 0.9976557459017616
Class 1 (broad_bean) IOU: 0.7775239513080976
Class 2 (corn_spurry) IOU: 0.5371477756636768
Class 3 (red-root_amaranth) IOU: 0.6262602579132474
Class 4 (common_buckwheat) IOU: 0.9289007754835092
Class 5 (pea) IOU: 0.7426858358961167
Class 6 (red_fingergrass) IOU: 0.7066281843813083
Class 7 (common_wild_oat) IOU: 0.47090604570749367
Class 8 (cornflower) IOU: 0.6357016622904504
Class 9 (corn_cockle) IOU: 0.6868151595338349
Class 10 (corn) IOU: 0.8882298010122572
Class 11 (milk_thistle) IOU: 0.9300657430201714
Class 12 (rye_brome) IOU: 0.399595540603804
Class 13 (soybean) IOU: 0.8692106130236403
Class 14 (sunflower) IOU: 0.7894858954086459
Class 15 (narrow-leaved_plantain) IOU: 0.3312119160460393
Class 16 (small-flower_geranium) IOU: 0.40954197631596356
Class 17 (sugar_beet) IOU: 0.5303697378920993
