In [14]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
import torch.nn as nn
import torchsummary
import torch
import torch.optim as optim

from dgzc.dataset import DGZCAutoEncoderDataset, DGZCClassifierDataset, DGZCInferenceDataset
from dgzc.autoencoder import Encoder, Decoder, AutoEncoderFaceImages
from dgzc.classifier import ClassificationBackbone, AutoEncoderClassifierAmalgamation

In [2]:
data_path = "/home/shivam/2021-22-2/ML4CE/Assignments/Assignment3/Driver-Gaze-Zone-Classification/data/gaze_dataset"
# dataset_auto_enc = DGZCAutoEncoderDataset(data_path, size=(200, 200))
# dataset_classifier = DGZCClassifierDataset(data_path, size=(200, 200))
dataset_inference = DGZCInferenceDataset(data_path, size=(200, 200))

In [18]:
inv_class_map = {0 : 'Centerstack',
                 1 : 'Forward',
                 2 : 'Left_wing_mirror',
                 3 : 'Rearview_mirror',
                 4 : 'Right_wing_mirror',
                 -1 : 'other'}

In [3]:
dataloader = DataLoader(dataset_inference, batch_size = 10, shuffle = False, num_workers=5)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
data_len = len(dataloader)
log_n = int(data_len//3)

In [6]:
model = AutoEncoderClassifierAmalgamation()
model.to(device)

AutoEncoderClassifierAmalgamation(
  (encoder): Encoder(
    (conv1): Conv2DNormActivation(
      (0): Conv2d(3, 5, kernel_size=(5, 5), stride=(2, 2))
      (1): SELU()
    )
    (conv2): Conv2DNormActivation(
      (0): Conv2d(5, 10, kernel_size=(5, 5), stride=(1, 1), bias=False)
      (1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SELU()
    )
    (max_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv3): Conv2DNormActivation(
      (0): Conv2d(10, 20, kernel_size=(3, 3), stride=(1, 1))
      (1): SELU()
    )
    (conv4): Conv2DNormActivation(
      (0): Conv2d(20, 20, kernel_size=(3, 3), stride=(1, 1))
      (1): SELU()
    )
    (conv5): Conv2DNormActivation(
      (0): Conv2d(20, 30, kernel_size=(3, 3), stride=(1, 1), bias=False)
      (1): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SELU()
    )
    (conv6): Conv2DNormActivation(
      (0): Co

In [33]:
load_prev_auto_enc_state = False
load_prev_model_state = True
if load_prev_auto_enc_state:
    model.encoder.load_state_dict(torch.load('./encoder_state'))
    model.decoder.load_state_dict(torch.load('./decoder_state'))
elif load_prev_model_state:
    model.load_state_dict(torch.load('./amalgam_model'))

In [34]:
torchsummary.summary(model, (3, 200, 200))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 5, 98, 98]             380
              SELU-2            [-1, 5, 98, 98]               0
            Conv2d-3           [-1, 10, 94, 94]           1,250
       BatchNorm2d-4           [-1, 10, 94, 94]              20
              SELU-5           [-1, 10, 94, 94]               0
         MaxPool2d-6           [-1, 10, 47, 47]               0
            Conv2d-7           [-1, 20, 45, 45]           1,820
              SELU-8           [-1, 20, 45, 45]               0
            Conv2d-9           [-1, 20, 43, 43]           3,620
             SELU-10           [-1, 20, 43, 43]               0
           Conv2d-11           [-1, 30, 41, 41]           5,400
      BatchNorm2d-12           [-1, 30, 41, 41]              60
             SELU-13           [-1, 30, 41, 41]               0
           Conv2d-14           [-1, 40,

In [35]:
model.eval()

AutoEncoderClassifierAmalgamation(
  (encoder): Encoder(
    (conv1): Conv2DNormActivation(
      (0): Conv2d(3, 5, kernel_size=(5, 5), stride=(2, 2))
      (1): SELU()
    )
    (conv2): Conv2DNormActivation(
      (0): Conv2d(5, 10, kernel_size=(5, 5), stride=(1, 1), bias=False)
      (1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SELU()
    )
    (max_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv3): Conv2DNormActivation(
      (0): Conv2d(10, 20, kernel_size=(3, 3), stride=(1, 1))
      (1): SELU()
    )
    (conv4): Conv2DNormActivation(
      (0): Conv2d(20, 20, kernel_size=(3, 3), stride=(1, 1))
      (1): SELU()
    )
    (conv5): Conv2DNormActivation(
      (0): Conv2d(20, 30, kernel_size=(3, 3), stride=(1, 1), bias=False)
      (1): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SELU()
    )
    (conv6): Conv2DNormActivation(
      (0): Co

In [36]:
labels = []

In [37]:
for i, (data, target) in enumerate(dataloader):
    data = data.to(device) # Move data to target device

    cls, recon = model(data)

    predictions = cls.max(dim=1)[1]
    labels += list(predictions.detach().cpu().numpy())

In [38]:
str_labels = [inv_class_map[x] for x in labels]

In [39]:
f_names = [x.split('/')[-1] for x in dataset_inference.images_path]

In [40]:
result_df = pd.DataFrame(np.array([f_names, str_labels]).T, columns=['filename', 'class'])

In [41]:
result_df.to_csv('./result3.csv', index=False)

In [42]:
str_labels

['Forward',
 'Forward',
 'Right_wing_mirror',
 'Right_wing_mirror',
 'Forward',
 'Left_wing_mirror',
 'Rearview_mirror',
 'Right_wing_mirror',
 'Right_wing_mirror',
 'Right_wing_mirror',
 'Rearview_mirror',
 'Centerstack',
 'Left_wing_mirror',
 'Centerstack',
 'Left_wing_mirror',
 'Forward',
 'Forward',
 'Centerstack',
 'Forward',
 'Forward',
 'Centerstack',
 'Centerstack',
 'Left_wing_mirror',
 'Centerstack',
 'Rearview_mirror',
 'Centerstack',
 'Forward',
 'Right_wing_mirror',
 'Forward',
 'Right_wing_mirror',
 'Forward',
 'Forward',
 'Centerstack',
 'Centerstack',
 'Left_wing_mirror',
 'Centerstack',
 'Forward',
 'Left_wing_mirror',
 'Forward',
 'Right_wing_mirror',
 'Left_wing_mirror',
 'Rearview_mirror',
 'Forward',
 'Right_wing_mirror',
 'Centerstack',
 'Left_wing_mirror',
 'Forward',
 'Rearview_mirror',
 'Left_wing_mirror',
 'Left_wing_mirror',
 'Centerstack',
 'Centerstack',
 'Rearview_mirror',
 'Rearview_mirror',
 'Left_wing_mirror',
 'Left_wing_mirror',
 'Centerstack',
 'Rear