<a href="https://colab.research.google.com/github/Zahra2351373/demo-repo/blob/Master/TrainTridentMain.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import numpy as np
# torch imports
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.data import Dataset
import torch.nn.functional as F

In [2]:
# Run this once
from google.colab import drive
drive.mount('/content/gdrive')
# You will have to modify this based on your Google Drive directory structure
%cd /content/gdrive/MyDrive/Colab Notebooks/RideStream/
!pwd

Mounted at /content/gdrive
[Errno 2] No such file or directory: '/content/gdrive/MyDrive/Colab Notebooks/RideStream/'
/content
/content


In [3]:
class UTKDataset(Dataset):
    '''
        Inputs:
            dataFrame : Pandas dataFrame
            transform : The transform to apply to the dataset
    '''
    def __init__(self, dataFrame, transform=None):
        # read in the transforms
        self.transform = transform
        
        # Use the dataFrame to get the pixel values
        data_holder = dataFrame.pixels.apply(lambda x: np.array(x.split(" "),dtype=float))
        arr = np.stack(data_holder)
        arr = arr / 255.0
        arr = arr.astype('float32')
        arr = arr.reshape(arr.shape[0], 48, 48, 1)
        # reshape into 48x48x1
        self.data = arr
        
        # get the age, gender, and ethnicity label arrays
        self.age_label = np.array(dataFrame.bins[:])        # Note : Changed dataFrame.age to dataFrame.bins with most recent change
        self.gender_label = np.array(dataFrame.gender[:])
        self.eth_label = np.array(dataFrame.ethnicity[:])
    
    # override the length function
    def __len__(self):
        return len(self.data)
    
    # override the getitem function
    def __getitem__(self, index):
        # load the data at index and apply transform
        data = self.data[index]
        data = self.transform(data)
        
        # load the labels into a list and convert to tensors
        labels = torch.tensor((self.age_label[index], self.gender_label[index], self.eth_label[index]))
        
        # return data labels
        return data, labels

In [4]:
# High level feature extractor network (Adopted VGG type structure)
class highLevelNN(nn.Module):
    def __init__(self):
        super(highLevelNN, self).__init__()
        self.CNN = nn.Sequential(
            # first batch (32)
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.ReLU(),

            # second batch (64)
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.ReLU(),

            # Third Batch (128)
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
        )

    def forward(self, x):
        out = self.CNN(x)

        return out

# Low level feature extraction module
class lowLevelNN(nn.Module):
    def __init__(self, num_out):
        super(lowLevelNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(in_features=2048, out_features=256)
        self.fc2 = nn.Linear(in_features=256, out_features=128)
        self.fc3 = nn.Linear(in_features=128, out_features=64)
        self.fc4 = nn.Linear(in_features=64, out_features=num_out)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=3, stride=2, padding=1))
        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=3, stride=2, padding=1))
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)

        return x


class TridentNN(nn.Module):
    def __init__(self, num_age, num_gen, num_eth):
        super(TridentNN, self).__init__()
        # Construct the high level neural network
        self.CNN = highLevelNN()
        # Construct the low level neural networks
        self.ageNN = lowLevelNN(num_out=num_age)
        self.genNN = lowLevelNN(num_out=num_gen)
        self.ethNN = lowLevelNN(num_out=num_eth)

    def forward(self, x):
        x = self.CNN(x)
        age = self.ageNN(x)
        gen = self.genNN(x)
        eth = self.ethNN(x)

        return age, gen, eth

In [5]:
'''
    Function to test the trained model

    Inputs:
      - testloader : PyTorch DataLoader containing the test dataset
      - modle : Trained NeuralNetwork
    
    Outputs:
      - Prints out test accuracy for gender and ethnicity and loss for age
'''
def test(testloader, model):
  device = 'cuda' if torch.cuda.is_available() else 'cpu' 
  size = len(testloader.dataset)
  # put the moel in evaluation mode so we aren't storing anything in the graph
  model.eval()

  age_acc, gen_acc, eth_acc = 0, 0, 0

  with torch.no_grad():
      for X, y in testloader:
          X = X.to(device)
          age, gen, eth = y[:,0].to(device), y[:,1].to(device), y[:,2].to(device)
          pred = model(X)

          age_acc += (pred[0].argmax(1) == age).type(torch.float).sum().item()
          gen_acc += (pred[1].argmax(1) == gen).type(torch.float).sum().item()
          eth_acc += (pred[2].argmax(1) == eth).type(torch.float).sum().item()

  age_acc /= size
  gen_acc /= size
  eth_acc /= size

  print(f"Age Accuracy : {age_acc*100}%,     Gender Accuracy : {gen_acc*100},    Ethnicity Accuracy : {eth_acc*100}\n")

In [8]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [9]:
# Read in the dataframe
dataFrame = pd.read_csv("/content/drive/MyDrive/age_gender.gz", compression='gzip')

# Construct age bins
age_bins = [0,10,15,20,25,30,40,50,60,120]
age_labels = [0, 1, 2, 3, 4, 5, 6, 7, 8]
dataFrame['bins'] = pd.cut(dataFrame.age, bins=age_bins, labels=age_labels)

# Split into training and testing
train_dataFrame, test_dataFrame = train_test_split(dataFrame, test_size=0.2)

# get the number of unique classes for each group
class_nums = {'age_num':len(dataFrame['bins'].unique()), 'eth_num':len(dataFrame['ethnicity'].unique()),
              'gen_num':len(dataFrame['gender'].unique())}

# Define train and test transforms
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.49,), (0.23,))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.49,), (0.23,))
])

# Construct the custom pytorch datasets
train_set = UTKDataset(train_dataFrame, transform=train_transform)
test_set = UTKDataset(test_dataFrame, transform=test_transform)

# Load the datasets into dataloaders
trainloader = DataLoader(train_set, batch_size=64, shuffle=True)
testloader = DataLoader(test_set, batch_size=128, shuffle=False)

# Sanity Check
for X, y in trainloader:
    print(f'Shape of training X: {X.shape}')
    print(f'Shape of y: {y.shape}')
    break

Shape of training X: torch.Size([64, 1, 48, 48])
Shape of y: torch.Size([64, 3])


In [10]:
# Configure the device 
device = 'cuda' if torch.cuda.is_available() else 'cpu' 
print(device)

# Define the list of hyperparameters
hyperparameters = {'learning_rate':0.001, 'epochs':30}

# Initialize the TridentNN model and put on device
model = TridentNN(class_nums['age_num'], class_nums['gen_num'], class_nums['eth_num'])
model.to(device)

cuda


TridentNN(
  (CNN): highLevelNN(
    (CNN): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): ReLU()
      (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU()
      (7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (9): ReLU()
      (10): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU()
      (12): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (14): ReLU()
    )
  )
  (ageNN): lowLevelNN(
    (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1,

In [11]:
'''
  Functions to load and save a PyTorch model
'''
def save_checkpoint(state, epoch):
  print("Saving Checkpoint")
  filename = "tridentNN_epoch"+str(epoch)+".pth.tar"
  torch.save(state,filename)

def load_checkpoint(checkpoint):
  print("Loading Checkpoint")
  model.load_state_dict(checkpoint['state_dict'])
  opt.load_state_dict(checkpoint['optimizer'])


In [12]:
'''
train the model
''' 
# Load hyperparameters
learning_rate = hyperparameters['learning_rate']
num_epoch = hyperparameters['epochs']

# Define loss functions
age_loss = nn.CrossEntropyLoss()
gen_loss = nn.CrossEntropyLoss() # TODO : Explore using Binary Cross Entropy Loss?
eth_loss = nn.CrossEntropyLoss()

# Define optimizer
opt = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
for epoch in range(num_epoch):
  # Construct tqdm loop to keep track of training
  loop = tqdm(enumerate(trainloader), total=len(trainloader), position=0, leave=True)
  age_correct, gen_correct, eth_correct, total = 0,0,0,0

  # save the model every 10 epochs
  if epoch % 10 == 0:
    checkpoint = {'state_dict' : model.state_dict(), 'optimizer' : opt.state_dict(), 
                  'age_loss' : age_loss, 'gen_loss' : gen_loss, 'eth_loss' : eth_loss}
    save_checkpoint(checkpoint, epoch)

  # Loop through dataLoader
  for _, (X,y) in loop:
    # Unpack y to get true age, eth, and gen values
    # Have to do some special changes to age label to make it compatible with NN output and Loss function
    #age, gen, eth = y[:,0].resize_(len(y[:,0]),1).float().to(device), y[:,1].to(device), y[:,2].to(device)
    age, gen, eth = y[:,0].to(device), y[:,1].to(device), y[:,2].to(device)
    X = X.to(device)
    pred = model(X)          # Forward pass
    loss = age_loss(pred[0],age) + gen_loss(pred[1],gen) + eth_loss(pred[2],eth)   # Loss calculation

    # Backpropagation
    opt.zero_grad()          # Zero the gradient
    loss.backward()          # Calculate updates

    # Gradient Descent
    opt.step()               # Apply updates

    # Update num correct and total
    age_correct += (pred[0].argmax(1) == age).type(torch.float).sum().item()
    gen_correct += (pred[1].argmax(1) == gen).type(torch.float).sum().item()
    eth_correct += (pred[2].argmax(1) == eth).type(torch.float).sum().item()

    total += len(y)

    # Update progress bar
    loop.set_description(f"Epoch [{epoch+1}/{num_epoch}]")
    loop.set_postfix(loss = loss.item())

  # Update epoch accuracy
  gen_acc, eth_acc, age_acc = gen_correct/total, eth_correct/total, age_correct/total

  # print out accuracy and loss for epoch
  print(f'Epoch : {epoch+1}/{num_epoch},    Age Accuracy : {age_acc*100},    Gender Accuracy : {gen_acc*100},    Ethnicity Accuracy : {eth_acc*100}\n')

  0%|          | 0/297 [00:00<?, ?it/s]

Saving Checkpoint


Epoch [1/30]: 100%|██████████| 297/297 [00:17<00:00, 17.33it/s, loss=3.85]


Epoch : 1/30,    Age Accuracy : 21.44062434085636,    Gender Accuracy : 58.13646909934613,    Ethnicity Accuracy : 43.78823033115376



Epoch [2/30]: 100%|██████████| 297/297 [00:09<00:00, 31.33it/s, loss=3.07]


Epoch : 2/30,    Age Accuracy : 33.44758489770091,    Gender Accuracy : 80.97975110736132,    Ethnicity Accuracy : 54.012866483864165



Epoch [3/30]: 100%|██████████| 297/297 [00:10<00:00, 29.45it/s, loss=2.82]


Epoch : 3/30,    Age Accuracy : 41.67369753216621,    Gender Accuracy : 85.82050200379668,    Ethnicity Accuracy : 65.55051676861422



Epoch [4/30]: 100%|██████████| 297/297 [00:10<00:00, 28.45it/s, loss=2.66]


Epoch : 4/30,    Age Accuracy : 44.61611474372496,    Gender Accuracy : 88.22505800464036,    Ethnicity Accuracy : 72.62180974477958



Epoch [5/30]: 100%|██████████| 297/297 [00:10<00:00, 29.37it/s, loss=1.96]


Epoch : 5/30,    Age Accuracy : 46.6568234549673,    Gender Accuracy : 89.48006749630879,    Ethnicity Accuracy : 75.60641214933558



Epoch [6/30]: 100%|██████████| 297/297 [00:09<00:00, 30.29it/s, loss=2.38]


Epoch : 6/30,    Age Accuracy : 48.13857835899599,    Gender Accuracy : 90.6137945581101,    Ethnicity Accuracy : 77.96878295718203



Epoch [7/30]: 100%|██████████| 297/297 [00:09<00:00, 29.88it/s, loss=1.79]


Epoch : 7/30,    Age Accuracy : 50.163467622864374,    Gender Accuracy : 91.5524151022991,    Ethnicity Accuracy : 79.42417211558744



Epoch [8/30]: 100%|██████████| 297/297 [00:09<00:00, 30.64it/s, loss=1.72]


Epoch : 8/30,    Age Accuracy : 51.45011600928074,    Gender Accuracy : 92.59122547985656,    Ethnicity Accuracy : 81.49124657245306



Epoch [9/30]: 100%|██████████| 297/297 [00:09<00:00, 30.37it/s, loss=2.43]


Epoch : 9/30,    Age Accuracy : 52.578569921957396,    Gender Accuracy : 93.13435983969627,    Ethnicity Accuracy : 82.33495043239823



Epoch [10/30]: 100%|██████████| 297/297 [00:09<00:00, 30.07it/s, loss=1.47]


Epoch : 10/30,    Age Accuracy : 53.98650073824087,    Gender Accuracy : 93.83041552415102,    Ethnicity Accuracy : 83.81143218730226



Epoch [11/30]:   0%|          | 1/297 [00:00<00:38,  7.64it/s, loss=1.44]

Saving Checkpoint


Epoch [11/30]: 100%|██████████| 297/297 [00:10<00:00, 28.67it/s, loss=1.86]


Epoch : 11/30,    Age Accuracy : 55.48934823876819,    Gender Accuracy : 94.67939253322083,    Ethnicity Accuracy : 85.37228432820079



Epoch [12/30]: 100%|██████████| 297/297 [00:09<00:00, 30.66it/s, loss=1.02]


Epoch : 12/30,    Age Accuracy : 57.34549673064754,    Gender Accuracy : 95.11179076144273,    Ethnicity Accuracy : 86.66947901286647



Epoch [13/30]: 100%|██████████| 297/297 [00:09<00:00, 29.96it/s, loss=1.58]


Epoch : 13/30,    Age Accuracy : 58.75342754693102,    Gender Accuracy : 95.66019827040708,    Ethnicity Accuracy : 87.7768403290445



Epoch [14/30]: 100%|██████████| 297/297 [00:09<00:00, 29.93it/s, loss=1.38]


Epoch : 14/30,    Age Accuracy : 60.42501581944737,    Gender Accuracy : 95.99767981438515,    Ethnicity Accuracy : 88.73655346973213



Epoch [15/30]: 100%|██████████| 297/297 [00:09<00:00, 31.15it/s, loss=1.35]


Epoch : 15/30,    Age Accuracy : 63.198692259017086,    Gender Accuracy : 96.56717991984813,    Ethnicity Accuracy : 90.02320185614849



Epoch [16/30]: 100%|██████████| 297/297 [00:09<00:00, 29.95it/s, loss=0.736]


Epoch : 16/30,    Age Accuracy : 64.85446108415947,    Gender Accuracy : 96.73064754271252,    Ethnicity Accuracy : 91.1832946635731



Epoch [17/30]: 100%|██████████| 297/297 [00:09<00:00, 29.81it/s, loss=1.16]


Epoch : 17/30,    Age Accuracy : 67.45939675174014,    Gender Accuracy : 97.02067074456866,    Ethnicity Accuracy : 91.86880404977853



Epoch [18/30]: 100%|██████████| 297/297 [00:09<00:00, 30.49it/s, loss=0.737]


Epoch : 18/30,    Age Accuracy : 69.46846656823456,    Gender Accuracy : 97.40561063066863,    Ethnicity Accuracy : 93.06580890107571



Epoch [19/30]: 100%|██████████| 297/297 [00:12<00:00, 24.61it/s, loss=1.04]


Epoch : 19/30,    Age Accuracy : 71.9784855515714,    Gender Accuracy : 97.47943471841384,    Ethnicity Accuracy : 93.65640160303734



Epoch [20/30]: 100%|██████████| 297/297 [00:09<00:00, 29.96it/s, loss=1.14]


Epoch : 20/30,    Age Accuracy : 74.41468044716305,    Gender Accuracy : 97.74836532377135,    Ethnicity Accuracy : 94.37882303311538



Epoch [21/30]:   0%|          | 1/297 [00:00<00:41,  7.09it/s, loss=0.693]

Saving Checkpoint


Epoch [21/30]: 100%|██████████| 297/297 [00:10<00:00, 29.06it/s, loss=1.17]


Epoch : 21/30,    Age Accuracy : 76.1970048512972,    Gender Accuracy : 98.01202278000422,    Ethnicity Accuracy : 94.9430499894537



Epoch [22/30]: 100%|██████████| 297/297 [00:09<00:00, 30.75it/s, loss=0.945]


Epoch : 22/30,    Age Accuracy : 78.30099135203542,    Gender Accuracy : 97.98565703438094,    Ethnicity Accuracy : 95.6074667791605



Epoch [23/30]: 100%|██████████| 297/297 [00:09<00:00, 29.98it/s, loss=0.474]


Epoch : 23/30,    Age Accuracy : 80.27314912465725,    Gender Accuracy : 98.47605990297406,    Ethnicity Accuracy : 96.00822611263446



Epoch [24/30]: 100%|██████████| 297/297 [00:09<00:00, 30.02it/s, loss=0.623]


Epoch : 24/30,    Age Accuracy : 82.29803838852563,    Gender Accuracy : 98.37587006960557,    Ethnicity Accuracy : 96.1875131828728



Epoch [25/30]: 100%|██████████| 297/297 [00:09<00:00, 30.74it/s, loss=0.515]


Epoch : 25/30,    Age Accuracy : 83.28939042396118,    Gender Accuracy : 98.54988399071925,    Ethnicity Accuracy : 96.6199114110947



Epoch [26/30]: 100%|██████████| 297/297 [00:09<00:00, 29.97it/s, loss=1.09]


Epoch : 26/30,    Age Accuracy : 84.63931659987345,    Gender Accuracy : 98.70280531533432,    Ethnicity Accuracy : 96.97321240244673



Epoch [27/30]: 100%|██████████| 297/297 [00:10<00:00, 28.19it/s, loss=0.432]


Epoch : 27/30,    Age Accuracy : 86.68529846024046,    Gender Accuracy : 98.59734233284118,    Ethnicity Accuracy : 97.21577726218096



Epoch [28/30]: 100%|██████████| 297/297 [00:10<00:00, 28.94it/s, loss=0.432]


Epoch : 28/30,    Age Accuracy : 86.7907614427336,    Gender Accuracy : 98.75026365745623,    Ethnicity Accuracy : 97.17886521830837



Epoch [29/30]: 100%|██████████| 297/297 [00:09<00:00, 30.70it/s, loss=0.719]


Epoch : 29/30,    Age Accuracy : 87.36553469732125,    Gender Accuracy : 98.91900442944527,    Ethnicity Accuracy : 97.28432820080151



Epoch [30/30]: 100%|██████████| 297/297 [00:09<00:00, 29.99it/s, loss=0.4]

Epoch : 30/30,    Age Accuracy : 88.65745623286226,    Gender Accuracy : 98.83463404345075,    Ethnicity Accuracy : 97.97511073613163






I manuall interrupted the training because I wanted everything to have a training accuracy > 90% and I didn't code that part in yet
<br> <br>
Now I am going to test the model

In [13]:
test(testloader, model)

Age Accuracy : 46.88884201645222%,     Gender Accuracy : 87.51318287281164,    Ethnicity Accuracy : 75.57477325458764

