In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install torch torchvision albumentations segmentation_models_pytorch PIL torchmetrics torchmetrics tqdm -q

In [None]:
!pip install PIL tqdm

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

class UNet(nn.Module):
    def __init__(self, num_classes):
        super(UNet, self).__init__()

        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True)
            )

        def up_conv(in_channels, out_channels):
            return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

        # Encoder
        self.enc1 = conv_block(3, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)
        self.enc4 = conv_block(256, 512)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Bottleneck
        self.bottleneck = conv_block(512, 1024)

        # Decoder
        self.upconv4 = up_conv(1024, 512)
        self.dec4 = conv_block(1024, 512)
        self.upconv3 = up_conv(512, 256)
        self.dec3 = conv_block(512, 256)
        self.upconv2 = up_conv(256, 128)
        self.dec2 = conv_block(256, 128)
        self.upconv1 = up_conv(128, 64)
        self.dec1 = conv_block(128, 64)

        # Output layer
        self.conv_final = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))

        # Bottleneck
        bottleneck = self.bottleneck(self.pool(enc4))

        # Decoder
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.dec4(dec4)

        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.dec3(dec3)

        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.dec2(dec2)

        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.dec1(dec1)

        return self.conv_final(dec1)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm


def show_augmented_images(dataloader, num_samples=3):
    batch = next(iter(dataloader))  # Get one batch
    images, masks  = batch

    images = images.cpu().numpy()
    masks = masks.cpu().numpy()

    fig, axes = plt.subplots(num_samples, 2, figsize=(10, 5 * num_samples))

    for i in range(num_samples):
        img = np.transpose(images[i], (1, 2, 0))  # Convert from [C, H, W] to [H, W, C]

        if img.max() <= 1.0:
            img = (img * 255).astype(np.uint8)  # Convert back to 0-255 for display

        mask = masks[i].squeeze()

        axes[i, 0].imshow(img)  # Show image
        axes[i, 0].set_title("Augmentirana slika")
        axes[i, 0].axis("off")

        axes[i, 1].imshow(mask, cmap="gray")  # Show mask
        axes[i, 1].set_title("Augmentirana maska")
        axes[i, 1].axis("off")

    plt.tight_layout()
    plt.show()

#from your_module import UNet, SegmentationDataset  # Make sure to import your UNet model and SegmentationDataset
# Define the SegmentationDataset class
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.image_filenames = [f for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))]

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_path = os.path.join(self.mask_dir, self.image_filenames[idx])
        image = Image.open(image_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        # Convert mask to numpy array
        mask_np = np.array(mask)

        # Adapt the mask to ignore classes 6 and 7, and consider class 5 as part of class 3
        mask_np[mask_np == 5] = 3
        mask_np[mask_np == 6] = 0
        mask_np[mask_np == 7] = 0

        # Debugging: Print unique values in the mask before adaptation
        print(f"Unique values in mask before adaptation: {np.unique(mask_np)}")

        # Convert mask to tensor
        mask = torch.from_numpy(mask_np).long()

        # Debugging: Print unique values in the mask after adaptation
        print(f"Unique values in mask after adaptation: {torch.unique(mask)}")

        return image, mask
# Define the transformations for the training and validation sets
transform = transforms.Compose([
    transforms.ToTensor(),
    # Add other transformations if needed
])

# Create the training and validation datasets
train_dataset = SegmentationDataset(
    image_dir='/content/drive/MyDrive/RGB_datasets_segmentation_V2/images/train',
    mask_dir='/content/drive/MyDrive/RGB_datasets_segmentation_V2/masks/train',
    transform=transform
)

val_dataset = SegmentationDataset(
    image_dir='/content/drive/MyDrive/RGB_datasets_segmentation_V2/images/val',
    mask_dir='/content/drive/MyDrive/RGB_datasets_segmentation_V2/masks/val',
    transform=transform
)

# Create the data loaders
train_loader = DataLoader(train_dataset, batch_size=3, shuffle=True, num_workers=1)
val_loader = DataLoader(val_dataset, batch_size=3, shuffle=False, num_workers=1)

# Initialize the UNet model
model = UNet(num_classes=4)  # Adjust the number of classes as needed

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Move the model to the GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Training loop
num_epochs = 25
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for images, masks in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False):
        #show_augmented_images(train_loader)
        images = images.to(device)
        masks = masks.to(device)
        masks = masks.squeeze(1).long()

        #print(f"unique values in masks: {torch.unique(masks)}")
        #print(f"Masks shape: {masks.shape}")

        # Forward pass
        outputs = model(images)
        #print(f"outputs shape: {outputs.shape}")
        #print(f"masks shape: {outputs[0, :5, 0, 0]}")
        loss = criterion(outputs, masks)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')

        epoch_loss += loss.item()

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(train_loader)}')

    # Validation loop
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device)
            masks = masks.squeeze(1).long()

            outputs = model(images)
            loss = criterion(outputs, masks)
            val_loss += loss.item()

    print(f'Validation Loss: {val_loss/len(val_loader)}')

print('Training complete')
