In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import numpy as np


class RoadDataset(Dataset):
    def __init__(self, path):
        self.path = path
        self.segmentation_name_list = os.listdir(os.path.join(path, "gt_image_2"))
        self.image_name_list = os.listdir(os.path.join(path, "image_2"))

    def __len__(self):
        return len(self.segmentation_name_list)

    def __getitem__(self, idx):
        segmentation_name = self.segmentation_name_list[idx]
        segmentation_path = os.path.join(self.path, "gt_image_2", segmentation_name)
        image_name = self.image_name_list[idx]
        image_path = os.path.join(self.path, "image_2", image_name)

        segmentation_img = Image.open(segmentation_path).convert("RGB")
        image_img = Image.open(image_path).convert("RGB")

        segmentation_tensor = transforms.ToTensor()(segmentation_img)
        image_tensor = transforms.ToTensor()(image_img)

        return image_tensor, segmentation_tensor


In [2]:
dataset = RoadDataset(r'P:\编程\python\MIT-DL\Road-Detection-using-CNN\data_road\training')
print(f"Dataset length: {len(dataset)}")
image, segmentation = dataset[0]  # Get the first item
print(f"Image shape: {image.shape}")
print(f"Segmentation shape: {segmentation.shape}")

Dataset length: 384
Image shape: torch.Size([3, 375, 1242])
Segmentation shape: torch.Size([3, 375, 1242])


In [None]:
import torch.nn.functional as F


class AdaptivePadding:
    def pad(self, x):
        self.original_size = (x.size(2), x.size(3))
        pad_h = (16 - self.original_size[0] % 16) % 16
        pad_w = (16 - self.original_size[1] % 16) % 16
        return F.pad(x, (0, pad_w, 0, pad_h))
    
    def crop(self, x, target_size=None):
        if target_size is None:
            target_size = self.original_size

        return x[:, :, :target_size[0], :target_size[1]]


class Conv_Block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Conv_Block, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.Dropout2d(),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.Dropout2d(),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.layer(x)
    

class UNet(nn.Module):
    def __init__(
            self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]
            ):
        super(UNet, self).__init__()
        self.upsample = nn.ModuleList()
        self.downsample = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Downsampling path
        for feature in features:
            self.downsample.append(Conv_Block(in_channels, feature))
            in_channels = feature

        # Bottleneck
        self.bottleneck = Conv_Block(features[-1], features[-1] * 2)

        # Upsampling path
        for feature in reversed(features):
            self.upsample.append(
                nn.Sequential(
                    nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                    nn.Conv2d(feature * 2, feature, kernel_size=3, padding=1)
                )
            )
            self.upsample.append(Conv_Block(feature * 2, feature))

        # Final output layer
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

        self.padding = AdaptivePadding()

    def forward(self, x):
        original_size = x.shape[2:]
        x = self.padding.pad(x)

        skip_connections = []

        for down in self.downsample:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)

        skip_connections = skip_connections[::-1]

        for i in range(0, len(self.upsample), 2):
            x = self.upsample[i](x)
            skip_connection = skip_connections[i // 2]

            if x.shape[2:] != skip_connection.shape[2:]:
                diffy = skip_connection.size()[2] - x.size()[2]
                diffx = skip_connection.size()[3] - x.size()[3]
                x = F.pad(x, [diffx // 2, diffx - diffx // 2,
                              diffy // 2, diffy - diffy // 2])


            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.upsample[i + 1](concat_skip)

        x = self.final_conv(x)
        return self.padding.crop(x, original_size)
