## Data

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset
# import torch_directml
import cv2
import numpy as np
from torchvision import transforms
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2

# TODO try with 720 x 1280 -> 704 x 1280 because it needs to be a multiple of 32
img_size = (224, 384) # image width and height
# img_size = (448, 768) # image width and height


class FoosballDataset(Dataset):
    def __init__(self, image_dir, mask_dir, augment=True):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.augment = augment
        self.image_filenames = sorted(os.listdir(image_dir))
        self.mask_filenames = sorted(os.listdir(mask_dir))

        # Torchvision Base Transforms (Resize & Normalize)
        self.base_transform = transforms.Compose([
            transforms.Resize(img_size),  # Resize image
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize to [-1,1]
        ])

        # Albumentations Augmentations
        self.augmentation = A.Compose([
            A.Resize(img_size[0],img_size[1]),  # Resize
            A.HorizontalFlip(p=0.5),  
            A.VerticalFlip(p=0.5),  
            A.RandomBrightnessContrast(p=0.7),
            A.GaussianBlur(blur_limit=(3, 7), p=0.2),
            A.GaussNoise(p=0.5),
            A.Rotate(limit=(-90, 90), p=0.4),
            A.ElasticTransform(alpha=1, sigma=50, p=0.4),
            A.CoarseDropout(num_holes_range=(1,8), hole_height_range=[30,60], hole_width_range=(30,60), p=0.5),
            # TODO translation and scaling
            # A.RandomScale((-0.1,0.1),cv2.INTER_LINEAR,cv2.INTER_NEAREST,p=0.5),
            # A.Resize(180,320),  # Resize
            A.Normalize(mean=(0.5), std=(0.5)),
            ToTensorV2()
        ])

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_filenames[idx])

        # Load Image & Mask
        image = cv2.imread(img_path) 
        # image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 
        # image = image[:, :, 1]
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) 
        
        # Ensure binary mask (0 or 1)
        mask = (mask > 0).astype(np.uint8)

        # Apply Augmentation
        if self.augment:
            augmented = self.augmentation(image=image, mask=mask)
            image, mask = augmented["image"], augmented["mask"]
        else:
            image = self.base_transform(Image.fromarray(image))  # Apply torchvision transform
            mask = cv2.resize(mask, img_size[::-1], interpolation=cv2.INTER_NEAREST)

        # TODO use bytes instead, also in the model - performance increase?
        mask = torch.tensor(mask, dtype=torch.float32).unsqueeze(0)
        return image, mask

# add the hand labeled images in different augmented ways as well as without augmentations
hand_labeled1 = FoosballDataset(os.path.join("training_images", "images"), os.path.join("training_images", "masks"), augment=False)
hand_labeled_aug1 = FoosballDataset(os.path.join("training_images", "test_003_hand_cleaned", "images"), os.path.join("training_images", "test_003_hand_cleaned", "masks"), augment=True)

vid_003_aug = FoosballDataset(os.path.join("training_images", "test_003_hand_cleaned", "images"), os.path.join("training_images", "test_003_hand_cleaned", "masks"), augment=False)
vid_003 = FoosballDataset(os.path.join("training_images", "test_003_hand_cleaned", "images"), os.path.join("training_images", "test_003_hand_cleaned", "masks"), augment=True)

hand_labeled = ConcatDataset([hand_labeled1,hand_labeled_aug1, vid_003, vid_003_aug])

video1_dir = os.path.join("training_images", "rec-20250110-130045_hand_cleaned")
vid1 = FoosballDataset(os.path.join(video1_dir, "images"), os.path.join(video1_dir, "masks"), augment=True)

full_dataset = ConcatDataset([hand_labeled, vid1])

# Split dataset into train (80%) and validation (20%)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Disable augmentation for validation dataset
val_dataset.dataset.augment = False  # Ensures val set is not augmented

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

print(f"total: {len(full_dataset)}, training: {len(train_loader.dataset)}, validation: {len(val_loader.dataset)}")

images, masks = next(iter(train_loader))
print(f"Training: Image batch shape: {images.shape}")  # (batch_size, 1, H, W)
print(f"Training: Mask batch shape: {masks.shape}")  # (batch_size, 1, H, W)

for images, masks in train_loader:
    print(f"Training: Image shape: {images.shape}")  # Should be [batch_size, 1, H, W]
    print(f"Training: Mask shape: {masks.shape}")  # Should be [batch_size, 1, H, W]
    break

images, masks = next(iter(val_loader))
print(f"Validation: Image batch shape: {images.shape}")  # (batch_size, 1, H, W)
print(f"Validation: Mask batch shape: {masks.shape}")  # (batch_size, 1, H, W)

for images, masks in val_loader:
    print(f"Validation: Image shape: {images.shape}")  # Should be [batch_size, 1, H, W]
    print(f"Validation: Mask shape: {masks.shape}")  # Should be [batch_size, 1, H, W]
    break


## show exampels

In [None]:
import matplotlib.pyplot as plt

images, masks = next(iter(train_loader))

image = images[0]
mask = masks[0]
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(image.permute(1, 2, 0)) # rgb image
# plt.imshow(image.squeeze(0),cmap="gray")  # single channel images
plt.title("Augmented Image")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(mask.squeeze(0), cmap="gray")  # Remove channel dimension
plt.title("Augmented Mask")
plt.axis("off")

plt.show()


## Model

In [None]:
class ShallowUNet(nn.Module):
    def __init__(self, input_channels=1, output_channels=1):
        super(ShallowUNet, self).__init__()

        # Downsampling path
        self.enc1 = nn.Sequential(
            nn.Conv2d(input_channels, 4, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(inplace=True)
        )
        self.pool1 = nn.Sequential(
            nn.Conv2d(4, 4, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(inplace=True),
        )

        self.enc2 = nn.Sequential(
            nn.Conv2d(4, 8, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(8, 8, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(inplace=True)
        )
        self.pool2 = nn.Sequential(
            nn.Conv2d(8, 8, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(inplace=True),
        )
        self.enc3 = nn.Sequential(
            nn.Conv2d(8, 16, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
        )
        self.pool3 = nn.Sequential(
            nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(inplace=True),
        )
        self.enc4 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
        )
        self.pool4 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(inplace=True),
        )

        self.enc5 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
        )
        self.pool5 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(inplace=True),
        )

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
        )

        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
        )
        self.up2 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
        )
        self.up3 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
        self.dec3 = nn.Sequential(
            nn.Conv2d(32, 16, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
        )

        # Upsampling path
        self.up4 = nn.ConvTranspose2d(16, 8, kernel_size=2, stride=2)
        self.dec4 = nn.Sequential(
            nn.Conv2d(16, 8, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(8, 8, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
        )

        self.up5 = nn.ConvTranspose2d(8, 4, kernel_size=2, stride=2)
        self.dec5 = nn.Sequential(
            nn.Conv2d(8, 4, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
        )

        # Output layer
        self.out = nn.Conv2d(4, output_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool1(enc1))
        enc3 = self.enc3(self.pool2(enc2))
        enc4 = self.enc4(self.pool3(enc3))
        enc5 = self.enc5(self.pool4(enc4))

        # Bottleneck
        bottleneck = self.bottleneck(self.pool5(enc5))

        # Decoder
        dec1 = self.dec1(torch.cat((self.up1(bottleneck), enc5), dim=1))
        dec2 = self.dec2(torch.cat((self.up2(dec1), enc4), dim=1))
        dec3 = self.dec3(torch.cat((self.up3(dec2), enc3), dim=1))
        dec4 = self.dec4(torch.cat((self.up4(dec3), enc2), dim=1))
        dec5 = self.dec5(torch.cat((self.up5(dec4), enc1), dim=1))

        # Output
        return torch.sigmoid(self.out(dec5))

model = ShallowUNet(input_channels=3, output_channels=1)

print(model)


## training

In [None]:
##################################################################################################################################
# utility
##################################################################################################################################

class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'

best_val_loss = float('inf')  # Keep track of best validation loss
train_losses = []
val_losses = []


##################################################################################################################################
# config
##################################################################################################################################

criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch_directml.device()

model.to(device)
epochs = 64
model.train()

##################################################################################################################################
# training
##################################################################################################################################

print("Starting training")
for epoch in range(epochs):
    # Training phase
    model.train()
    train_loss = 0.0
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, masks)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_loader)  # Average training loss
    train_losses.append(train_loss)

    # Validation phase (No gradients)
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for val_images, val_masks in val_loader:
            val_images, val_masks = val_images.to(device), val_masks.to(device)

            val_outputs = model(val_images)
            val_loss += criterion(val_outputs, val_masks).item()

    val_loss /= len(val_loader)  # Average validation loss
    val_losses.append(val_loss)
    plt.figure(figsize=(8, 6))
    plt.plot(range(1, len(train_losses) + 1), train_losses, label="Train Loss", marker='o', color='blue')
    plt.plot(range(1, len(val_losses) + 1), val_losses, label="Val Loss", marker='o', color='red')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training & Validation Loss")
    plt.legend()
    plt.grid()
    plt.show()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        print(f"{bcolors.OKGREEN}Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f} (Best){bcolors.ENDC}")
    else:
        print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")


## Finetuning

In [None]:
# train_losses = []
# val_losses = []
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.train()
print("Starting training")
epochs = 32

finetuning_loader = DataLoader(hand_labeled,batch_size=16,shuffle=True)
for epoch in range(epochs):
    # Training phase
    model.train()
    train_loss = 0.0
    for images, masks in finetuning_loader:
        images, masks = images.to(device), masks.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, masks)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(finetuning_loader)  # Average training loss
    train_losses.append(train_loss)

    # Validation phase (No gradients)
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for val_images, val_masks in val_loader:
            val_images, val_masks = val_images.to(device), val_masks.to(device)

            val_outputs = model(val_images)
            val_loss += criterion(val_outputs, val_masks).item()

    val_loss /= len(val_loader)  # Average validation loss
    val_losses.append(val_loss)
    plt.figure(figsize=(8, 6))
    plt.plot(range(1, len(train_losses) + 1), train_losses, label="Train Loss", marker='o', color='blue')
    plt.plot(range(1, len(val_losses) + 1), val_losses, label="Val Loss", marker='o', color='red')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training & Validation Loss")
    plt.legend()
    plt.grid()
    plt.show()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        print(f"{bcolors.OKGREEN}Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f} (Best){bcolors.ENDC}")
    else:
        print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")


## save model

In [None]:
torch.save(model.state_dict(), "unet_foosball.pth")

## load model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch_directml.device()

model = ShallowUNet(input_channels=3, output_channels=1)  # needs to match training config
model.to(device)
model.load_state_dict(torch.load("unet_foosball.pth", map_location=device))  # Load weights
model.eval()  


## validation

In [None]:
import os, random

img_path = os.path.join("training_images/unlabeled_images/", random.choice(os.listdir("training_images/unlabeled_images/")))
# img_path = os.path.join("training_images/rec-20250110-130045/images/", random.choice(os.listdir("training_images/rec-20250110-130045/images/")))
print(os.path.basename(img_path))
import torch
import cv2
import numpy as np
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import time
model.eval()  

transform = transforms.Compose([
    transforms.Resize(img_size),  # Resize to match model input
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize to [-1,1]
])


rgb_image = cv2.imread(img_path)  # Read image using OpenCV
image = rgb_image.copy()
# image = cv2.cvtColor(rgb_image.copy(), cv2.COLOR_BGR2HSV)  # Convert to RGB
# image = image[:, :, 1]
rgb_image = cv2.cvtColor(rgb_image,cv2.COLOR_BGRA2RGB)
start_time = time.time()
image_pil = Image.fromarray(image)
input_tensor = transform(image_pil).unsqueeze(0).to(device)  # Add batch dimension: (1, 3, W, H)


# Run inference
with torch.no_grad():
    output = model(input_tensor)

# Print statistics about the output
print(f"Output min: {output.min().item()}, max: {output.max().item()}")
print(f"Output mean: {output.mean().item()}, std: {output.std().item()}")
# Post-process the output
output_mask = output.cpu().squeeze(0).squeeze(0).numpy()  # Remove batch and channel dims
output_mask = (output_mask > 0.5).astype(np.uint8)  # Threshold to binary mask

# Resize mask back to original image size
output_mask = cv2.resize(output_mask, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
print(f"time taken: {time.time()-start_time:.3f}s")
# Overlay mask on original image
overlay = rgb_image.copy()
overlay[output_mask == 1] = [255,0,0]  # Color detected areas in red

# Display results
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(rgb_image)
plt.title("Original Image")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(overlay)
plt.title("Overlayed Segmentation")
plt.axis("off")

plt.show()

