## Overview:
### Age prediction as a CV task is useful for various real-world applications, such as age-restricted content filtering, personalized marketing targeting specific age demographics, enhancing security systems with age verification, and assisting in medical diagnostics and age-related research. In this project, various CNN based and ViT based models are built and compared to predict the age of a person, given the image of their face.

In [None]:
import pandas as pd
from PIL import Image
from tqdm import tqdm
from os.path import join
import matplotlib.pyplot as plt

import torch
import wandb
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torchvision.transforms import Compose, Resize, ToTensor, Normalize

In [None]:
class AgeDataset(torch.utils.data.Dataset):
    def __init__(self, data_path, annot_path, train=True):
        super(AgeDataset, self).__init__()

        self.annot_path = annot_path
        self.data_path = data_path
        self.train = train

        self.ann = pd.read_csv(annot_path)
        self.files = self.ann['file_id']
        if train:
            self.ages = self.ann['age']
        self.transform = self._transform(224)

    @staticmethod    
    def _convert_image_to_rgb(image):
        return image.convert("RGB")

    def _transform(self, n_px):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        return Compose([
            Resize(n_px),
            self._convert_image_to_rgb,
            ToTensor(),
            Normalize(mean, std),
        ])
    
    # ######### Does not give better results
    # def _transform(self, n_px):     
    #     mean = [0.485, 0.456, 0.406]
    #     std = [0.229, 0.224, 0.225]
    #     return Compose([
    #         Resize(n_px),
    #         self._convert_image_to_rgb,
    #         RandomHorizontalFlip(p=0.5),
    #         RandomRotation(degrees=10),
    #         RandomApply([ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)], p=0.5),
    #         ToTensor(),
    #         Normalize(mean, std)
    #     ])
    # #########

    def read_img(self, file_name):
        im_path = join(self.data_path,file_name)   
        img = Image.open(im_path)
        img = self.transform(img)
        return img

    def __getitem__(self, index):
        file_name = self.files[index]
        img = self.read_img(file_name)
        if self.train:
            age = self.ages[index]
            return img, age
        else:
            return img

    def __len__(self):
        return len(self.files)

In [None]:
train_path = 'faces_dataset/train'
train_ann = 'faces_dataset/train.csv'
train_dataset = AgeDataset(train_path, train_ann, train=True)

test_path = 'faces_dataset/test'
test_ann = 'faces_dataset/submission.csv'
test_dataset = AgeDataset(test_path, test_ann, train=False)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=10)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=10)

In [None]:
class AgeModel(nn.Module):
    def __init__(self, model):
        super(AgeModel, self).__init__()
        if model == 'resnet34':
            self.model = models.resnet34(weights="ResNet34_Weights.IMAGENET1K_V1")
        elif model == 'resnet101':
            self.model = models.resnet101(weights="ResNet101_Weights.IMAGENET1K_V2")
        elif model == 'resnet18':
            self.model = models.resnet18(weights="ResNet18_Weights.IMAGENET1K_V1")
        elif model == 'resnet50':
            self.model = models.resnet50(weights="ResNet50_Weights.IMAGENET1K_V1")
        elif model == 'swin_v2_s':
            self.model = models.swin_v2_s(weights='IMAGENET1K_V1')
        elif model == 'swin_v2_t':
            self.model = models.swin_v2_t(weights='IMAGENET1K_V1')
        elif model == 'swin_v2_b':
            self.model = models.swin_v2_b(weights='IMAGENET1K_V1')
        elif model == 'convnext_small':
            self.model = models.convnext_small(weights='IMAGENET1K_V1')
        elif model == 'convnext_tiny':
            self.model = models.convnext_tiny(weights='IMAGENET1K_V1')
        elif model == 'vit_b_16':
            self.model = models.vit_b_16(weights='IMAGENET1K_V1')
        # self.fc = nn.Linear(1000, 1)

        # #########
        self.fc1 = nn.Linear(1000, 100)
        self.fc2 = nn.Linear(100, 1)
        # #########

    def forward(self, x):
        x = self.model(x)
        # x = self.fc(x)

        # #########
        x = self.fc1(x)
        x = self.fc2(x)
        # #########

        return x

In [None]:
models_arr = ['swin_v2_s', 'swin_v2_t', 'swin_v2_b', 'convnext_small', 'convnext_tiny', 'vit_b_16', 'resnet18', 'resnet34', 'resnet50', 'resnet101']
weight_decays = [1e-4, 1e-3]
lrs = [1e-4, 1e-3, 1e-2]

for lr in lrs:
    for model_name in models_arr:
        for weight_decay in weight_decays:
            wandb.init(project='Age_Prediction', name=f'{model_name}', config={"model": model_name, "lr": lr})

            model = AgeModel(model=model_name)
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            print(device)

            # ########## Did not work out, gave bad results in all scenarios, started off with a bad loss and ended up with a bad one, didnt learn effectively
            # for param in model.parameters():
            #     param.requires_grad = False

            # # Unfreeze the last few layers for fine-tuning
            # for param in model.model.head.parameters():
            #     param.requires_grad = True
            # ##########

            model.to(device)

            criterion = nn.L1Loss()         # MAE Loss
            optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
            scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10)   

            # Train the model
            model.train()

            num_epochs = 100
            min_loss = float('inf')
            for epoch in tqdm(range(num_epochs), desc="Epochs"):
                wandb.log({"epoch": epoch})
                running_loss = 0.0
                for i, data in tqdm(enumerate(train_loader, 0), desc="Batches", total=len(train_loader)):
                    inputs, labels = data[0].to(device), data[1].to(device)
                    optimizer.zero_grad()
                    outputs = model(inputs)
                    loss = criterion(outputs, labels.unsqueeze(1).float())
                    loss.backward() 
                    optimizer.step()
                    running_loss += loss.item()

                wandb.log({"loss": running_loss/len(train_loader), "epoch": epoch})

                if running_loss < min_loss:
                    mn_loss = running_loss
                    torch.save(model.state_dict(), f'Age_Prediction_{model_name}.pth')

                scheduler.step()
                print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")
            

In [None]:
###### SUBMISSION CSV FILE #####

for lr in lrs:
    for model_name in models_arr:
        for weight_decay in weight_decays:
            # Load the model
            model = AgeModel(model=model_name)
            model.load_state_dict(torch.load(f'Age_Prediction_{model_name}.pth'))
            model.to(device)

            @torch.no_grad
            def predict(loader, model):
                model.eval()
                predictions = []

                for img in tqdm(loader):
                    img = img.to(device)
                    pred = model(img)
                    predictions.extend(pred.flatten().detach().tolist())

                return predictions

            preds = predict(test_loader, model)

            submit = pd.read_csv('predictions.csv')
            submit['age'] = preds
            submit.head()

            submit.to_csv(f'predictions_{model_name}.csv',index=False)

            torch.cuda.empty_cache()
            wandb.finish()      