In [12]:
import torch
import torchvision.models as models

def load_resnet_rgbn(model_path, architecture='resnet18'):
    # Create model with four-channel input
    model = getattr(models, architecture)(pretrained=False)
    model.conv1 = torch.nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
    model = torch.nn.Sequential(*list(model.children())[:-1], torch.nn.Flatten())
    
    # Load weights
    model.load_state_dict(torch.load(model_path))
    return model

# Example usage
model = load_resnet_rgbn('Res_18.pth', 'resnet18')

In [2]:
!wget https://zenodo.org/record/8170135/files/Res_18.pth

--2025-04-28 11:34:11--  https://zenodo.org/record/8170135/files/Res_18.pth
Resolving zenodo.org (zenodo.org)... 188.185.43.25, 188.185.48.194, 188.185.45.92, ...
Connecting to zenodo.org (zenodo.org)|188.185.43.25|:443... connected.
HTTP request sent, awaiting response... 301 MOVED PERMANENTLY
Location: /records/8170135/files/Res_18.pth [following]
--2025-04-28 11:34:12--  https://zenodo.org/records/8170135/files/Res_18.pth
Reusing existing connection to zenodo.org:443.
HTTP request sent, awaiting response... 200 OK
Length: 44793425 (43M) [application/octet-stream]
Saving to: ‘Res_18.pth’


2025-04-28 11:34:22 (4.30 MB/s) - ‘Res_18.pth’ saved [44793425/44793425]



In [4]:
def compute_modified_miou(preds, targets, num_classes=9):
    """
    preds, targets: Tensor [C, H, W] hoặc [B, C, H, W]
    """
    if preds.dim() == 4:
        preds = preds.flatten(0, 1)
        targets = targets.flatten(0, 1)

    C, H, W = preds.shape

    # preds, targets: [C, H, W] --> [H, W, C]
    preds = preds.permute(1, 2, 0)
    targets = targets.permute(1, 2, 0)

    # Boolean masks
    preds_bool = (torch.sigmoid(preds) > 0.5)

    targets_bool = targets > 0

    # True Positive: pred đúng ít nhất 1 nhãn
    correct = (preds_bool & targets_bool).float()

    # False Positive: pred có nhãn sai
    pred_only = preds_bool & (~targets_bool)
    # False Negative: thiếu nhãn
    target_only = targets_bool & (~preds_bool)

    TP = correct.sum(dim=(0, 1))
    FP = pred_only.sum(dim=(0, 1))
    FN = target_only.sum(dim=(0, 1))

    ious = TP / (TP + FP + FN + 1e-6)
    miou = ious.mean().item()

    return miou

In [7]:
import os
import sys
import time
import glob
import numpy as np
import torch
import utils
import logging
import argparse
import torch.nn as nn
import torch.utils
import torch.nn.functional as F
from torch.utils.data import Dataset
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
from model import NetworkCIFAR as Network
from collections import namedtuple
from tqdm import tqdm
import wandb

class CustomPTSegmentationDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        super().__init__()
        self.images_dir = os.path.join(root_dir, "X")
        self.labels_dir = os.path.join(root_dir, "labels")
        self.transform = transform

        self.image_paths = sorted(glob.glob(os.path.join(self.images_dir, '*.pt')))
        self.label_paths = [os.path.join(self.labels_dir, os.path.basename(p)) for p in self.image_paths]
        assert all([os.path.exists(p) for p in self.label_paths]), "Missing label files."

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = torch.load(self.image_paths[idx]).float().div(255.0)  # Normalize
        label = torch.load(self.label_paths[idx]).float()
        image = image.permute(0, 2, 1)  # [C, H, W]
        label = label.permute(0, 2, 1)
        if self.transform:
            image, label = self.transform(image, label)
        return image, label
full_dataset = CustomPTSegmentationDataset(root_dir="../../../Agriculture-Vision-2021_processed_zip/trainrandcrop256/")
indices = list(range(5000))
np.random.shuffle(indices)
train_indices = indices[:2500]
val_indices = indices[2500:5000]

train_loader = torch.utils.data.DataLoader(full_dataset, batch_size=16,
    sampler=torch.utils.data.SubsetRandomSampler(train_indices), num_workers=2)
val_loader = torch.utils.data.DataLoader(full_dataset, batch_size=16,
sampler=torch.utils.data.SubsetRandomSampler(val_indices), num_workers=2)

In [8]:
def evaluate_model(model, val_loader, device, num_classes=9):
    model.eval()
    model.to(device)
    
    total_miou = 0
    total_samples = 0

    pbar = tqdm(val_loader, desc="Evaluating", leave=False)
    with torch.no_grad():
        for images, labels in pbar:
            images = images.to(device)
            labels = labels.to(device)

            # Dự đoán
            logits = model(images)

            # logits đầu ra [B, 512], cần reshape lại [B, num_classes, H, W]
            # Tuy nhiên mô hình của bạn hiện tại đang flatten hết rồi.
            # -> Nếu model output không phải [B, 9, H, W], bạn phải kiểm tra lại nhé.

            if logits.dim() == 2:  # (batch_size, feature_dim)
                # (Vì model Sequential cắt ở Flatten, output là 512-dim vector)
                raise ValueError("Model output is flatten vector! Bạn cần chỉnh lại model để output ra mask segmentation.")
            
            # Tính mIoU
            batch_miou = compute_modified_miou(logits, labels, num_classes=num_classes)
            batch_size = images.size(0)

            total_miou += batch_miou * batch_size
            total_samples += batch_size

            pbar.set_postfix({"Batch mIoU": batch_miou})

    avg_miou = total_miou / total_samples
    return avg_miou


In [13]:
def evaluate_flatten_model(model, val_loader, device, num_classes=9, threshold=0.5):
    model.eval()
    model.to(device)

    total_correct = 0
    total_labels = 0

    pbar = tqdm(val_loader, desc="Evaluating (Flatten)", leave=False)
    with torch.no_grad():
        for images, labels in pbar:
            images = images.to(device)
            labels = labels.to(device)

            # Predict
            logits = model(images)  # [B, 512]
            preds = torch.sigmoid(logits)

            # Reduce labels [B, 9, H, W] -> [B, 9]
            labels_reduced = (labels > 0).float().view(labels.size(0), labels.size(1), -1).max(dim=2)[0]  # max pooling theo H*W

            # Predict 9 classes bằng 1 Linear
            preds = preds[:, :num_classes]  # Giả sử chỉ lấy 9 lớp đầu tiên

            # Threshold predictions
            preds_binary = (preds > threshold).float()

            # Calculate correct prediction
            correct = (preds_binary == labels_reduced).float().sum()

            total_correct += correct.item()
            total_labels += labels_reduced.numel()

            pbar.set_postfix({"Acc": total_correct / total_labels})

    acc = total_correct / total_labels
    return acc


In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = load_resnet_rgbn('Res_18.pth', 'resnet18').to(device)

acc = evaluate_flatten_model(model, val_loader, device, num_classes=9)
print("Validation accuracy:", acc)


                                                                                  

Validation accuracy: 0.43862222222222225


