In [33]:
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import numpy as np
# torch imports
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.data import Dataset
import torch.nn.functional as F

In [3]:
# Run this once
from google.colab import drive
drive.mount('/content/gdrive')
# You will have to modify this based on your Google Drive directory structure
%cd /content/gdrive/MyDrive/Colab Notebooks/RideStream/
!pwd

Mounted at /content/gdrive
/content/gdrive/MyDrive/Colab Notebooks/RideStream
/content/gdrive/MyDrive/Colab Notebooks/RideStream


In [4]:
class UTKDataset(Dataset):
    '''
        Inputs:
            dataFrame : Pandas dataFrame
            transform : The transform to apply to the dataset
    '''
    def __init__(self, dataFrame, transform=None):
        # read in the transforms
        self.transform = transform
        
        # Use the dataFrame to get the pixel values
        data_holder = dataFrame.pixels.apply(lambda x: np.array(x.split(" "),dtype=float))
        arr = np.stack(data_holder)
        arr = arr / 255.0
        arr = arr.astype('float32')
        arr = arr.reshape(arr.shape[0], 48, 48, 1)
        # reshape into 48x48x1
        self.data = arr
        
        # get the age, gender, and ethnicity label arrays
        self.age_label = np.array(dataFrame.bins[:])        # Note : Changed dataFrame.age to dataFrame.bins with most recent change
        self.gender_label = np.array(dataFrame.gender[:])
        self.eth_label = np.array(dataFrame.ethnicity[:])
    
    # override the length function
    def __len__(self):
        return len(self.data)
    
    # override the getitem function
    def __getitem__(self, index):
        # load the data at index and apply transform
        data = self.data[index]
        data = self.transform(data)
        
        # load the labels into a list and convert to tensors
        labels = torch.tensor((self.age_label[index], self.gender_label[index], self.eth_label[index]))
        
        # return data labels
        return data, labels

In [5]:
# High level feature extractor network (Adopted VGG type structure)
class highLevelNN(nn.Module):
    def __init__(self):
        super(highLevelNN, self).__init__()
        self.CNN = nn.Sequential(
            # first batch (32)
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.ReLU(),

            # second batch (64)
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.ReLU(),

            # Third Batch (128)
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
        )

    def forward(self, x):
        out = self.CNN(x)

        return out

# Low level feature extraction module
class lowLevelNN(nn.Module):
    def __init__(self, num_out):
        super(lowLevelNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(in_features=2048, out_features=256)
        self.fc2 = nn.Linear(in_features=256, out_features=128)
        self.fc3 = nn.Linear(in_features=128, out_features=64)
        self.fc4 = nn.Linear(in_features=64, out_features=num_out)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=3, stride=2, padding=1))
        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=3, stride=2, padding=1))
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)

        return x


class TridentNN(nn.Module):
    def __init__(self, num_age, num_gen, num_eth):
        super(TridentNN, self).__init__()
        # Construct the high level neural network
        self.CNN = highLevelNN()
        # Construct the low level neural networks
        self.ageNN = lowLevelNN(num_out=num_age)
        self.genNN = lowLevelNN(num_out=num_gen)
        self.ethNN = lowLevelNN(num_out=num_eth)

    def forward(self, x):
        x = self.CNN(x)
        age = self.ageNN(x)
        gen = self.genNN(x)
        eth = self.ethNN(x)

        return age, gen, eth

In [35]:
'''
    Function to test the trained model

    Inputs:
      - testloader : PyTorch DataLoader containing the test dataset
      - modle : Trained NeuralNetwork
    
    Outputs:
      - Prints out test accuracy for gender and ethnicity and loss for age
'''
def test(testloader, model):
  device = 'cuda' if torch.cuda.is_available() else 'cpu' 
  size = len(testloader.dataset)
  # put the moel in evaluation mode so we aren't storing anything in the graph
  model.eval()

  age_acc, gen_acc, eth_acc = 0, 0, 0

  with torch.no_grad():
      for X, y in testloader:
          X = X.to(device)
          age, gen, eth = y[:,0].to(device), y[:,1].to(device), y[:,2].to(device)
          pred = model(X)

          age_acc += (pred[0].argmax(1) == age).type(torch.float).sum().item()
          gen_acc += (pred[1].argmax(1) == gen).type(torch.float).sum().item()
          eth_acc += (pred[2].argmax(1) == eth).type(torch.float).sum().item()

  age_acc /= size
  gen_acc /= size
  eth_acc /= size

  print(f"Age Accuracy : {age_acc*100}%,     Gender Accuracy : {gen_acc*100},    Ethnicity Accuracy : {eth_acc*100}\n")

In [18]:
# Read in the dataframe
dataFrame = pd.read_csv('age_gender.gz', compression='gzip')

# Construct age bins
age_bins = [0,10,15,20,25,30,40,50,60,120]
age_labels = [0, 1, 2, 3, 4, 5, 6, 7, 8]
dataFrame['bins'] = pd.cut(dataFrame.age, bins=age_bins, labels=age_labels)

# Split into training and testing
train_dataFrame, test_dataFrame = train_test_split(dataFrame, test_size=0.2)

# get the number of unique classes for each group
class_nums = {'age_num':len(dataFrame['bins'].unique()), 'eth_num':len(dataFrame['ethnicity'].unique()),
              'gen_num':len(dataFrame['gender'].unique())}
'''
class_nums = {'age_num':1, 'eth_num':len(dataFrame['ethnicity'].unique()),
              'gen_num':len(dataFrame['gender'].unique())}
'''

# Define train and test transforms
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.49,), (0.23,))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.49,), (0.23,))
])

# Construct the custom pytorch datasets
train_set = UTKDataset(train_dataFrame, transform=train_transform)
test_set = UTKDataset(test_dataFrame, transform=test_transform)

# Load the datasets into dataloaders
trainloader = DataLoader(train_set, batch_size=64, shuffle=True)
testloader = DataLoader(test_set, batch_size=128, shuffle=False)

# Sanity Check
for X, y in trainloader:
    print(f'Shape of training X: {X.shape}')
    print(f'Shape of y: {y.shape}')
    break

Shape of training X: torch.Size([64, 1, 48, 48])
Shape of y: torch.Size([64, 3])


In [22]:
# Configure the device 
device = 'cuda' if torch.cuda.is_available() else 'cpu' 
print(device)

# Define the list of hyperparameters
hyperparameters = {'learning_rate':0.001, 'epochs':30}

# Initialize the TridentNN model and put on device
model = TridentNN(class_nums['age_num'], class_nums['gen_num'], class_nums['eth_num'])
model.to(device)

cuda


TridentNN(
  (CNN): highLevelNN(
    (CNN): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): ReLU()
      (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU()
      (7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (9): ReLU()
      (10): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU()
      (12): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (14): ReLU()
    )
  )
  (ageNN): lowLevelNN(
    (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1,

In [23]:
'''
  Functions to load and save a PyTorch model
'''
def save_checkpoint(state, epoch):
  print("Saving Checkpoint")
  filename = "tridentNN_epoch"+str(epoch)+".pth.tar"
  torch.save(state,filename)

def load_checkpoint(checkpoint):
  print("Loading Checkpoint")
  model.load_state_dict(checkpoint['state_dict'])
  opt.load_state_dict(checkpoint['optimizer'])


In [34]:
'''
train the model
''' 
# Load hyperparameters
learning_rate = hyperparameters['learning_rate']
num_epoch = hyperparameters['epochs']

# Define loss functions
age_loss = nn.CrossEntropyLoss()
gen_loss = nn.CrossEntropyLoss() # TODO : Explore using Binary Cross Entropy Loss?
eth_loss = nn.CrossEntropyLoss()

# Define optimizer
opt = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Initialize the summaryWriter
# writer = SummaryWriter(f'LR: {learning_rate}')

# Train the model
for epoch in range(num_epoch):
  # Construct tqdm loop to keep track of training
  loop = tqdm(enumerate(trainloader), total=len(trainloader), position=0, leave=True)
  age_correct, gen_correct, eth_correct, total = 0,0,0,0

  # save the model every 10 epochs
  if epoch % 10 == 0:
    checkpoint = {'state_dict' : model.state_dict(), 'optimizer' : opt.state_dict(), 
                  'age_loss' : age_loss, 'gen_loss' : gen_loss, 'eth_loss' : eth_loss}
    save_checkpoint(checkpoint, epoch)

  # Loop through dataLoader
  for _, (X,y) in loop:
    # Unpack y to get true age, eth, and gen values
    # Have to do some special changes to age label to make it compatible with NN output and Loss function
    #age, gen, eth = y[:,0].resize_(len(y[:,0]),1).float().to(device), y[:,1].to(device), y[:,2].to(device)
    age, gen, eth = y[:,0].to(device), y[:,1].to(device), y[:,2].to(device)
    X = X.to(device)
    pred = model(X)          # Forward pass
    loss = age_loss(pred[0],age) + gen_loss(pred[1],gen) + eth_loss(pred[2],eth)   # Loss calculation

    # Backpropagation
    opt.zero_grad()          # Zero the gradient
    loss.backward()          # Calculate updates

    # Gradient Descent
    opt.step()               # Apply updates

    # Update num correct and total
    age_correct += (pred[0].argmax(1) == age).type(torch.float).sum().item()
    gen_correct += (pred[1].argmax(1) == gen).type(torch.float).sum().item()
    eth_correct += (pred[2].argmax(1) == eth).type(torch.float).sum().item()

    total += len(y)

    # Update progress bar
    loop.set_description(f"Epoch [{epoch+1}/{num_epoch}]")
    loop.set_postfix(loss = loss.item())

  # Update epoch accuracy
  gen_acc, eth_acc, age_acc = gen_correct/total, eth_correct/total, age_correct/total

  # print out accuracy and loss for epoch
  print(f'Epoch : {epoch+1}/{num_epoch},    Age Accuracy : {age_acc*100},    Gender Accuracy : {gen_acc*100},    Ethnicity Accuracy : {eth_acc*100}\n')

Epoch [1/30]:   0%|          | 1/297 [00:00<00:44,  6.65it/s, loss=1.48]

Saving Checkpoint


Epoch [1/30]: 100%|██████████| 297/297 [00:10<00:00, 27.40it/s, loss=1.84]
Epoch [2/30]:   1%|          | 3/297 [00:00<00:10, 28.54it/s, loss=1.12]

Epoch : 1/30,    Age Accuracy : 56.950010546298245,    Gender Accuracy : 94.49483231385784,    Ethnicity Accuracy : 85.62539548618435



Epoch [2/30]: 100%|██████████| 297/297 [00:10<00:00, 28.32it/s, loss=1.21]
Epoch [3/30]:   1%|          | 3/297 [00:00<00:10, 28.19it/s, loss=1.26]

Epoch : 2/30,    Age Accuracy : 59.51276102088167,    Gender Accuracy : 95.13288335794137,    Ethnicity Accuracy : 87.6028264079308



Epoch [3/30]: 100%|██████████| 297/297 [00:10<00:00, 28.08it/s, loss=1.97]
Epoch [4/30]:   1%|          | 3/297 [00:00<00:10, 28.46it/s, loss=1.32]

Epoch : 3/30,    Age Accuracy : 61.78021514448429,    Gender Accuracy : 95.58110103353724,    Ethnicity Accuracy : 88.60999789074036



Epoch [4/30]: 100%|██████████| 297/297 [00:10<00:00, 28.07it/s, loss=1.35]
Epoch [5/30]:   1%|          | 3/297 [00:00<00:10, 28.70it/s, loss=1.3] 

Epoch : 4/30,    Age Accuracy : 63.70491457498419,    Gender Accuracy : 96.00822611263446,    Ethnicity Accuracy : 89.43788230331154



Epoch [5/30]: 100%|██████████| 297/297 [00:10<00:00, 28.08it/s, loss=1.36]
Epoch [6/30]:   1%|          | 3/297 [00:00<00:10, 29.11it/s, loss=1.11] 

Epoch : 5/30,    Age Accuracy : 65.38704914574984,    Gender Accuracy : 96.37734655136047,    Ethnicity Accuracy : 91.05146593545666



Epoch [6/30]: 100%|██████████| 297/297 [00:10<00:00, 27.99it/s, loss=1.26]
Epoch [7/30]:   1%|          | 3/297 [00:00<00:09, 29.75it/s, loss=1.01] 

Epoch : 6/30,    Age Accuracy : 67.71250790972368,    Gender Accuracy : 96.57245306897279,    Ethnicity Accuracy : 91.67369753216622



Epoch [7/30]: 100%|██████████| 297/297 [00:10<00:00, 28.04it/s, loss=1.03]
Epoch [8/30]:   1%|          | 3/297 [00:00<00:10, 28.73it/s, loss=0.844]

Epoch : 7/30,    Age Accuracy : 69.51592491035646,    Gender Accuracy : 97.25268930605357,    Ethnicity Accuracy : 92.7968782957182



Epoch [8/30]: 100%|██████████| 297/297 [00:10<00:00, 28.12it/s, loss=0.853]
Epoch [9/30]:   1%|          | 3/297 [00:00<00:10, 28.13it/s, loss=0.931]

Epoch : 8/30,    Age Accuracy : 72.12086057793714,    Gender Accuracy : 97.51107361316178,    Ethnicity Accuracy : 93.45602193630036



Epoch [9/30]: 100%|██████████| 297/297 [00:10<00:00, 28.16it/s, loss=1.03]
Epoch [10/30]:   1%|          | 3/297 [00:00<00:10, 28.34it/s, loss=0.734]

Epoch : 9/30,    Age Accuracy : 74.36722210504114,    Gender Accuracy : 97.57962455178233,    Ethnicity Accuracy : 94.27336005062223



Epoch [10/30]: 100%|██████████| 297/297 [00:10<00:00, 28.05it/s, loss=0.517]
  0%|          | 0/297 [00:00<?, ?it/s]

Epoch : 10/30,    Age Accuracy : 77.00379666736976,    Gender Accuracy : 97.82746256064121,    Ethnicity Accuracy : 94.7584897700907

Saving Checkpoint


Epoch [11/30]: 100%|██████████| 297/297 [00:10<00:00, 27.11it/s, loss=0.688]
Epoch [12/30]:   1%|          | 3/297 [00:00<00:10, 28.18it/s, loss=0.404]

Epoch : 11/30,    Age Accuracy : 78.45918582577515,    Gender Accuracy : 98.13330520987134,    Ethnicity Accuracy : 95.46509175279478



Epoch [12/30]: 100%|██████████| 297/297 [00:10<00:00, 27.88it/s, loss=0.456]
Epoch [13/30]:   1%|          | 3/297 [00:00<00:10, 29.20it/s, loss=0.513]

Epoch : 12/30,    Age Accuracy : 79.17633410672855,    Gender Accuracy : 98.38641636785488,    Ethnicity Accuracy : 95.96076777051255



Epoch [13/30]: 100%|██████████| 297/297 [00:10<00:00, 28.11it/s, loss=0.508]
Epoch [14/30]:   1%|          | 3/297 [00:00<00:10, 29.21it/s, loss=0.701]

Epoch : 13/30,    Age Accuracy : 80.94283906348872,    Gender Accuracy : 98.41278211347817,    Ethnicity Accuracy : 96.13478169162623



Epoch [14/30]: 100%|██████████| 297/297 [00:10<00:00, 28.07it/s, loss=0.733]
Epoch [15/30]:   1%|          | 3/297 [00:00<00:10, 27.56it/s, loss=0.584]

Epoch : 14/30,    Age Accuracy : 82.15039021303522,    Gender Accuracy : 98.46551360472475,    Ethnicity Accuracy : 96.53554102510019



Epoch [15/30]: 100%|██████████| 297/297 [00:10<00:00, 28.03it/s, loss=0.785]
Epoch [16/30]:   1%|          | 3/297 [00:00<00:10, 28.28it/s, loss=0.589]

Epoch : 15/30,    Age Accuracy : 84.37565914364058,    Gender Accuracy : 98.66062012233706,    Ethnicity Accuracy : 96.66209660409196



Epoch [16/30]: 100%|██████████| 297/297 [00:10<00:00, 28.03it/s, loss=0.642]
Epoch [17/30]:   1%|          | 3/297 [00:00<00:10, 29.08it/s, loss=0.477]

Epoch : 16/30,    Age Accuracy : 84.77114532798987,    Gender Accuracy : 98.69753216620965,    Ethnicity Accuracy : 96.67264290234128



Epoch [17/30]: 100%|██████████| 297/297 [00:10<00:00, 27.98it/s, loss=0.235]
Epoch [18/30]:   1%|          | 3/297 [00:00<00:10, 28.22it/s, loss=0.312]

Epoch : 17/30,    Age Accuracy : 86.44273360050623,    Gender Accuracy : 98.82408774520144,    Ethnicity Accuracy : 97.30542079730014



Epoch [18/30]: 100%|██████████| 297/297 [00:10<00:00, 28.01it/s, loss=0.47]
Epoch [19/30]:   1%|          | 3/297 [00:00<00:10, 28.39it/s, loss=0.368]

Epoch : 18/30,    Age Accuracy : 86.77494199535964,    Gender Accuracy : 98.80826829782747,    Ethnicity Accuracy : 97.5321662096604



Epoch [19/30]: 100%|██████████| 297/297 [00:10<00:00, 27.81it/s, loss=1.13]
Epoch [20/30]:   1%|          | 3/297 [00:00<00:10, 28.26it/s, loss=0.439]

Epoch : 19/30,    Age Accuracy : 88.21978485551571,    Gender Accuracy : 98.97700906981649,    Ethnicity Accuracy : 97.54271250790973



Epoch [20/30]: 100%|██████████| 297/297 [00:10<00:00, 27.98it/s, loss=0.329]
  0%|          | 0/297 [00:00<?, ?it/s]

Epoch : 20/30,    Age Accuracy : 88.9000210925965,    Gender Accuracy : 98.89791183294663,    Ethnicity Accuracy : 97.7061801307741

Saving Checkpoint


Epoch [21/30]: 100%|██████████| 297/297 [00:11<00:00, 26.91it/s, loss=0.517]
Epoch [22/30]:   1%|          | 3/297 [00:00<00:10, 28.08it/s, loss=0.282]

Epoch : 21/30,    Age Accuracy : 89.37987766294032,    Gender Accuracy : 99.05610630668636,    Ethnicity Accuracy : 97.71145327989875



Epoch [22/30]: 100%|██████████| 297/297 [00:10<00:00, 27.95it/s, loss=0.702]
Epoch [23/30]:   1%|          | 3/297 [00:00<00:10, 28.26it/s, loss=0.487]

Epoch : 22/30,    Age Accuracy : 89.98628981227588,    Gender Accuracy : 99.04028685931237,    Ethnicity Accuracy : 97.73781902552204



Epoch [23/30]: 100%|██████████| 297/297 [00:10<00:00, 27.92it/s, loss=0.0925]
Epoch [24/30]:   1%|          | 3/297 [00:00<00:10, 26.90it/s, loss=0.297]

Epoch : 23/30,    Age Accuracy : 90.42396118962243,    Gender Accuracy : 99.10356464880827,    Ethnicity Accuracy : 97.62180974477958



Epoch [24/30]:  81%|████████  | 240/297 [00:08<00:02, 27.54it/s, loss=0.28] 

KeyboardInterrupt: ignored

I manuall interrupted the training because I wanted everything to have a training accuracy > 90% and I didn't code that part in yet
<br> <br>
Now I am going to test the model

In [40]:
test(testloader, model)

Age Accuracy : 45.85530478801941%,     Gender Accuracy : 89.0107572242143,    Ethnicity Accuracy : 76.46066230753006



As you can see the testing accuracy is not that great. My hypothesis is that predicting age is actually a very difficult task because there is so much variation between how people age.
<br> <br> 
Even between different genders and ethnicities there is so much variance. Therefore, we have both inter and intra-variance when it comes to age.
<br> <br>
Perhaps a better approach would be to feed the outputs of the gender and ethnicity classifier to the age classifier so it can use that information as well. But,that's a project for another day.