In [1]:
import os
import cv2
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Conv2DTranspose, concatenate
from tensorflow.keras.models import Model

In [2]:
dataset_dir = "/Users/aryan/Downloads/rs19_val"

In [3]:
def unet(input_size=(256, 256, 3)):
    inputs = tf.keras.Input(input_size)

    # Encoder
    c1 = Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    p1 = MaxPooling2D((2, 2))(c1)
    c2 = Conv2D(128, (3, 3), activation='relu', padding='same')(p1)
    p2 = MaxPooling2D((2, 2))(c2)

    # Bottleneck
    c3 = Conv2D(256, (3, 3), activation='relu', padding='same')(p2)

    # Decoder
    u1 = Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same')(c3)
    m1 = concatenate([u1, c2])
    c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(m1)
    u2 = Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same')(c4)
    m2 = concatenate([u2, c1])
    c5 = Conv2D(64, (3, 3), activation='relu', padding='same')(m2)

    # Output layer
    outputs = Conv2D(1, (1, 1), activation='sigmoid')(c5)

    model = Model(inputs=[inputs], outputs=[outputs])
    return model

In [4]:
def load_data(dataset_dir, img_size=(256, 256)):
    images = []
    labels = []

    for filename in os.listdir(os.path.join(dataset_dir, 'jpgs', 'rs19_val')):
        img_path = os.path.join(dataset_dir, 'jpgs', 'rs19_val', filename)
        label_path = os.path.join(dataset_dir, 'uint8', 'rs19_val', filename.replace('.jpg', '.png'))

        print(f"Image path: {img_path}")  # Debugging print statement
        print(f"Label path: {label_path}")  # Debugging print statement

        img = cv2.imread(img_path)
        
        if img is None:
            print(f"Failed to load image: {img_path}")  # Debugging print statement
            continue

        img = cv2.resize(img, img_size)

        label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
        if label is None:
            print(f"Failed to load label: {label_path}")  # Debugging print statement
            continue

        label = cv2.resize(label, img_size)
        label = np.expand_dims(label, axis=-1)

        images.append(img)
        labels.append(label)

    images = np.array(images, dtype=np.float32) / 255.0
    labels = np.array(labels, dtype=np.float32) / 255.0

    return train_test_split(images, labels, test_size=0.2, random_state=42)


In [5]:
def train_model(dataset_dir):
    X_train, X_test, y_train, y_test = load_data(dataset_dir)
    model = unet()
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=5)

    return model


In [6]:
train_model(dataset_dir)

Image path: /Users/aryan/Downloads/rs19_val/jpgs/rs19_val/rs05276.jpg
Label path: /Users/aryan/Downloads/rs19_val/uint8/rs19_val/rs05276.png
Image path: /Users/aryan/Downloads/rs19_val/jpgs/rs19_val/rs02519.jpg
Label path: /Users/aryan/Downloads/rs19_val/uint8/rs19_val/rs02519.png
Image path: /Users/aryan/Downloads/rs19_val/jpgs/rs19_val/rs03607.jpg
Label path: /Users/aryan/Downloads/rs19_val/uint8/rs19_val/rs03607.png
Image path: /Users/aryan/Downloads/rs19_val/jpgs/rs19_val/rs04168.jpg
Label path: /Users/aryan/Downloads/rs19_val/uint8/rs19_val/rs04168.png
Image path: /Users/aryan/Downloads/rs19_val/jpgs/rs19_val/rs07461.jpg
Label path: /Users/aryan/Downloads/rs19_val/uint8/rs19_val/rs07461.png
Image path: /Users/aryan/Downloads/rs19_val/jpgs/rs19_val/rs01010.jpg
Label path: /Users/aryan/Downloads/rs19_val/uint8/rs19_val/rs01010.png
Image path: /Users/aryan/Downloads/rs19_val/jpgs/rs19_val/rs06019.jpg
Label path: /Users/aryan/Downloads/rs19_val/uint8/rs19_val/rs06019.png
Image path: /

2024-03-28 19:09:50.623634: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.engine.functional.Functional at 0x179d5c700>

In [54]:
def segment_video(model, video_path, output_path, frame_size=(256, 256)):
    cap = cv2.VideoCapture(video_path)
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    out = cv2.VideoWriter(output_path, fourcc, 20.0, frame_size)

    while(cap.isOpened()):
        ret, frame = cap.read()
        if ret:
            frame = cv2.resize(frame, frame_size)
            frame = np.expand_dims(frame, axis=0) / 255.0
            prediction = model.predict(frame)
            mask = (prediction > 0.5).astype(np.uint8) * 255
            mask = np.squeeze(mask, axis=0)
            out.write(mask)
        else:
            break

    cap.release()
    out.release()