In [9]:
import cv2
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.encoders import get_preprocessing_fn

In [10]:
# Mapping categories in TEST/masks
category_colors = {
    "Saliency": [0, 0, 0],  # Background (waterbody) (Called also BW)
    "HD": [0, 0, 255],      # Human divers
    "PF": [0, 255, 0],      # Aquatic plants and sea-grass
    "WR": [0, 255, 255],    # Wrecks and ruins
    "RO": [255, 0, 0],      # Robots (AUVs/ROVs/instruments)
    "RI": [255, 0, 255],    # Reefs and invertebrates
    "FV": [255, 255, 0],    # Fish and vertebrates
    "SR": [255, 255, 255],  # Sea-floor and rocks
}

CLASSES = list(category_colors.keys())
color_to_class = {tuple(value): idx for idx, (_, value) in enumerate(category_colors.items())}
idx_to_class = {idx: class_name for idx, class_name in enumerate(CLASSES)}

In [11]:
class_colors = {v: k for k, v in color_to_class.items()}

def create_color_mask(pred_mask, class_colors):
    # Initialize the colored mask with black (default background)
    height, width = pred_mask.shape
    colored_mask = np.zeros((height, width, 3), dtype=np.uint8)
    
    masks_per_class = []

    # Loop through each class and apply the corresponding color
    for class_id, color in class_colors.items():
        class_mask = (pred_mask == class_id)  # Create a binary mask for the class
        colored_mask[class_mask] = color      # Apply the color where the class mask is 1
        
        mask_per_class = np.full((height, width, 3), fill_value=(200, 200, 200), dtype=np.uint8) # Set a light grey color to see also the black mask
        mask_per_class[class_mask] = color
        masks_per_class.append(mask_per_class)

    return colored_mask, masks_per_class

In [12]:
MODEL_NAME = "Unet-resnet50-ssl"
ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'ssl'
EPOCHS = 50
INDEX_RUN = 2
  
RESULTS_FOLDER = f"./results/{MODEL_NAME}/{INDEX_RUN}-epochs{EPOCHS}"

BEST_MODEL = f'{RESULTS_FOLDER}/best_model.pth'

In [13]:
ACTIVATION = 'softmax2d'  # Use sigmoid  if doing one-single-class segmentation
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class UNetWithDropout(smp.Unet):
    def __init__(self, encoder_name, encoder_weights, classes, activation, in_channels=3, dropout_prob=0.5):
        super().__init__(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            classes=classes,
            activation=activation,
            in_channels=in_channels
        )
        self.dropout = nn.Dropout2d(p=dropout_prob)  # Dropout layer

    def forward(self, x):
        """Forward method with dropout added to the encoder and decoder outputs."""
        features = self.encoder(x)
        features = [self.dropout(feature) for feature in features]  # Apply dropout after each encoder layer
        decoder_output = self.decoder(*features)  # Decode features
        masks = self.segmentation_head(decoder_output)  # Generate segmentation mask
        return masks

# Create FPN model with pretrained encoder
model = UNetWithDropout(
    encoder_name=ENCODER,
    encoder_weights=ENCODER_WEIGHTS,  # Use 'imagenet' pretrained weights for encoder initialization
    classes=len(CLASSES),  # Number of classes in your dataset
    activation=ACTIVATION,  # Activation function for the output
    in_channels=3,  # Model input channels (1 for gray-scale images, 3 for RGB, etc.)
    dropout_prob=0.3
    # decoder_dropout=0.5
)

preprocessing_fn = get_preprocessing_fn(ENCODER, pretrained=ENCODER_WEIGHTS)

print("Running on: ", DEVICE)

Running on:  cuda


In [14]:
from torchvision.transforms import functional as F

def preprocess_frame(frame, device, input_size=(256, 256)):
    """Preprocess a single video frame for the segmentation model."""
    frame_resized = cv2.resize(frame, input_size)
    frame_resized = frame_resized / 255.0
    # Convert to tensor and permute to (C, H, W)
    frame_tensor = torch.tensor(frame_resized, dtype=torch.float32).permute(2, 0, 1)
    # Add batch dimension and move to the appropriate device
    frame_tensor = frame_tensor.unsqueeze(0).to(device)
    return frame_tensor

In [21]:
def process_video_with_segmentation(model, input_video_path, output_video_path, device, input_size=(256, 256)):
    cap = cv2.VideoCapture(input_video_path)
    
    if not cap.isOpened():
        print("Error: Could not open the video file.")
        return

    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))

    with tqdm(total=total_frames, desc="Processing video frames", unit="frame") as pbar:
        while cap.isOpened():
            ret, frame = cap.read()
            
            if not ret:
                break

            input_tensor = preprocess_frame(frame, device, input_size)

            with torch.no_grad():
                pr_mask = model(input_tensor)  # Forward pass
                pr_mask = torch.argmax(pr_mask, dim=1).squeeze(0).cpu().numpy()
                    
            colored_mask, colored_mask_per_class = create_color_mask(pr_mask, class_colors)

            overlay = cv2.resize(colored_mask, (frame_width, frame_height))
            result_frame = cv2.addWeighted(frame, 0.7, overlay, 0.3, 0)

            out.write(result_frame)

            pbar.update(1)

    cap.release()
    out.release()
    print("Segmented video saved at:", output_video_path)

In [None]:
segmentation_model = torch.load(BEST_MODEL)

input_video = "video/test2.mp4"  
output_video = "video/segmented-test2.mp4"

process_video_with_segmentation(
    model=segmentation_model,
    input_video_path=input_video,
    output_video_path=output_video,
    device=DEVICE,
    input_size=(1280, 768)
)


  segmentation_model = torch.load(BEST_MODEL)
Processing video frames:   0%|          | 0/1154 [00:00<?, ?frame/s]


TypeError: conv2d() received an invalid combination of arguments - got (numpy.ndarray, Parameter, NoneType, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias = None, tuple of ints stride = 1, tuple of ints padding = 0, tuple of ints dilation = 1, int groups = 1)
      didn't match because some of the arguments have invalid types: (!numpy.ndarray!, !Parameter!, !NoneType!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, !int!)
 * (Tensor input, Tensor weight, Tensor bias = None, tuple of ints stride = 1, str padding = "valid", tuple of ints dilation = 1, int groups = 1)
      didn't match because some of the arguments have invalid types: (!numpy.ndarray!, !Parameter!, !NoneType!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, !int!)
