<a href="https://colab.research.google.com/github/AnikethDandu/traffic-sign-classification/blob/main/TrafficSignClassification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Traffic Sign Classification**


## **Google Drive Dataset Import**
*Make sure to follow dataset download instructions on [Github](https://github.com/AnikethDandu/traffic-sign-classification)*

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')
!cp -r /content/gdrive/My\ Drive/ColabNotebooks/Data/ traffic_sign_images.zip
!unzip traffic_sign_images.zip/traffic_sign_images.zip
!rm -r traffic_sign_images.zip/

## **Import Libraries**
Python standard libraries. If there are problems with torch, run the following command
```
!pip install torch
```



In [2]:
import cv2
import numpy as np
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from tqdm.notebook import tqdm

## **Main Script Variable Initialization**
Initializes variables for the main script. It includes hyperparameters for the network, a standard image size, the datatsets and dataloaders, and the path to the training data directories.

In [11]:
# Network hyperparameters
EPOCHS = 10
BATCH_SIZE = 128
learning_rate = 0.002

# Tuple for desired image size
image_size = (-1, 3, 50, 50)

# Training and evaluation Datasets and DataLoaders
training_dataset = None
testing_dataset = None
train_dataloader = None
test_dataloader = None

# Path to class sorted training data directories
train_path = 'traffic_sign_images/Train'

## **Convolutional Neural Network**
A CNN class with four 2D Convolutional layers and two Linear layers. The ReLU function is appliead after each layer. The max pooling function is applied after each Convolutional layer.

In [4]:
class ConvNet(nn.Module):
  """
  Convolution Neural Network class
  Extends torch.nn.Module base network initialization
  Overrides torch.nn.Module forward method

  PUBLIC METHODS:
    - forward(self, x)

  INSTANCE VARIABLES:
    - PADDING_SIZE
    - KERNEL_SIZE
    - STRIDE
    - POOL_SIZE
    - conv1
    - conv2
    - conv3
    - conv4
    - fc1
    - fc2
  """

  def __init__(self):
    """
    Initializes network layers for input image of size 32x32x3

    :var PADDING_SIZE: size of padding applied to all sides of input matrix to preserve input volume
    :type: int
    :var KERNEL_SIZE: size of feature extraction filter
    :type: int
    :var STRIDE: pixel translation length during convolution operation
    :type: int
    :var POOL_SIZE: size of pooled feature map
    :type: int
    :return: None
    """
    self.PADDING_SIZE = 1
    self.KERNEL_SIZE = 3
    self.STRIDE = 1
    self.POOL_SIZE = 2

    super().__init__()

    self.conv1 = nn.Conv2d(3, 32, 
                           kernel_size=self.KERNEL_SIZE, 
                           stride=self.STRIDE, 
                           padding=self.PADDING_SIZE)
    self.conv2 = nn.Conv2d(32, 64, 
                           kernel_size=self.KERNEL_SIZE, 
                           stride=self.STRIDE, 
                           padding=self.PADDING_SIZE)
    self.conv3 = nn.Conv2d(64, 128, 
                           kernel_size=self.KERNEL_SIZE, 
                           stride=self.STRIDE, 
                           padding=self.PADDING_SIZE)
    self.conv4 = nn.Conv2d(128, 256, 
                           kernel_size=self.KERNEL_SIZE, 
                           stride=self.STRIDE, 
                           padding=self.PADDING_SIZE)
    self.fc1 = nn.Linear(2304, 1024)
    self.fc2 = nn.Linear(1024, 43)
    
  def forward(self, x):
    """
    Passes input matrix through network convolutional and linear layers while 
    applying pooling and ReLU function

    :param x: input matrix
    :type: torch.tensor
    :return: class output matrix
    :rtype: torch.tensor
    """
    x = F.max_pool2d(F.relu(self.conv1(x)), self.POOL_SIZE)
    x = F.max_pool2d(F.relu(self.conv2(x)), self.POOL_SIZE)
    x = F.max_pool2d(F.relu(self.conv3(x)), self.POOL_SIZE)
    x = F.max_pool2d(F.relu(self.conv4(x)), self.POOL_SIZE)
    x = x.flatten(start_dim=1)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return x 

## **Dataset**

###**Image Transformation Function**###
The function quadruples the training dataset size by applying an image rotation transformation. The function updates the corresponding csv and saves the new images in the appropriate training folders

In [5]:
def transform_training_images(csv_path, train_path):
  df = pd.read_csv(csv_path)
  training_count = len(df)
  new_data = []
  for i in tqdm(range(training_count)):
    line = df.iloc[i]
    image = cv2.imread(os.path.join('traffic_sign_images', line[7]), cv2.IMREAD_COLOR)
    rows, cols, x = image.shape
    for j in range(3):
      matrix = cv2.getRotationMatrix2D(((cols-1)/2.0,(rows-1)/2.0),90*(i+1),1)
      image = cv2.warpAffine(image, matrix, (cols,rows))
      new_path = new_path = os.path.join(train_path, str(line[6]), f'{j}x{i}.png')
      cv2.imwrite(str(new_path), image)
      new_line = f'{",".join(list([str(line[i]) for i in range(7)]))},{str(new_path[20:])}'
      new_data.append(new_line)
  with open(csv_path, 'a') as f:
    for line in tqdm(new_data):
      f.write(line + '\n')
  f.close()

### **Custom Dataset Class**
The class subclasses the PyTorch Dataset class to read the training and evaluation csv and iterate over the dataset

In [6]:
class TrafficSignDataset(Dataset):
  """
  Custom Dataset class
  Subclasses torch.utils.data.Dataset

  DUNDER METHODS:
    - __len__(self)
    - __getitem__(self, idx)

  INSTANCE VARIABLES:
    - train
    - root_dir
    - img_size
    - df
  """

  def __init__(self, train, root_dir, img_size):
    """
    Initializes instance variables for given parameters

    :param train: whether dataset is a train or evaluation dataset
    :type: bool
    :param root_dir: root directory of files
    :type: str
    :param img_size: desired length of one side of resized image
    :type: int
    :var df: reads and displays csv file as two-axis display
    :type: DataFrame
    """
    self.train = train
    self.root_dir = root_dir
    self.df = pd.read_csv(os.path.join(root_dir, 'Train.csv' if train else 'Test.csv'))
    self.img_size = img_size

  def __len__(self):
    """
    Returns length of csv file corresponding to dataset

    :return: length of csv file
    :rtype: int
    """
    return len(self.df)

  def __getitem__(self, idx):
    """
    Returns image from dataset at specific index along with class label

    :param idx: index of desired item to get
    :type: int
    :return: returns dictionary of image with corresponding label
    :rtype: dict
    """
    image = cv2.imread(os.path.join(self.root_dir, self.df.iloc[idx][7]), cv2.IMREAD_COLOR)
    image = cv2.resize(image, (self.img_size, self.img_size))
    sample = {'image': torch.tensor(image), 'label': self.df.iloc[idx][6]}
    return sample


### **Dataset Creation Function**
The function reassigns global variables to the corresponding Dataset and DataLoader classes, creating shuffled training and evaluation datasets separated into batches

In [7]:
def create_datasets():
  """
  Reassigns global training variables to corresponding datasets and dataloaders
  """
  global training_dataset
  global testing_dataset
  global train_dataloader
  global test_dataloader

  training_dataset = TrafficSignDataset(train=True, root_dir='traffic_sign_images', img_size=50)
  testing_dataset = TrafficSignDataset(train=False, root_dir='traffic_sign_images', img_size=50)
  
  train_dataloader = DataLoader(training_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
  test_dataloader = DataLoader(testing_dataset, batch_size=1, shuffle=True)


## **Model Training and Evaluation**
The functions below train / evaluate the network given in the function parameter

### **Model Evaluation Function**
The function iterates through the evaluation dataset, evaluating the network response by comparing it to each image's class label

In [8]:
def evaluate_model(net):
  """
  Iterates over evaluation dataset without adjusting model gradients, comparing model predicted class to true class
  Prints accuracy for each class, raw class total correct, and total image accuracy and raw score

  :param net: CNN to be trained
  :type net: torch.nn.Module
  """
  total_classes = {}
  class_correct = {}
  total_images = 0
  total_correct = 0
  with torch.no_grad():
    for batch_idx, batch in tqdm(enumerate(test_dataloader)):
      test_image, test_label = batch['image'].view(image_size) / 255.0, batch['label'].item()
      correct_class = test_label
      test_image = test_image.to(device)
      predicted_class = torch.argmax(net(test_image)[0])
      
      total_images += 1
      total_classes[predicted_class.item()] = total_classes[predicted_class.item()] + 1 if predicted_class.item() in total_classes else 1
      
      if predicted_class == correct_class:
        total_correct += 1
        class_correct[correct_class] = class_correct[correct_class] + 1 if correct_class in class_correct else 1
  print([f'Accuracy for {img_class}: {round(100 * class_correct[img_class] / total_classes[img_class], 3)}%' for img_class in class_correct])
  print(f'Raw class score: {class_correct}')
  print(f'Total images correct: {total_correct}, Total images: {total_images}, Total accuracy: {round(100 * total_correct / total_images, 3)}%')

### **Model Training Function**
The function trains the model, printing loss for every epoch

In [9]:
def train_model(net):
  """
  Iterates over training dataset, adjusting model gradients
  Prints epoch number and corresponding loss after every epoch

  :param net: CNN to be trained
  :type net: torch.nn.Module
  """
  for epoch in range(EPOCHS):
    for batch_idx, batch in tqdm(enumerate(train_dataloader)):
      batch_imgs, batch_lbls = batch["image"].view(image_size) / 255.0, batch["label"]
      batch_labels = [0 for i in range(BATCH_SIZE)]
      for label_idx, label in enumerate(batch_lbls):
        batch_labels[label_idx] = label.item()
      
      optimizer.zero_grad()
      outputs = net(batch_imgs.to(device))
      loss = criterion(outputs, torch.tensor([label for label in batch_labels], device=device).long())
      loss.backward()
      optimizer.step()
    print(f'Epoch: {epoch + 1}, Loss: {loss}')
    evaluate_model(net)

## **Main Script**
The main script uses the GPU if available, creates the datasets, CNN, Adam optimizer, CrossEntropyLoss criterion with class weights, and trains and evaluates the model

In [13]:
if torch.cuda.is_available():
  device = torch.device('cuda:0')
else:
  device = torch.device('cpu')

# Increase training dataset size by applying image transforms
transform_training_images('traffic_sign_images/Train.csv', 'traffic_sign_images/Train/')

# Create both training and evaluation datasets
create_datasets()

# Calculate total number of images in training dataset
total_images = len(training_dataset)
class_count = []
for folder in os.listdir(train_path):
  if folder != '.DS_Store':
    image_count = len([img for img in os.listdir(os.path.join(train_path, folder))])
    class_count.append(image_count)

# Calculate the class weights (due to unequal class image sizes)
final_weights = torch.Tensor([1 - img_count/total_images for img_count in class_count]).to(device)
# Create the CNN on the GPU
conv_net = ConvNet().to(device)                        
# Initializes optimizer and criterion
optimizer = optim.Adam(conv_net.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss(weight=final_weights)

# Train and evaluate the model
train_model(conv_net)
# evaluate_model(conv_net)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: 1, Loss: 0.7266603112220764


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


['Accuracy for 10: 71.48%', 'Accuracy for 12: 90.469%', 'Accuracy for 4: 75.986%', 'Accuracy for 40: 72.152%', 'Accuracy for 25: 57.614%', 'Accuracy for 31: 66.055%', 'Accuracy for 1: 72.019%', 'Accuracy for 38: 80.519%', 'Accuracy for 23: 83.051%', 'Accuracy for 17: 98.316%', 'Accuracy for 8: 55.556%', 'Accuracy for 29: 68.0%', 'Accuracy for 3: 59.625%', 'Accuracy for 6: 79.452%', 'Accuracy for 5: 62.626%', 'Accuracy for 32: 83.333%', 'Accuracy for 11: 73.775%', 'Accuracy for 33: 90.355%', 'Accuracy for 9: 83.878%', 'Accuracy for 18: 71.141%', 'Accuracy for 13: 92.708%', 'Accuracy for 7: 72.664%', 'Accuracy for 37: 84.615%', 'Accuracy for 2: 85.318%', 'Accuracy for 35: 92.767%', 'Accuracy for 15: 96.711%', 'Accuracy for 14: 99.02%', 'Accuracy for 26: 72.611%', 'Accuracy for 36: 88.793%', 'Accuracy for 20: 49.038%', 'Accuracy for 34: 97.087%', 'Accuracy for 16: 89.474%', 'Accuracy for 27: 51.613%', 'Accuracy for 28: 60.116%', 'Accuracy for 39: 70.0%', 'Accuracy for 24: 31.944%', 'Accu

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: 2, Loss: 0.4995872676372528


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


['Accuracy for 18: 76.55%', 'Accuracy for 2: 79.378%', 'Accuracy for 1: 79.117%', 'Accuracy for 28: 71.333%', 'Accuracy for 7: 91.74%', 'Accuracy for 8: 76.739%', 'Accuracy for 31: 66.578%', 'Accuracy for 13: 85.088%', 'Accuracy for 3: 83.765%', 'Accuracy for 5: 86.957%', 'Accuracy for 10: 85.221%', 'Accuracy for 11: 83.29%', 'Accuracy for 9: 86.693%', 'Accuracy for 12: 95.728%', 'Accuracy for 38: 86.798%', 'Accuracy for 34: 83.206%', 'Accuracy for 15: 93.22%', 'Accuracy for 29: 66.304%', 'Accuracy for 25: 76.907%', 'Accuracy for 4: 88.832%', 'Accuracy for 17: 98.635%', 'Accuracy for 6: 87.692%', 'Accuracy for 40: 73.563%', 'Accuracy for 14: 98.413%', 'Accuracy for 22: 75.94%', 'Accuracy for 23: 46.853%', 'Accuracy for 35: 90.585%', 'Accuracy for 19: 56.897%', 'Accuracy for 36: 72.18%', 'Accuracy for 32: 62.069%', 'Accuracy for 37: 60.759%', 'Accuracy for 26: 84.058%', 'Accuracy for 16: 86.905%', 'Accuracy for 20: 77.143%', 'Accuracy for 30: 45.665%', 'Accuracy for 33: 91.121%', 'Accu

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: 3, Loss: 0.21891508996486664


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


['Accuracy for 4: 81.965%', 'Accuracy for 13: 87.879%', 'Accuracy for 31: 68.876%', 'Accuracy for 8: 78.444%', 'Accuracy for 12: 95.618%', 'Accuracy for 7: 82.311%', 'Accuracy for 14: 97.297%', 'Accuracy for 1: 82.055%', 'Accuracy for 37: 97.727%', 'Accuracy for 2: 84.065%', 'Accuracy for 38: 86.047%', 'Accuracy for 9: 95.329%', 'Accuracy for 5: 76.691%', 'Accuracy for 34: 91.597%', 'Accuracy for 28: 58.333%', 'Accuracy for 10: 94.082%', 'Accuracy for 24: 40.0%', 'Accuracy for 20: 35.437%', 'Accuracy for 3: 90.836%', 'Accuracy for 6: 90.0%', 'Accuracy for 18: 76.301%', 'Accuracy for 15: 99.432%', 'Accuracy for 17: 99.329%', 'Accuracy for 35: 94.22%', 'Accuracy for 25: 65.0%', 'Accuracy for 16: 97.973%', 'Accuracy for 33: 78.0%', 'Accuracy for 21: 70.423%', 'Accuracy for 36: 68.421%', 'Accuracy for 40: 72.881%', 'Accuracy for 22: 78.151%', 'Accuracy for 42: 81.0%', 'Accuracy for 26: 74.138%', 'Accuracy for 39: 86.275%', 'Accuracy for 29: 81.429%', 'Accuracy for 30: 63.208%', 'Accuracy 

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: 4, Loss: 0.1264994591474533


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


['Accuracy for 2: 85.213%', 'Accuracy for 8: 83.493%', 'Accuracy for 5: 83.953%', 'Accuracy for 37: 92.308%', 'Accuracy for 12: 95.604%', 'Accuracy for 1: 83.721%', 'Accuracy for 4: 86.755%', 'Accuracy for 6: 90.517%', 'Accuracy for 7: 85.086%', 'Accuracy for 33: 70.196%', 'Accuracy for 15: 94.472%', 'Accuracy for 29: 56.923%', 'Accuracy for 10: 78.818%', 'Accuracy for 14: 95.865%', 'Accuracy for 35: 93.036%', 'Accuracy for 3: 76.163%', 'Accuracy for 13: 88.774%', 'Accuracy for 25: 75.828%', 'Accuracy for 22: 45.05%', 'Accuracy for 38: 85.18%', 'Accuracy for 28: 52.5%', 'Accuracy for 32: 56.566%', 'Accuracy for 40: 81.818%', 'Accuracy for 34: 85.246%', 'Accuracy for 11: 75.708%', 'Accuracy for 23: 63.025%', 'Accuracy for 39: 97.619%', 'Accuracy for 36: 54.918%', 'Accuracy for 9: 89.101%', 'Accuracy for 31: 74.434%', 'Accuracy for 30: 38.406%', 'Accuracy for 0: 65.0%', 'Accuracy for 17: 98.457%', 'Accuracy for 16: 96.667%', 'Accuracy for 18: 87.764%', 'Accuracy for 26: 62.381%', 'Accur

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

KeyboardInterrupt: ignored