In [1]:
from __future__ import print_function
import argparse
from tqdm import tqdm
import os
import pandas as pd
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import numpy as np
import torchvision
from torchvision import transforms
from glob import glob
from model import GTSRBnet
import utils

In [2]:
model_file = '/home/stringlab/Desktop/DLCV_midterm_project/trained_models/model_49.pth'
test_dir = '/home/stringlab/Desktop/DLCV_midterm_project/GTSRB_Final_Test_Images/GTSRB/Final_Test/Images'
output_file = open("pred.csv", "w")
state_dict = torch.load(model_file)
model = GTSRBnet(n_classes=43)
model.load_state_dict(state_dict)
model.eval();

In [None]:
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize([212, 256]),
    transforms.ToTensor()
])

output_file.write("Filename,ClassId\n")

for f in tqdm(sorted(glob(os.path.join(test_dir, "*.ppm")))):
    output = torch.zeros([1, 43], dtype=torch.float32)
    with torch.no_grad():
        data = transform(utils.pil_loader(f))
        data = data.view(1, data.size(0), data.size(1), data.size(2))
        data = Variable(data)
        output = output.add(model(data))
        pred = output.data.max(1, keepdim=True)[1]
        file_id = f[0:5]
        output_file.write("%s,%d\n" % (file_id, pred))
        
output_file.close()

In [None]:
# Calculate test accuracy
gt_file = '/home/stringlab/Desktop/DLCV_midterm_project/GTSRB_Final_Test_GT/GT-final_test.csv'
gt = pd.read_csv(gt_file, sep=';')
pred_file = '/home/stringlab/Desktop/DLCV_midterm_project/pred.csv'
pred = pd.read_csv(pred_file, sep=',')

In [None]:
print("Accuracy: ", (gt['ClassId']==pred['ClassId']).sum()/len(gt)*100, "%")

In [None]:
# Provide a class confusion matrix showing relative distribution (in percentages) of
# classifications for each class in each cell (and not cardinalities)

# Plot a confusion matrix
cm = confusion_matrix(gt['ClassId'], pred['ClassId'])
cm = (cm.astype('float') / cm.sum(axis=1)[:, np.newaxis])*100
plt.figure(figsize=(25,20))
sns.heatmap(cm, annot=True, fmt='.1f', cmap='Blues')
plt.ylabel('True label', fontsize=20)
plt.xlabel('Predicted label', fontsize=20)
plt.show()