In [1]:
import os
import random
import numpy as np

import torch
import torch.nn as
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import torchvision
from torchvision import transforms
from torchvision.datasets import VOCSegmentation
import torchvision.transforms.functional as TF
from PIL import Image

In [2]:
# ---------------------------------
# 1. VOCSegmentation 데이터셋 (RGB + Dummy Depth)
# ---------------------------------
class VOCSegWithDepth(Dataset):
    """
    VOCSegmentation 데이터셋에 대해,
    각 샘플을 (RGB image, Dummy Depth image, Target segmentation mask)으로 반환합니다.

    Dummy Depth는 RGB 이미지를 그레이스케일로 변환하여 만듭니다.
    """
    def __init__(self, root, year='2012', image_set='train',
                 transform_rgb=None, transform_depth=None, transform_target=None, download=False):
        self.voc = VOCSegmentation(root=root, year=year, image_set=image_set, download=download)
        self.transform_rgb = transform_rgb
        self.transform_depth = transform_depth  # Dummy depth transform
        self.transform_target = transform_target

    def __len__(self):
        return len(self.voc)

    def __getitem__(self, index):
        # voc[ index ] returns (image, target) both as PIL images
        rgb, target = self.voc[index]

        # Dummy Depth: convert RGB to grayscale
        depth = TF.to_grayscale(rgb, num_output_channels=1)

        if self.transform_rgb:
            rgb = self.transform_rgb(rgb)
        if self.transform_depth:
            depth = self.transform_depth(depth)
        if self.transform_target:
            target = self.transform_target(target)
        return rgb, depth, target

In [3]:
# ---------------------------------
# 2. Transform 정의
# ---------------------------------
img_size = (224, 224)

# RGB 이미지: bilinear interpolation, tensor 변환 및 정규화 (ImageNet 평균/표준편차 사용)
transform_rgb = transforms.Compose([
    transforms.Resize(img_size, interpolation=Image.BILINEAR),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std =[0.229, 0.224, 0.225])
])

# Dummy Depth: Nearest Neighbor로 리사이즈, 단일 채널 Tensor로 변환
transform_depth = transforms.Compose([
    transforms.Resize(img_size, interpolation=Image.NEAREST),
    transforms.ToTensor()  # 출력은 (1, H, W) 범위 [0,1]
])

# Segmentation target: Nearest Neighbor, Tensor 변환 (VOC segmentation은 255가 ignore index)
# VOCSegmentation의 target은 PIL 이미지로, 각 픽셀이 0~21 (또는 0~20) 클래스 값을 가짐.
class ToLongTensor(object):
    def __call__(self, pic):
        return torch.from_numpy(np.array(pic)).long()

transform_target = transforms.Compose([
    transforms.Resize(img_size, interpolation=Image.NEAREST),
    ToLongTensor()
])

In [4]:
# ---------------------------------
# 3. 데이터셋 및 DataLoader 준비
# ---------------------------------
root = './VOC2012'  # 저장 경로
# 학습 데이터셋 (train)
train_dataset = VOCSegWithDepth(root=root, year='2012', image_set='train', download=True,
                                  transform_rgb=transform_rgb,
                                  transform_depth=transform_depth,
                                  transform_target=transform_target)
# 테스트 데이터셋 (val)
test_dataset = VOCSegWithDepth(root=root, year='2012', image_set='val', download=True,
                                 transform_rgb=transform_rgb,
                                 transform_depth=transform_depth,
                                 transform_target=transform_target)

batch_size = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

100%|██████████| 2.00G/2.00G [00:24<00:00, 82.1MB/s]


In [5]:
# ---------------------------------
# 4. 간단한 FuseNet 모델 구현 (Simplified Version)
# ---------------------------------
# 기본 구성 블록: Conv -> BatchNorm -> ReLU (CBR)
class CBR(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(CBR, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
        self.bn   = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

# Encoder branch: 간단하게 한 블록 사용 (CBR + MaxPool)
class EncoderBranch(nn.Module):
    def __init__(self, in_channels):
        super(EncoderBranch, self).__init__()
        self.block = nn.Sequential(
            CBR(in_channels, 64),
            CBR(64, 64)
        )
        self.pool  = nn.MaxPool2d(2,2)  # 다운샘플링
    def forward(self, x):
        features = self.block(x)
        pooled   = self.pool(features)
        return features, pooled

# 간단한 디코더: 업샘플링(Interpolation) + CBR 블록 + 1x1 Convolution으로 예측 채널 맞춤
class Decoder(nn.Module):
    def __init__(self, num_classes):
        super(Decoder, self).__init__()
        self.conv = nn.Sequential(
            CBR(64, 64),
            nn.Conv2d(64, num_classes, kernel_size=1)
        )
    def forward(self, x, output_size):
        # 단순하게 bilinear interpolation으로 업샘플링
        x = nn.functional.interpolate(x, size=output_size, mode='bilinear', align_corners=True)
        x = self.conv(x)
        return x

# FuseNet: 두 개의 branch (RGB, Depth) → Fusion → Decoder
class SimpleFuseNet(nn.Module):
    def __init__(self, num_classes=21):
        super(SimpleFuseNet, self).__init__()
        self.rgb_encoder = EncoderBranch(in_channels=3)
        self.depth_encoder = EncoderBranch(in_channels=1)
        # 여기서는 단순 fusion: 각 branch의 block 출력을 element-wise sum
        self.decoder = Decoder(num_classes)
    def forward(self, rgb, depth):
        # 인코더 각각에서 feature 추출
        rgb_features, _ = self.rgb_encoder(rgb)       # (batch, 64, H', W')
        depth_features, _ = self.depth_encoder(depth)   # (batch, 64, H', W')
        # Fusion: element-wise summation
        fused = rgb_features + depth_features
        # 디코더: 업샘플링. 원래 입력 이미지 해상도 (224,224)로 복원
        out = self.decoder(fused, output_size=rgb.size()[2:])
        return out

In [6]:
# ---------------------------------
# 5. 학습 및 평가 준비
# ---------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 21  # VOC Segmentation의 경우 21개 클래스 (0 ~ 20)

model = SimpleFuseNet(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=255)  # 255를 무시 (VOC 라벨 기준)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005)

In [7]:
# ---------------------------------
# 6. 학습 루프
# ---------------------------------
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, (rgb, depth, target) in enumerate(train_loader):
        rgb = rgb.to(device)
        depth = depth.to(device)
        target = target.to(device)

        optimizer.zero_grad()
        outputs = model(rgb, depth)  # 출력: (batch, num_classes, H, W)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if (i + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}")

    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {epoch_loss:.4f}")

Epoch [1/10], Step [10/366], Loss: 2.6160
Epoch [1/10], Step [20/366], Loss: 2.2132
Epoch [1/10], Step [30/366], Loss: 1.6581
Epoch [1/10], Step [40/366], Loss: 1.2841
Epoch [1/10], Step [50/366], Loss: 1.8149
Epoch [1/10], Step [60/366], Loss: 1.0158
Epoch [1/10], Step [70/366], Loss: 1.7720
Epoch [1/10], Step [80/366], Loss: 0.8940
Epoch [1/10], Step [90/366], Loss: 1.2676
Epoch [1/10], Step [100/366], Loss: 1.6330
Epoch [1/10], Step [110/366], Loss: 0.8756
Epoch [1/10], Step [120/366], Loss: 0.7867
Epoch [1/10], Step [130/366], Loss: 0.6887
Epoch [1/10], Step [140/366], Loss: 1.1580
Epoch [1/10], Step [150/366], Loss: 1.4648
Epoch [1/10], Step [160/366], Loss: 1.6416
Epoch [1/10], Step [170/366], Loss: 1.1076
Epoch [1/10], Step [180/366], Loss: 1.2719
Epoch [1/10], Step [190/366], Loss: 2.6285
Epoch [1/10], Step [200/366], Loss: 1.2336
Epoch [1/10], Step [210/366], Loss: 0.9260
Epoch [1/10], Step [220/366], Loss: 1.2159
Epoch [1/10], Step [230/366], Loss: 1.1352
Epoch [1/10], Step [

In [8]:
# ---------------------------------
# 7. 평가 루프 (Global Accuracy 예시)
# ---------------------------------
model.eval()
total_correct = 0
total_pixels = 0
with torch.no_grad():
    for rgb, depth, target in test_loader:
        rgb = rgb.to(device)
        depth = depth.to(device)
        target = target.to(device)

        outputs = model(rgb, depth)
        _, preds = torch.max(outputs, 1)
        total_correct += (preds == target).sum().item()
        total_pixels += torch.numel(target)

print(f"Test Global Accuracy: {100 * total_correct / total_pixels:.2f}%")

Test Global Accuracy: 69.27%
