In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install torch torchvision albumentations segmentation_models_pytorch PIL torchmetrics torchmetrics tqdm -q

In [None]:
!pip install PIL

In [None]:
import torch
import torch.nn as nn

class UNet(nn.Module):
    def __init__(self, num_classes=4):
        super().__init__()

        # Encoder
        self.conv1 = self._conv_block(3, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = self._conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = self._conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = self._conv_block(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = self._conv_block(512, 1024)

        # Decoder
        self.up6 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.conv6 = self._conv_block(1024, 512)
        self.up7 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv7 = self._conv_block(512, 256)
        self.up8 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv8 = self._conv_block(256, 128)
        self.up9 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv9 = self._conv_block(128, 64)
        self.conv10 = nn.Conv2d(64, num_classes, kernel_size=1)

    def _conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Encoder
        c1 = self.conv1(x)
        p1 = self.pool1(c1)
        c2 = self.conv2(p1)
        p2 = self.pool2(c2)
        c3 = self.conv3(p2)
        p3 = self.pool3(c3)
        c4 = self.conv4(p3)
        p4 = self.pool4(c4)
        c5 = self.conv5(p4)

        # Decoder
        up_6 = self.up6(c5)
        merge6 = torch.cat([up_6, c4], dim=1)
        c6 = self.conv6(merge6)
        up_7 = self.up7(c6)
        merge7 = torch.cat([up_7, c3], dim=1)
        c7 = self.conv7(merge7)
        up_8 = self.up8(c7)
        merge8 = torch.cat([up_8, c2], dim=1)
        c8 = self.conv8(merge8)
        up_9 = self.up9(c8)
        merge9 = torch.cat([up_9, c1], dim=1)
        c9 = self.conv9(merge9)
        c10 = self.conv10(c9)

        return c10

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm


def show_augmented_images(dataloader, num_samples=3):
    # Define a color map for the classes
    class_colors = {
        0: [0, 0, 0],      # Background or Ignore
        1: [0, 255, 0],    # Static Obstacle (Green)
        2: [255, 0, 0],    # Dynamic Obstacle (Red)
        3: [0, 0, 255]     # Water (Blue)
        # Add more classes as needed
    }

    batch = next(iter(dataloader))  # Get one batch
    images, masks = batch

    images = images.cpu().numpy()
    masks = masks.cpu().numpy()

    fig, axes = plt.subplots(num_samples, 2, figsize=(10, 5 * num_samples))

    for i in range(num_samples):
        img = np.transpose(images[i], (1, 2, 0))  # Convert from [C, H, W] to [H, W, C]

        if img.max() <= 1.0:
            img = (img * 255).astype(np.uint8)  # Convert back to 0-255 for display

        mask = masks[i].squeeze()

        # Debugging: Print unique values in the mask
        #print(f"Unique values in mask {i}: {np.unique(mask)}")

        # Create an RGB image for the mask
        mask_rgb = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
        for class_label, color in class_colors.items():
            mask_rgb[mask == class_label] = color

        axes[i, 0].imshow(img)  # Show image
        axes[i, 0].set_title("Augmented Image")
        axes[i, 0].axis("off")

        axes[i, 1].imshow(mask_rgb)  # Show mask in RGB
        axes[i, 1].set_title("Augmented Mask")
        axes[i, 1].axis("off")

    plt.tight_layout()
    plt.show()


#from your_module import UNet, SegmentationDataset  # Make sure to import your UNet model and SegmentationDataset
# Define the SegmentationDataset class
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.image_filenames = [f for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))]

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_path = os.path.join(self.mask_dir, self.image_filenames[idx])
        image = Image.open(image_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        # Convert mask to numpy array
        mask_np = np.array(mask)
        #print(f"Unique values in mask {image}: {np.unique(mask)}")

        #mask = np.clip(mask, 0, 3)
        if mask_np.max() <= 1.0:
            mask_np = (mask_np * 255).astype(np.uint8)

        # Adapt the mask to ignore classes 6 and 7, and consider class 5 as part of class 3
        mask_np[mask_np == 5] = 3
        mask_np[mask_np == 6] = 0
        mask_np[mask_np == 7] = 0

        # Debugging: Print unique values in the mask before adaptation
        #print(f"Unique values in mask before adaptation: {np.unique(mask_np)}")

        # Convert mask to tensor
        mask = torch.from_numpy(mask_np).long()

        # Debugging: Print unique values in the mask after adaptation
        #print(f"Unique values in mask after adaptation: {torch.unique(mask)}")

        return image, mask
# Define the transformations for the training and validation sets
transform = transforms.Compose([
    transforms.ToTensor(),
    # Add other transformations if needed
])

# Create the training and validation datasets
train_dataset = SegmentationDataset(
    image_dir='/content/drive/MyDrive/RGB_datasets_segmentation_V2/images/train',
    mask_dir='/content/drive/MyDrive/RGB_datasets_segmentation_V2/masks/train',
    transform=transform
)

val_dataset = SegmentationDataset(
    image_dir='/content/drive/MyDrive/RGB_datasets_segmentation_V2/images/val',
    mask_dir='/content/drive/MyDrive/RGB_datasets_segmentation_V2/masks/val',
    transform=transform
)

# Create the data loaders
train_loader = DataLoader(train_dataset, batch_size=3, shuffle=True, num_workers=1)
val_loader = DataLoader(val_dataset, batch_size=3, shuffle=False, num_workers=1)

# Initialize the UNet model
model = UNet(num_classes=4)  # Adjust the number of classes as needed

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Move the model to the GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Training loop
num_epochs = 25
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for images, masks in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False):
        #show_augmented_images(train_loader)
        images = images.to(device)
        masks = masks.to(device)
        masks = masks.squeeze(1).long()

        # Forward pass
        outputs = model(images)

        loss = criterion(outputs, masks)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')

        epoch_loss += loss.item()

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(train_loader)}')

    # Validation loop
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device)
            masks = masks.squeeze(1).long()

            outputs = model(images)
            loss = criterion(outputs, masks)
            val_loss += loss.item()

    print(f'Validation Loss: {val_loss/len(val_loader)}')

print('Training complete')