In [19]:
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
from PIL import Image

# Define U-Net architecture
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.down_conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.down_conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.up_conv1 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.up_conv2 = nn.Conv2d(64, 3, kernel_size=3, padding=1)

    def forward(self, x):
        # Downward path
        x1 = F.relu(self.down_conv1(x))
        x2 = self.pool(F.relu(self.down_conv2(x1)))

        # Upward path
        x = F.interpolate(x2, scale_factor=2, mode='bilinear', align_corners=True)
        x = F.relu(self.up_conv1(x))
        x = F.relu(self.up_conv2(x))

        return x


class CustomDataset(Dataset):
    def __init__(self, root, folder_name, batch_size=8, transform=None):
        self.root = root
        self.folder_name = folder_name+"/images"
        self.batch_size = batch_size
        self.transform = transform
        self.images_folder = os.path.join(self.root, 'train', self.folder_name)

        self.image_files = sorted(os.listdir(self.images_folder))
        self.num_samples = len(self.image_files)

    def __len__(self):
        return (self.num_samples + self.batch_size - 1) // self.batch_size

    def __getitem__(self, idx):
        start_idx = idx * self.batch_size
        end_idx = min((idx + 1) * self.batch_size, self.num_samples)

        images = []
        for i in range(start_idx, end_idx):
            img_name = self.image_files[i]
            img_path = os.path.join(self.images_folder, img_name)
            image = Image.open(img_path).convert('RGB')

            if self.transform:
                image = self.transform(image)

            images.append(image)

        return images

# Example usage
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

# Specify the folder you want to access inside the train folder
image_type = 'church'  
batch_size = 8 
train_data = CustomDataset(root="./data", folder_name=image_type, batch_size=batch_size, transform=transform)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)


# Initialize model, loss function, and optimizer
model = UNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

print(train_data.image_files)

# Training loop
epochs = 5
for epoch in range(epochs):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        
        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels.squeeze(1).long())  # CrossEntropyLoss expects labels to be 1D
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 100 == 99:
            print(f"[Epoch {epoch + 1}, Batch {i + 1}] Loss: {running_loss / 100:.3f}")
            running_loss = 0.0

print("Finished Training")

['00001.png', '00002.png', '00003.png', '00004.png', '00005.png', '00006.png', '00007.png', '00008.png', '00009.png', '00010.png', '00011.png', '00012.png', '00013.png', '00014.png', '00015.png', '00016.png', '00017.png', '00018.png', '00019.png', '00020.png', '00021.png', '00022.png', '00023.png', '00024.png', '00025.png', '00026.png', '00027.png', '00028.png', '00029.png', '00030.png', '00031.png', '00032.png', '00033.png', '00034.png', '00035.png', '00036.png', '00037.png', '00038.png', '00039.png', '00040.png', '00041.png', '00042.png', '00043.png', '00044.png', '00045.png', '00046.png', '00047.png', '00048.png', '00049.png', '00050.png', '00051.png', '00052.png', '00053.png', '00054.png', '00055.png', '00056.png', '00057.png', '00058.png', '00059.png', '00060.png', '00061.png', '00062.png', '00063.png', '00064.png', '00065.png', '00066.png', '00067.png', '00068.png', '00069.png', '00070.png', '00071.png', '00072.png', '00073.png', '00074.png', '00075.png', '00076.png', '00077.png'

RuntimeError: each element in list of batch should be of equal size

In [2]:
import matplotlib.pyplot as plt
import numpy as np

# Function to convert segmentation mask to RGB for visualization
def mask_to_rgb(mask):
    colors = np.array([
        [128, 64, 128],  # road
        [244, 35, 232],  # sidewalk
        [70, 70, 70],    # building
        [102, 102, 156], # wall
        [190, 153, 153], # fence
        [153, 153, 153], # pole
        [250, 170, 30],  # traffic light
        [220, 220, 0],   # traffic sign
        [107, 142, 35],  # vegetation
        [152, 251, 152], # terrain
        [70, 130, 180],  # sky
        [220, 20, 60],   # person
        [255, 0, 0],     # rider
        [0, 0, 142],     # car
        [0, 0, 70],      # truck
        [0, 60, 100],    # bus
        [0, 80, 100],    # train
        [0, 0, 230],     # motorcycle
        [119, 11, 32]    # bicycle
    ])
    colored_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)
    for label in range(len(colors)):
        colored_mask[mask == label] = colors[label]
    return colored_mask

# Sample some images from the dataset for visualization
sample_loader = DataLoader(train_data, batch_size=1, shuffle=True)
num_samples = 5
fig, axes = plt.subplots(num_samples, 2, figsize=(10, num_samples * 5))

model.eval()
with torch.no_grad():
    for i, (image, target) in enumerate(sample_loader):
        if i >= num_samples:
            break
        output = model(image)
        pred_mask = torch.argmax(output.squeeze(), dim=0).numpy()
        axes[i, 0].imshow(np.transpose(image.squeeze().numpy(), (1, 2, 0)))
        axes[i, 0].set_title('Input Image')
        axes[i, 0].axis('off')
        axes[i, 1].imshow(mask_to_rgb(pred_mask))
        axes[i, 1].set_title('Predicted Segmentation Mask')
        axes[i, 1].axis('off')

plt.tight_layout()
plt.show()

NameError: name 'train_data' is not defined