In [19]:
import torch
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.v2 as transforms

class CustomCSVDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        # อ่านข้อมูลจากไฟล์ CSV
        self.data = pd.read_csv(csv_file)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data.iloc[idx, 0]
        depth_path = self.data.iloc[idx, 1]

        # โหลดภาพและ depth map
        image = Image.open(img_path).convert('RGB')
        depth = Image.open(depth_path).convert('L')  # Depth map เป็น grayscale

        if self.transform:
            image = self.transform(image)
            depth = self.transform(depth)

        return image, depth


In [20]:
transform = transforms.Compose([
    transforms.ToImage(),
    transforms.ToDtype(torch.float32, scale=True),
])

csv_file = './data/nyu2_test.csv'  # กำหนด path ของไฟล์ CSV
test_dataset = CustomCSVDataset(csv_file, transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)


In [22]:
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
from torch.utils.data import DataLoader

# กำหนด Loss Function
loss_fn = nn.MSELoss()  

# ฟังก์ชันสำหรับคำนวณ IoU
def compute_iou(output, target, threshold=0.5):
    output = (output > threshold).float()
    intersection = (output * target).sum()
    union = output.sum() + target.sum() - intersection
    iou = intersection / union if union != 0 else 0  # ป้องกันการหารด้วยศูนย์
    return iou.item()

# ฟังก์ชันคำนวณ Accuracy
def compute_accuracy(output, target, threshold=1.0):
    """นับจำนวนพิกเซลที่ความแตกต่างน้อยกว่าหรือเท่ากับ threshold"""
    correct_pixels = torch.abs(output - target) <= threshold
    accuracy = correct_pixels.float().mean().item()
    return accuracy

# Valid Loop สำหรับ Evaluation
def evaluate_model(model, dataloader, loss_fn, device):
    model.eval()
    running_loss = 0.0
    running_iou = 0.0
    running_accuracy = 0.0

    with torch.no_grad():
        for batch_idx, (images, depths) in enumerate(dataloader):
            images, depths = images.to(device), depths.to(device)

            outputs = model(images).squeeze(1)  # [B, 1, H, W] -> [B, H, W]
            depths = depths.squeeze(1)  # [B, 1, H, W] -> [B, H, W]

            if outputs.shape != depths.shape:
                print(f"Shape mismatch: {outputs.shape} vs {depths.shape}")
                continue

            # คำนวณ Loss, IoU, และ Accuracy
            loss = loss_fn(outputs, depths)
            iou = compute_iou(outputs, depths)
            accuracy = compute_accuracy(outputs, depths, threshold=1.0)

            running_loss += loss.item()
            running_iou += iou
            running_accuracy += accuracy

        avg_loss = running_loss / len(dataloader)
        avg_iou = running_iou / len(dataloader)
        avg_accuracy = running_accuracy / len(dataloader)

    print(f"Test Loss: {avg_loss:.4f}, Test IoU: {avg_iou:.4f}, Test Accuracy: {avg_accuracy:.4f}")
    return avg_loss, avg_iou, avg_accuracy

# เรียกใช้ FPN แทน U-Net
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = smp.FPN(encoder_name='resnet50', encoder_weights='imagenet', classes=1, activation=None)
model.load_state_dict(torch.load('./final_model.pth', map_location=DEVICE))
model.to(DEVICE)

# รันการประเมินผล
test_loss, test_iou, test_accuracy = evaluate_model(model, test_dataloader, loss_fn, DEVICE)


  model.load_state_dict(torch.load('./final_model.pth', map_location=DEVICE))


RuntimeError: Error(s) in loading state_dict for FPN:
	Missing key(s) in state_dict: "decoder.p5.weight", "decoder.p5.bias", "decoder.p4.skip_conv.weight", "decoder.p4.skip_conv.bias", "decoder.p3.skip_conv.weight", "decoder.p3.skip_conv.bias", "decoder.p2.skip_conv.weight", "decoder.p2.skip_conv.bias", "decoder.seg_blocks.0.block.0.block.0.weight", "decoder.seg_blocks.0.block.0.block.1.weight", "decoder.seg_blocks.0.block.0.block.1.bias", "decoder.seg_blocks.0.block.1.block.0.weight", "decoder.seg_blocks.0.block.1.block.1.weight", "decoder.seg_blocks.0.block.1.block.1.bias", "decoder.seg_blocks.0.block.2.block.0.weight", "decoder.seg_blocks.0.block.2.block.1.weight", "decoder.seg_blocks.0.block.2.block.1.bias", "decoder.seg_blocks.1.block.0.block.0.weight", "decoder.seg_blocks.1.block.0.block.1.weight", "decoder.seg_blocks.1.block.0.block.1.bias", "decoder.seg_blocks.1.block.1.block.0.weight", "decoder.seg_blocks.1.block.1.block.1.weight", "decoder.seg_blocks.1.block.1.block.1.bias", "decoder.seg_blocks.2.block.0.block.0.weight", "decoder.seg_blocks.2.block.0.block.1.weight", "decoder.seg_blocks.2.block.0.block.1.bias", "decoder.seg_blocks.3.block.0.block.0.weight", "decoder.seg_blocks.3.block.0.block.1.weight", "decoder.seg_blocks.3.block.0.block.1.bias". 
	Unexpected key(s) in state_dict: "decoder.blocks.0.conv1.0.weight", "decoder.blocks.0.conv1.1.weight", "decoder.blocks.0.conv1.1.bias", "decoder.blocks.0.conv1.1.running_mean", "decoder.blocks.0.conv1.1.running_var", "decoder.blocks.0.conv1.1.num_batches_tracked", "decoder.blocks.0.conv2.0.weight", "decoder.blocks.0.conv2.1.weight", "decoder.blocks.0.conv2.1.bias", "decoder.blocks.0.conv2.1.running_mean", "decoder.blocks.0.conv2.1.running_var", "decoder.blocks.0.conv2.1.num_batches_tracked", "decoder.blocks.1.conv1.0.weight", "decoder.blocks.1.conv1.1.weight", "decoder.blocks.1.conv1.1.bias", "decoder.blocks.1.conv1.1.running_mean", "decoder.blocks.1.conv1.1.running_var", "decoder.blocks.1.conv1.1.num_batches_tracked", "decoder.blocks.1.conv2.0.weight", "decoder.blocks.1.conv2.1.weight", "decoder.blocks.1.conv2.1.bias", "decoder.blocks.1.conv2.1.running_mean", "decoder.blocks.1.conv2.1.running_var", "decoder.blocks.1.conv2.1.num_batches_tracked", "decoder.blocks.2.conv1.0.weight", "decoder.blocks.2.conv1.1.weight", "decoder.blocks.2.conv1.1.bias", "decoder.blocks.2.conv1.1.running_mean", "decoder.blocks.2.conv1.1.running_var", "decoder.blocks.2.conv1.1.num_batches_tracked", "decoder.blocks.2.conv2.0.weight", "decoder.blocks.2.conv2.1.weight", "decoder.blocks.2.conv2.1.bias", "decoder.blocks.2.conv2.1.running_mean", "decoder.blocks.2.conv2.1.running_var", "decoder.blocks.2.conv2.1.num_batches_tracked", "decoder.blocks.3.conv1.0.weight", "decoder.blocks.3.conv1.1.weight", "decoder.blocks.3.conv1.1.bias", "decoder.blocks.3.conv1.1.running_mean", "decoder.blocks.3.conv1.1.running_var", "decoder.blocks.3.conv1.1.num_batches_tracked", "decoder.blocks.3.conv2.0.weight", "decoder.blocks.3.conv2.1.weight", "decoder.blocks.3.conv2.1.bias", "decoder.blocks.3.conv2.1.running_mean", "decoder.blocks.3.conv2.1.running_var", "decoder.blocks.3.conv2.1.num_batches_tracked", "decoder.blocks.4.conv1.0.weight", "decoder.blocks.4.conv1.1.weight", "decoder.blocks.4.conv1.1.bias", "decoder.blocks.4.conv1.1.running_mean", "decoder.blocks.4.conv1.1.running_var", "decoder.blocks.4.conv1.1.num_batches_tracked", "decoder.blocks.4.conv2.0.weight", "decoder.blocks.4.conv2.1.weight", "decoder.blocks.4.conv2.1.bias", "decoder.blocks.4.conv2.1.running_mean", "decoder.blocks.4.conv2.1.running_var", "decoder.blocks.4.conv2.1.num_batches_tracked". 
	size mismatch for segmentation_head.0.weight: copying a param with shape torch.Size([1, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 128, 1, 1]).