In [None]:
!python -m venv fastai_env
!fastai_env\Scripts\activate

In [None]:
!pip install opencv-python
!pip install torch
!pip install torchvision
!pip install numpy
!pip install pillow
!pip install fastai
!pip install scikit-learn
!pip install mss
!pip install pynput
!pip install keyboard

In [6]:
# Standard libraries
import os
import time
from datetime import datetime

# Computer vision and image processing
import cv2
import numpy as np
from PIL import Image

# PyTorch and related modules
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

# FastAI for deep learning
from fastai.vision.all import *

# Scikit-learn for model evaluation
from sklearn.model_selection import train_test_split

# Screen capture and keyboard control
import mss
from pynput import keyboard as key
import keyboard  # Another keyboard control library, different from pynput

In [None]:
# Directory to store data
data_dir = "trackmania_data"
if not os.path.exists(data_dir):
    os.makedirs(data_dir)

# Create directories for labels
for label in ['0', '1', '2', '3']:  # 0: straight (w), 1: left (a), 2: right (d), 3: backward
    label_dir = os.path.join(data_dir, label)
    if not os.path.exists(label_dir):
        os.makedirs(label_dir)

# Variable to store the pressed key
key_pressed = None

# Function to handle key press events
def on_press(key):
    global key_pressed
    try:
        if key.char == 'w':
            key_pressed = '0'  # Straight
        elif key.char == 'a':
            key_pressed = '1'  # Left
        elif key.char == 'd':
            key_pressed = '2'  # Right
    except AttributeError:
        pass

def on_release(key):
    global key_pressed
    key_pressed = None

# Keyboard listener
listener = key.Listener(on_press=on_press, on_release=on_release)
listener.start()

# Function to capture and save screenshots
def collect_data():
    with mss.mss() as sct:
        monitor = sct.monitors[1]  # Screen to capture
        
        while True:
            if key_pressed is not None:
                # Take a screenshot
                screenshot = np.array(sct.grab(monitor))
                
                # Convert the image from BGRA to BGR (for OpenCV)
                img_bgr = cv2.cvtColor(screenshot, cv2.COLOR_BGRA2BGR)
                
                # Resize the image (optional, you can adjust this)
                img_bgr = cv2.resize(img_bgr, (100, 100))
                
                # Create a unique filename based on the timestamp
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S%f")
                img_filename = f"{timestamp}.jpg"
                
                # Check if key_pressed is not None
                if key_pressed is not None:
                    # Path to the directory with the correct label
                    label_dir = os.path.join(data_dir, key_pressed)
                    img_path = os.path.join(label_dir, img_filename)
                    
                    # Save the original screenshot
                    cv2.imwrite(img_path, img_bgr)
                    print(f"Screenshot saved at {img_path}")

                    # If the key is 'left' (1) or 'right' (2), flip the image and save with the opposite label
                    if key_pressed == '1':  # Left
                        flipped_img = cv2.flip(img_bgr, 1)  # Flip horizontally
                        flipped_label_dir = os.path.join(data_dir, '2')  # Flip to right
                    elif key_pressed == '2':  # Right
                        flipped_img = cv2.flip(img_bgr, 1)  # Flip horizontally
                        flipped_label_dir = os.path.join(data_dir, '1')  # Flip to left
                    
                    if key_pressed in ['1', '2']:  # Save the flipped image
                        flipped_img_filename = f"{timestamp}_flipped.jpg"
                        flipped_img_path = os.path.join(flipped_label_dir, flipped_img_filename)
                        cv2.imwrite(flipped_img_path, flipped_img)
                        print(f"Flipped screenshot saved at {flipped_img_path}")

# Start data collection
collect_data()

In [None]:
# Parameters
IMG_HEIGHT, IMG_WIDTH = 100, 100
EPOCHS = 5
BATCH_SIZE = 32
LEARNING_RATE = 0.001

# Data augmentation for FastAI
transform = transforms.Compose([
    transforms.RandomRotation(degrees=15),
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop((IMG_HEIGHT, IMG_WIDTH), scale=(0.8, 1.0)),
])

# Dataset class for in-memory data
class TrackmaniaDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        # Convert image to uint8 type for PIL compatibility
        image = (image * 255).astype(np.uint8)
        image = Image.fromarray(image)

        # Apply transformations
        if self.transform:
            image = self.transform(image)

        # Convert to tensor and normalize to [0, 1]
        image = np.array(image).astype(np.float32) / 255.0
        image = torch.tensor(image).permute(2, 0, 1)  # Convert to CxHxW format
        label = torch.tensor(label, dtype=torch.long)

        return image, label

# Function to load data from directory
def load_data(data_dir):
    images = []
    labels = []
    for label_dir in os.listdir(data_dir):
        label_path = os.path.join(data_dir, label_dir)
        for img_file in os.listdir(label_path):
            img = cv2.imread(os.path.join(label_path, img_file))
            img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))
            images.append(img)
            labels.append(int(label_dir))  # label 0 for straight, 1 for left, 2 for right, 3 for backward
    return np.array(images), np.array(labels)

# Load and preprocess the data
data_dir = "trackmania_data"
X, y = load_data(data_dir)
X = X / 255.0  # Normalize images

# Handle any invalid labels outside expected range
if np.any((y < 0) | (y > 3)):
    raise ValueError("Some labels are outside the expected range 0-3.")

# Balance the data through oversampling
oversampled_images = []
oversampled_labels = []
for image, label in zip(X, y):
    if label in [1, 2]:  # Oversample underrepresented classes (left and right)
        for _ in range(5):  # Duplicate examples
            oversampled_images.append(image)
            oversampled_labels.append(label)
    else:
        oversampled_images.append(image)
        oversampled_labels.append(label)

X_balanced = np.array(oversampled_images)
y_balanced = np.array(oversampled_labels)

# Train-test split
X_train, X_test, y_train, y_test = train_test_split(X_balanced, y_balanced, test_size=0.2, random_state=42)

# Create dataset objects
train_dataset = TrackmaniaDataset(X_train, y_train, transform=transform)
test_dataset = TrackmaniaDataset(X_test, y_test, transform=transforms.Resize((IMG_HEIGHT, IMG_WIDTH)))

# Create DataLoader
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Wrap DataLoader with FastAI's DataLoaders
dls = DataLoaders(train_loader, test_loader)

# CNN Model definition
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(32 * 25 * 25, 128)  # Adjust based on your input size
        self.fc2 = nn.Linear(128, 4)  # 4 classes: forward, left, right, backward

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)  # Flatten the output
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Create the model
model = CNNModel()

# Create a learner using FastAI's Learner
learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), opt_func=Adam, metrics=accuracy)

# Train the model
learn.fit_one_cycle(EPOCHS, LEARNING_RATE)

# Save the trained model
learn.save('trackmania_model')

# Evaluate the model on the test set
learn.validate()

In [None]:
# Model laden
model = CNNModel()
model.load_state_dict(torch.load('trackmania_model.pth'))
model.eval()  # Zet model in evaluatiemodus

IMG_HEIGHT, IMG_WIDTH = 100, 100


# Functie om screenshots te nemen en te gebruiken voor het model
def drive():
    with mss.mss() as sct:
        monitor = sct.monitors[1]  # Scherm dat je wilt capturen (je kunt dit aanpassen)

        while True:
            # Screenshot nemen
            screenshot = np.array(sct.grab(monitor))

            # Converteer de afbeelding van BGRA naar BGR (voor OpenCV)
            img_bgr = cv2.cvtColor(screenshot, cv2.COLOR_BGRA2BGR)

            # Verklein de afbeelding naar 100x100 pixels (zelfde als tijdens training)
            img_resized = cv2.resize(img_bgr, (IMG_WIDTH, IMG_HEIGHT))

            # Converteer de afbeelding naar een tensor voor het model
            img_tensor = torch.tensor(img_resized, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) / 255.0  # PyTorch input

            # Maak een voorspelling
            with torch.no_grad():
                output = model(img_tensor)
                predicted_class = torch.argmax(output, dim=1).item()

            # Print de voorspelling naar de console
            print(f"Predicted class: {predicted_class}")

            # Voeg de voorspelling toe aan het frame
            cv2.putText(img_bgr, f'Predicted: {predicted_class}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2, cv2.LINE_AA)

            # AZERTY
            # if predicted_class == 0:
            #     keyboard.press("w")
            #     keyboard.release("q")
            #     keyboard.release("d")
            #     keyboard.release("s")
            # elif predicted_class == 1:
            #     keyboard.press("q")
            #     keyboard.release("z")
            #     keyboard.release("d")
            #     keyboard.release("s")
            # elif predicted_class == 2:
            #     keyboard.press("d")
            #     keyboard.release("q")
            #     keyboard.release("z")
            #     keyboard.release("s")
            # elif predicted_class == 3:
            #     keyboard.press("s")
            #     keyboard.release("q")
            #     keyboard.release("z")
            #     keyboard.release("d")

            # QWERTY
            if predicted_class == 0:
                keyboard.press("w")
                keyboard.release("a")
                keyboard.release("d")
                keyboard.release("s")
            elif predicted_class == 1:
                keyboard.press("a")
                keyboard.release("w")
                keyboard.release("d")
                keyboard.release("s")
            elif predicted_class == 2:
                keyboard.press("d")
                keyboard.release("w")
                keyboard.release("a")
                keyboard.release("s")
            elif predicted_class == 3:
                keyboard.press("s")
                keyboard.release("w")
                keyboard.release("a")
                keyboard.release("d")

            
    cv2.destroyAllWindows()

# Start de drive functie
drive()