In [3]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, concatenate, UpSampling2D, Conv2DTranspose
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import Accuracy, Precision, Recall, MeanIoU
from PIL import Image
import os


In [9]:
# Step 1: Load the dataset (RGB images and class labels with RGB values)

# Load RGB images
rgb_images_dir = "/home/don/Git/aerial-semantic-segmentation/dataset_here/dataset/semantic_drone_dataset/original_images"
rgb_images = []
rgb_images_names = []
# Load RGB images
for image_name in sorted(os.listdir(rgb_images_dir)):
    image_path = os.path.join(rgb_images_dir, image_name)
    if image_name.lower().endswith('.jpg') or image_name.lower().endswith('.jpeg'):
        label = os.path.splitext(image_name)[0]
        rgb_images.append(image_path)
        rgb_images_names.append(label)

# Load class labels from CSV
# class_labels_df = pd.read_csv("/home/don/Git/aerial-semantic-segmentation/dataset_here/class_dict_seg.csv")
labels = pd.read_csv('/home/don/Git/aerial-semantic-segmentation/dataset_here/class_dict_seg.csv')
# convert to to list so each channel can be accessed
labels = labels.values.tolist()
print(labels)

[['unlabeled', 0, 0, 0], ['paved-area', 128, 64, 128], ['dirt', 130, 76, 0], ['grass', 0, 102, 0], ['gravel', 112, 103, 87], ['water', 28, 42, 168], ['rocks', 48, 41, 30], ['pool', 0, 50, 89], ['vegetation', 107, 142, 35], ['roof', 70, 70, 70], ['wall', 102, 102, 156], ['window', 254, 228, 12], ['door', 254, 148, 12], ['fence', 190, 153, 153], ['fence-pole', 153, 153, 153], ['person', 255, 22, 96], ['dog', 102, 51, 0], ['car', 9, 143, 150], ['bicycle', 119, 11, 32], ['tree', 51, 51, 0], ['bald-tree', 190, 250, 190], ['ar-marker', 112, 150, 146], ['obstacle', 2, 135, 115], ['conflicting', 255, 0, 0]]


In [10]:
class_labels = labels

# Step 2: Prepare the data
# Split the class labels into separate color channels
red_channel = class_labels[:, 0]
green_channel = class_labels[:, 1]
blue_channel = class_labels[:, 2]
# Combine the class labels into a single categorical label
combined_labels = red_channel + 256 * green_channel + 256 * 256 * blue_channel
# Apply one-hot encoding to convert categorical labels to binary vectors
encoder = LabelEncoder()
categorical_labels = keras.utils.to_categorical(encoder.fit_transform(combined_labels.ravel()))

# Step 3: Split the data into train and test sets
train_images, test_images, train_labels, test_labels = train_test_split(rgb_images, categorical_labels, test_size=0.2, random_state=42)

# Step 4: Create a U-Net model
def unet_model():
    num_channels = 3
    num_classes = categorical_labels.shape[1]
    
    # Define the U-Net architecture
    inputs = Input((256, 256, num_channels))
    
    # Encoder
    conv1 = Conv2D(64, 3, activation='relu', padding='same')(inputs)
    conv1 = Conv2D(64, 3, activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    
    # Decoder
    conv9 = Conv2D(num_classes, 1, activation='softmax')(pool1)
    
    # Create the model
    model = Model(inputs=inputs, outputs=conv9)
    return model

# Create an instance of the U-Net model
unet_model = unet_model()

# Step 5: Compile the model
unet_model.compile(optimizer=Adam(), loss='categorical_crossentropy', metrics=[Accuracy(), Precision(), Recall(), MeanIoU(num_classes=num_classes)])

# Step 6: Train the model
unet_model.fit(train_images, train_labels, batch_size=16, epochs=10, validation_data=(test_images, test_labels))

# Step 7: Evaluate the model
evaluation = unet_model.evaluate(test_images, test_labels)
print("Evaluation loss:", evaluation[0])
print("Accuracy:", evaluation[1])
print("Precision:", evaluation[2])
print("Recall:", evaluation[3])
print("Mean IoU:", evaluation[4])

# Step 8: Make predictions
predictions = unet_model.predict(test_images)


TypeError: list indices must be integers or slices, not tuple