In [67]:
import torch
import torch.cuda
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import requests, tarfile
from tqdm import tqdm

In [68]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'}")

Using cpu


In [69]:
# CONSTANTS
GENDER = ['M', 'F']
ETHNICITY = ['White', 'Black', 'Asian', 'Indian', 'Others']

# HELPER FUNCTIONS
def arrayAge(data):
    res = []
    for (age, _, __) in data:
        res.append(age)
    return np.array(res)

def genderToStr(num):
    return GENDER[num]

def arrayGenderToStr(data):
    res = []
    for (_, gender, __) in data:
        res.append(genderToStr(gender))
    return np.array(res)

def ethnicityToStr(num):
    return ETHNICITY[num]

def arrayEthnicityToStr(data):
    res = []
    for (_, __, ethnicity) in data:
        res.append(ethnicityToStr(ethnicity))
    return np.array(res)

def histPlot(labels, title, yLabel, xLabel, bins):
    plt.title(title, size=16)
    sns.histplot(x = labels, bins = bins)
    plt.ylabel(yLabel, size=12)
    plt.xlabel(xLabel, size=12)
    sns.despine(top=True, right=True, left=False, bottom=False)
    plt.show()

def countPlot(labels, title, yLabel, xLabel):
    plt.title(title, size=16)
    ax = sns.countplot(x = labels)
    plt.ylabel(yLabel, size=12)
    plt.xlabel(xLabel, size=12)
    sns.despine(top=True, right=True, left=False, bottom=False)

    total = len(labels)
    for p in ax.patches:
        height = p.get_height()
        percentage = f'{100 * height / total:.1f}%'
        ax.text(p.get_x() + p.get_width() / 2,
                height + 5,
                percentage,
                ha='center')
        
    plt.show()

In [70]:
dataPath = 'UTKFace'
if (dataPath not in os.listdir()):
    print("Downloading UTKFace...")
    url = "https://drive.google.com/uc?export=download&id=0BxYys69jI14kYVM3aVhKS1VhRUk&confirm=t&uuid=f981ca1d-ba0f-40c9-a4a0-8eaa887f3b6d&at=ANzk5s7e36SgjT0FlqBbRiijefRg:1681897584880"

    response = requests.get(url, stream=True)
    file = tarfile.open(fileobj=response.raw, mode="r|gz")
    file.extractall(path=".")
    print("Download complete.")
else:
    print("UTKFace already downloaded.")

UTKFace already downloaded.


In [71]:
#data = []
#labels = []

#for imagePath in os.listdir(dataPath):
#    try:
#        imageTensor = torchvision.io.read_image(f'{dataPath}/{imagePath}').float().half()
#        fileName = imagePath.split('_')
#        labels.append((int(fileName[0]), int(fileName[1]), int(fileName[2])))
#        data.append(imageTensor)
#    except:
#        pass
#data = torch.stack(data).to(device)
#labels = torch.Tensor(labels).to(device)


In [72]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, dataPath, transform=None):
        self.dataPath = dataPath
        self.transform = transform
        self.imagePaths = [f for f in os.listdir(self.dataPath) if f.endswith('.jpg')]
        
    def __getitem__(self, index):
        imagePath = self.imagePaths[index]
        try:
            imageTensor = torchvision.io.read_image(f'{self.dataPath}/{imagePath}').float()
            fileName = imagePath.split('_')
            label = torch.Tensor([int(fileName[0]), int(fileName[1]), int(fileName[2])])
            if self.transform:
                imageTensor = self.transform(imageTensor)
            return imageTensor, label
        except:
            return self.__getitem__((index + 1) % len(self.imagePaths))
        
    def __len__(self):
        return len(self.imagePaths)

In [73]:
# Load and normalizde the data
transform = transforms.Compose(
    [transforms.Resize(224) #,
     # transforms.ToTensor(),
     #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
     ])

#dataset = torch.utils.data.TensorDataset(data, labels)

dataset = MyDataset(dataPath, transform=transform)

batchSize = 100
testSplit = 0.1 # use 10% of dataset as test
validSplit = 0.2 / (1-testSplit) # use 20% of dataset as validation

testSize = int(np.floor(len(dataset)*testSplit))
trainValidSize = int(np.ceil(len(dataset)*(1-testSplit)))
validSize = int(np.ceil(trainValidSize*validSplit))
trainSize = int(np.floor(trainValidSize*(1-validSplit)))
print(len(dataset), testSize, trainValidSize)

trainValidSet, testSet = torch.utils.data.random_split(dataset, [trainValidSize, testSize])
trainSet, validSet = torch.utils.data.random_split(trainValidSet, [trainSize, validSize])

trainLoader = torch.utils.data.DataLoader(trainSet, batch_size=batchSize, shuffle=True)
validLoader = torch.utils.data.DataLoader(validSet, batch_size=batchSize, shuffle=True)
testLoader = torch.utils.data.DataLoader(testSet, batch_size=batchSize, shuffle=False)

23708 2370 21338


In [74]:
class ResNetModel(nn.Module):
    def __init__(self):
        super(ResNetModel,self).__init__()
        self.resnet = torchvision.models.resnet18(pretrained=True)
        for param in self.resnet.parameters():
            param.requires_grad=False
        self.resnet.fc = nn.Linear(512, 512)
        self.ageFc = nn.Linear(512,1)
        self.genderFc = nn.Linear(512,2)
        self.ethnicityFc = nn.Linear(512,5)
    
    def forward(self, x):
        resOut = F.relu(self.resnet.forward(x))
        ageOut = self.ageFc.forward(resOut)
        #genderOut = self.ageFc.forward(resOut)
        ethnicityOut = self.ageFc.forward(resOut)
        genderOut = F.sigmoid(self.ageFc.forward(resOut))
        #ethnicityOut = F.softmax(self.ageFc.forward(resOut))
        return ageOut, genderOut, ethnicityOut

In [75]:
# ResNetModel()

In [76]:

def lossAge(predictAge, targetAge):
    loss = F.mse_loss(predictAge, targetAge)
    return loss

def lossGender(predictGender, targetGender):
    loss = F.binary_cross_entropy(predictGender, targetGender)
    return loss

def lossEthnicity(predictEthnicity, targetEthnicity):
    loss = F.cross_entropy(predictEthnicity, targetEthnicity)
    return loss

def lossFunction(predictAge, predictGender, predictEthnicity, targetAge, targetGender, targetEthnicity):
    alpha = 1/3 # weight for age prediction
    beta = 1/3 # weight for gender prediction
    gamma = 1/3 # weight for ethncity prediction
    ageLoss = lossAge(predictAge, targetAge)
    genderLoss = lossGender(predictGender, targetGender)
    ethnicityLoss = lossEthnicity(predictEthnicity, targetEthnicity)
    totalLoss = alpha * ageLoss + beta * genderLoss + gamma * ethnicityLoss
    return totalLoss

In [77]:
def trainNetwork(model, optimizer, lossFunction, trainLoader, validLoader, epochs, device):
    model.train()
    for epoch in tqdm(range(1, epochs + 1)):
        
        ### TRAINING ###
        trainLoss = 0
        correctTrain = 0
        totalTrain = 0
        for batch_nr, (images, labels) in enumerate(trainLoader):
            # Move data to GPU (if exists)
            images, labels = images.to(device), labels.to(device)  

            # Predict
            agePredictions, genderPredictions, ethnicityPredictions = model(images)

            # Get loss and backpropogate
            loss = lossFunction(agePredictions, genderPredictions, ethnicityPredictions, 
                                labels[:, 0].view(-1, 1), labels[:, 1].view(-1, 1), labels[:, 2].view(-1, 1))
            loss.backward()

            # Optimize parameters (weights and biases) and remove gradients after
            optimizer.step() 
            optimizer.zero_grad()

            # Save loss for whole epoch
            trainLoss += loss.item()
            
            # Calculate training accuracy
            # _, predictions = torch.max(predictions, 1) 
            # correctTrain += (predictions == labels).sum().item() 
            # totalTrain += len(images)

            #print(f'Epoch [{epoch+1}/{epochs}] Batch [{batch_nr}/{len(trainLoader)}]')
        

        trainLoss /= len(trainLoader)
        trainAccuracy = 100 * correctTrain / totalTrain

        ### VALIDATION ###
        validLoss = 0
        correctValid = 0
        totalValid = 0
        for batch_nr, (images, labels) in enumerate(validLoader):
            # Move data to GPU (if exists)
            images, labels = images.to(device), labels.to(device) 

            # Predict            
            agePredictions, genderPredictions, ethnicityPredictions = model(images)

            # Get loss
            loss = lossFunction(agePredictions, genderPredictions, ethnicityPredictions, 
                                labels[:, 0].view(-1, 1), labels[:, 1].view(-1, 1), labels[:, 2].view(-1, 1))

            # Save loss for whole epoch
            validLoss += loss.item()

            # Calculate vaildation accuracy
            #_, predictions = torch.max(predictions, 1) 
            #correctValid += (predictions == labels).sum().item() 
            #totalValid += len(images)

            #print(f'Epoch [{epoch+1}/{epochs}] Batch [{batch_nr}/{len(validLoader)}]')

        validLoss /= len(validLoader)
        validAccuracy = 100 * correctValid / totalValid

        # Print reuslt of epoch
        print(f'Epoch [{epoch}/{epochs}] \t Training Loss: {round(trainLoss, 4)} \t Validation Loss: {round(validLoss, 4)} \t Traning Acc: {round(trainAccuracy, 2)}% \t Validation Acc: {round(validAccuracy, 2)}%')

In [78]:
epochs = 1
learningRate = 1e-3
resnetModel = ResNetModel().to(device)

optimizer = torch.optim.SGD(resnetModel.parameters(), lr=learningRate)
trainNetwork(resnetModel, optimizer, lossFunction, trainLoader, validLoader, epochs, device)

KeyboardInterrupt: 

In [None]:
!nvcc --version

'nvcc' is not recognized as an internal or external command,
operable program or batch file.
