# Recognition   

In [3]:
%load_ext autoreload
%autoreload 2
import sys
import os
import os
import glob
import random
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
from torchvision import transforms

repo_path = '/content/Cybathlon'
git_url = 'https://Leonpa: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@github.com/Leonpa/Cybathlon.git'
base_path = '/content/drive/MyDrive/Cybathlon'

if 'google.colab' in str(get_ipython()):
    from google.colab import drive
    print('Running on CoLab')
    if os.path.exists(repo_path):
        print('Repository already cloned. Pulling changes...')
        %cd $repo_path
        !git reset --hard
        !git pull
        %cd /content
    else:
        print('Cloning repository for the first time...')
        !git clone $git_url

    drive.mount('/content/drive', force_remount=True)
    sys.path.append(repo_path)

    # gpu_info = !nvidia-smi
    # gpu_info = '\n'.join(gpu_info)
    # if gpu_info.find('failed') >= 0:
    #   print('Not connected to a GPU')
    # else:
    #   print(gpu_info)
else:
    print('Running locally')
    base_path = ''

sys.path.append('/content/Cybathlon/')
from models.detection import Model, ModelTrainer, CustomDataset, Inference

In [4]:
def prepare_dataloaders(data_dirs, batch_size=32, valid_split=0.1, test_split=0.1):
    all_images = []
    for data_dir in data_dirs:
        all_images += glob.glob(os.path.join(data_dir, 'imgs_and_labels', 'images', '*.jpg'))

    # Filter images to include only those that have corresponding labels
    all_images = [img for img in all_images if os.path.exists(img.replace('images', 'labels').replace('.jpg', '.txt'))]

    print(f"Total images after filtering: {len(all_images)}")
    
    random.shuffle(all_images)
    total_images = len(all_images)
    test_size = int(total_images * test_split)
    valid_size = int(total_images * valid_split)
    train_size = total_images - test_size - valid_size

    train_images = all_images[:train_size]
    valid_images = all_images[train_size:train_size + valid_size]
    test_images = all_images[train_size + valid_size:]

    print(f"Training images: {len(train_images)}")
    print(f"Validation images: {len(valid_images)}")
    print(f"Test images: {len(test_images)}")

    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor()
    ])

    train_dataset = CustomDataset(train_images, transform=transform)
    valid_dataset = CustomDataset(valid_images, transform=transform)
    test_dataset = CustomDataset(test_images, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, valid_loader, test_loader

In [5]:
# Define the data directories
data_dirs = [
    'data/real_1geo_bright_512867' #,
    # 'data/real_1geo_onlyflash_512875',
    # 'data/real_2geos_bright_512830',
    # 'data/real_2geos_onlyflash_512887',
    # 'data/real_3geos_bright_512712',
    # 'data/real_4geos_bright_512639',
    # 'data/real_4geos_onlyflash_512894'
]
full_paths = (os.path.join(base_path, dir) for dir in data_dirs)
data_dirs = list(full_paths)

In [6]:
# Prepare DataLoaders
train_loader, valid_loader, test_loader = prepare_dataloaders(data_dirs, batch_size=16)

# Debugging: Print a few batches from the train_loader to check the data
# for images, labels in train_loader:
    # print("Batch of images:", images.shape)
    # print("Batch of labels:", labels)
    # Print example of unpacked labels
    # for label in labels:
        # print("Label:", label)
    # break  # Print only the first batch

In [5]:
model = Model(num_channels=3, num_classes=4)
trainer = ModelTrainer(model, train_loader, val_loader=valid_loader, learning_rate=0.001)
trainer.train(num_epochs=10)

In [6]:
inference = Inference(model, test_loader)
inference.run_inference(num_samples=10)

# TFlite export

In [7]:
import ai_edge_torch