In [None]:
import torch 
torch.cuda.empty_cache()
import torchvision 
import matplotlib.pyplot as plt 
import numpy as np 

import json 
import shutil 
import pandas as pd 
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
import models
from PIL import Image
import os
import src.dataset_ as dl
import src.utility_ as utility
import src.model_ as cnn_models

In [None]:
data = open('./data/data_dict.json')
data_dict = json.load(data)

In [None]:
print('Total number of images',len(data_dict[0]))
data_dict[0]['30601258@N03/landmark_aligned_face.1.10399646885_67c7d20df9_o.jpg']

In [None]:
# visualize the images
images_list = list(data_dict[0].keys())

# shows 
utility.show_img(images_list,4)

In [None]:
Age_Gender_Dataset = dl.Gender_Age_Classifier_dataset(json_file='./data/data_dict.json',root_dir='./data/aligned',
transforms=transforms.Compose([transforms.ToTensor(),transforms.Resize((52,52))]))

In [None]:
lengths = [int(len(Age_Gender_Dataset)*0.8), len(Age_Gender_Dataset)-int(len(Age_Gender_Dataset)*0.8)]
train_Dataset, val_Dataset = torch.utils.data.random_split(Age_Gender_Dataset, lengths)

In [None]:
train_dataloader = DataLoader(train_Dataset,batch_size=2,shuffle=True)
val_dataloader = DataLoader(val_Dataset,batch_size=2,shuffle=True)

In [None]:
# Device
# Getting gpu for training 
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

# device = "cpu"

In [None]:
model_architecture = "Resnet_multi_task"
model = cnn_models.Resnet_multi_task().to(device)

# sample_img = torch.randn(1,3,104,104).to(device)
# model(sample_img)

In [None]:
# Loss function
gender_criterion = torch.nn.CrossEntropyLoss()
age_criterion = torch.nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)

# input parameters 
epochs = 5
best_accuracy = torch.tensor(0.0)
resume_weights = False

In [None]:
def save_checkpoint(state, is_best, checkpoint_path, best_model_path):
	"""Save checkpoint if a new best is achieved
	
	state: checkpoint we want to save 
	is_best: if this checkpoint is the best so far
	checkpoint_path: path to save checkpoint
	best_model_path: path to save best model
	"""
	
	f_path = checkpoint_path

	# save checkpoint data to the path given, checkpoint_path
	torch.save(state, f_path)

	# if it is a best model, min validation loss
	if is_best:

		best_fpath = best_model_path
		# copy that checkpoint file to best path given, best_model_path

		shutil.copyfile(f_path, best_fpath)

def accuracy_metric(pred,age,gender):
	
	size = len(pred[0])
	# print(size)
	# print(age)
	# print(gender)
	# print(pred[0].argmax(1))
	# print(pred[1].argmax(1))
	correct_age = (pred[0].argmax(1) == age).type(torch.float).sum().item()
	correct_gender = (pred[1].argmax(1) == gender).type(torch.float).sum().item()
	# print(correct_age)
	# print(correct_age)
	age_acc = correct_age / size
	gender_acc = correct_gender / size

	return age_acc,gender_acc
	

In [None]:
def train(dataloader,model,age_criterion,gender_criterion,optimizer):
    
    size = len(dataloader.dataset)
    model.train()
    
    correct_age,correct_gender = 0,0

    for batch,(img,age,gender) in enumerate(dataloader):
        
        
        img, age, gender = img.to(device), age.to(device), gender.to(device)
        
        # compute prediction error 
        pred = model(img)

        age_loss = age_criterion(pred[0],age.long())
        gender_loss = gender_criterion(pred[1],gender.long())

        loss = (age_loss + gender_loss)/ 2 

        # Backprop 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        correct_age += (pred[0].argmax(1) == age).type(torch.float).sum().item()
        correct_gender += (pred[1].argmax(1) == gender).type(torch.float).sum().item()        
        
        if batch % 1000 == 0:
            age_loss,gender_loss,current = age_loss.item(),gender_loss.item(), batch*len(img)
            print(f"Age loss: {age_loss:>7f} Gender loss: {gender_loss:>7f} [{current:>5d}/{size:>5d}]")
    
    age_acc = correct_age / size
    gender_acc = correct_gender / size

    print(f" Train Age accuracy:{age_acc:>2f} Train Gender accuracy:{gender_acc:>2f}")
            

    
def test(dataloader, model,valid_loss_min_input,optimizer,age_criterion,gender_criterion,epoch,checkpoint_path,best_model_path):

    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()

    # initialize tracker for minimum validation loss
    valid_loss_min = valid_loss_min_input 

    test_age_loss,test_gender_loss, correct_age,correct_gender = 0,0,0,0

    for batch,(img,age,gender) in enumerate(dataloader):

        img,age,gender = img.to(device), age.to(device), gender.to(device)
        pred = model(img)

        test_age_loss += age_criterion(pred[0],age.long())
        test_gender_loss += gender_criterion(pred[1],gender.long())

        correct_age += (pred[0].argmax(1) == age).type(torch.float).sum().item()
        correct_gender += (pred[1].argmax(1) == gender).type(torch.float).sum().item()
        
    test_age_loss/=num_batches
    test_gender_loss /= num_batches
    correct_age /= size
    correct_gender /= size

    print(f"Test Error \n Age Accuracy: {100*correct_age:>2f} Gender Accuracy: {100*correct_gender:>2f} \n Age loss: {test_age_loss:>7f} Gender loss: {test_gender_loss:>7f}")

    combined_loss = (test_age_loss + test_gender_loss)/ 2 

    # create checkpoint variable and add important data
    checkpoint = {
        'epoch': epoch + 1,
        'valid_loss':combined_loss,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        }
        
    # save checkpoint   
    save_checkpoint(checkpoint, False, checkpoint_path, best_model_path)
        
    if combined_loss <= valid_loss_min:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(valid_loss_min,combined_loss))
        # save checkpoint as best model
        save_checkpoint(checkpoint, True, checkpoint_path, best_model_path)
        valid_loss_min = combined_loss


In [None]:


checkpoint_path = "./models/checkpoint/{}_checkpoint.pt".format(model_architecture)
best_model_path = "./models/best_model/{}_best_model.pt".format(model_architecture)

valid_loss_min_input = np.Inf

for i in range(epochs):
    print(f'Epoch:{i+1}\n ------------------------------')
    train(train_dataloader, model,age_criterion,gender_criterion,optimizer)
    test(val_dataloader, model,valid_loss_min_input,optimizer,age_criterion,gender_criterion,i,checkpoint_path,best_model_path)
