# Evaluating the trained model

In [None]:
# importing necessary libraries

import os
import csv
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim

In [None]:
# Dataset class for test dataset
class RoadFollowingDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        """
        Initialize the dataset.

        Args:
            csv_file (str): Path to the CSV file containing image paths and labels.
            root_dir (str): Directory containing the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data = []
        with open(csv_file, mode='r') as file:
            reader = csv.reader(file)
            next(reader)  # Skip the header
            for row in reader:
                self.data.append([os.path.basename(row[0]), row[1]])
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        """
        Return the number of samples in the dataset.
        """
        return len(self.data)

    def __getitem__(self, idx):
        """
        Fetch and return a sample from the dataset.

        Args:
            idx (int): Index of the sample.

        Returns:
            tuple: A tuple containing the image and label.
        """
        img_name = os.path.join(self.root_dir, self.data[idx][0])
        if not os.path.exists(img_name):
            raise FileNotFoundError(f"File not found: {img_name}")

        image = Image.open(img_name).convert('RGB')
        label = self.data[idx][1]

        if self.transform:
            image = self.transform(image)

        label_dict = {'up': 0, 'down': 1, 'left': 2, 'right': 3, 'obstacle': 4}
        label = label_dict[label]

        return image, label

In [None]:
# Data augmentation for test dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


In [None]:
# Initialize the test dataset and dataloader
csv_file_path = '/workspace/jetbot/notebooks/testing/data collection/collected_data/labels.csv'
root_dir = '/workspace/jetbot/notebooks/testing/data collection/collected_data'
dataset = RoadFollowingDataset(csv_file=csv_file_path, root_dir=root_dir, transform=transform)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

In [None]:
# Model definition
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=5):
        """
        Initialize the model.

        Args:
            num_classes (int, optional): Number of output classes. Defaults to 5.
        """
        super(SimpleCNN, self).__init__()
        self.features = models.resnet18(pretrained=True)
        self.features.fc = nn.Linear(self.features.fc.in_features, num_classes)

    def forward(self, x):
        """
        Forward pass of the model.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor.
        """
        x = self.features(x)
        return x

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleCNN(num_classes=5)
model = model.to(device)

In [None]:
# Load the trained model
model.load_state_dict(torch.load('/workspace/jetbot/notebooks/testing/training /road_following_model.pth'))
model.eval()

In [None]:
# Function to visualize predictions
def visualize_predictions(model, data_loader, device, class_names):
    """
    Visualize the model's predictions.

    Args:
        model (nn.Module): The model to be evaluated.
        data_loader (DataLoader): DataLoader for the test dataset.
        device (torch.device): Device to run the model on.
        class_names (list): List of class names.
    """
    model.eval()
    images, labels = next(iter(data_loader))
    images, labels = images.to(device), labels.to(device)
    outputs = model(images)
    _, preds = torch.max(outputs, 1)

    fig = plt.figure(figsize=(15, 6))
    for idx in range(min(10, len(images))):
        ax = fig.add_subplot(2, 5, idx + 1, xticks=[], yticks=[])
        img = images[idx].cpu().numpy().transpose((1, 2, 0))
        img = img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]
        img = np.clip(img, 0, 1)
        plt.imshow(img)
        ax.set_title(f"{class_names[preds[idx]]} ({class_names[labels[idx]]})",
                     color=("green" if preds[idx] == labels[idx] else "red"))

# Class names for labels
class_names = ['up', 'down', 'left', 'right', 'obstacle']

# Visualize the model's predictions
visualize_predictions(model, data_loader, device, class_names)