In [None]:
################ Code Description ##################

# The below code for age prediciton performs the following tasks:
# Defines a custom dataset (AgeDataset) for loading images and their corresponding ages.
# Performs data augmentation using transforms for train data.
# Sets paths for the training and testing data.
# Splits the training data into training and validation sets.
# Creates datasets for training, validation and testing.
# Creates data loaders for training, validation, and testing.
# Loads a pre-trained ResNet18-based model (ResNetAgePredictor) for age prediction.
# Defines loss criterion (MAE) and Adam Optimizer with weigh decay.
# Defines an EarlyStopping class for stopping early during training.
# Trains and validates the model with early stopping.
# Saves the best model.
# Uses the saved model from validation to make predictions on the test dataset.
# Creates a submission CSV file with the predicted ages.

In [None]:
import numpy as np
import pandas as pd
from glob import glob
from os.path import join
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torchvision
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomHorizontalFlip, RandomRotation, ColorJitter
import torch.optim as optim
from torchvision import models
from torch.utils.data import DataLoader, random_split



In [None]:
class AgeDataset(torch.utils.data.Dataset):
    def __init__(self, data_path, annot_path, train=True):
        super(AgeDataset, self).__init__()

        self.annot_path = annot_path
        self.data_path = data_path
        self.train = train

        self.ann = pd.read_csv(annot_path)
        self.files = self.ann['file_id']
        if train:
            self.ages = self.ann['age']
        self.transform = self._transform(224, train=train)

    @staticmethod
    def _convert_image_to_rgb(image):
        return image.convert("RGB")

    def _transform(self, n_px, train=True):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]

        if train:
            # Augmentation only for training data
            return Compose([
                Resize(n_px),
                RandomHorizontalFlip(p=0.5),
                RandomRotation(15),
                ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
                self._convert_image_to_rgb,
                ToTensor(),
                Normalize(mean, std),
            ])
        else:
            # No augmentation for validation/test data
            return Compose([
                Resize(n_px),
                self._convert_image_to_rgb,
                ToTensor(),
                Normalize(mean, std),
            ])

    def read_img(self, file_name):
        im_path = join(self.data_path, file_name)
        img = Image.open(im_path)
        img = self.transform(img)
        return img

    def __getitem__(self, index):
        file_name = self.files[index]
        img = self.read_img(file_name)
        if self.train:
            age = self.ages[index]
            return img, age
        else:
            return img

    def __len__(self):
        return len(self.files)




In [None]:
#define paths for datasets and create train,validation and test datasets.

train_path = '/kaggle/input/smai-24-age-prediction/content/faces_dataset/train'
train_ann = '/kaggle/input/smai-24-age-prediction/content/faces_dataset/train.csv'
train_dataset = AgeDataset(train_path, train_ann, train=True)


test_path = '/kaggle/input/smai-24-age-prediction/content/faces_dataset/test'
test_ann = '/kaggle/input/smai-24-age-prediction/content/faces_dataset/submission.csv'
test_dataset = AgeDataset(test_path, test_ann, train=False)

# Split training data into training and validation sets
num_train = len(train_dataset)
num_val = int(0.2 * num_train)
train_dataset, val_dataset = random_split(train_dataset, [num_train - num_val, num_val])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Define the device to use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



In [None]:
# ResNet Model for Age Prediction Task

class ResNetAgePredictor(nn.Module):
    def __init__(self, pretrained=True):
        super(ResNetAgePredictor, self).__init__()
        # Load a pre-trained ResNet-18 model
        self.model = models.resnet18(pretrained=pretrained)
        num_features = self.model.fc.in_features  # Get the number of inputs for the fully connected layer

        # Replace the fully connected layer with a new one with a single output
        self.model.fc = nn.Linear(num_features, 1)

    def forward(self, x):
        return self.model(x)



In [None]:
# Define EarlyStopping class

class EarlyStopping:
    def __init__(self, patience=5, verbose=False, delta=0, path='checkpoint.pth', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 5
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pth'
            trace_func (function): trace print function.
                            Default: print
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss




In [None]:
# Initialize model, optimizer and loss function
model = ResNetAgePredictor().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-6)
criterion = nn.L1Loss()  # MAE
num_epochs = 30 # number of epochs for training



Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 133MB/s] 


In [None]:
# Training Function
def train_and_evaluate(model, train_loader, val_loader, optimizer, criterion, num_epochs, early_stopping=None):
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        for inputs, ages in train_loader:
            inputs, ages = inputs.to(device), ages.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, ages.float().unsqueeze(1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        val_loss = 0.0
        model.eval()
        with torch.no_grad():
            for inputs, ages in val_loader:
                inputs, ages = inputs.to(device), ages.to(device)
                outputs = model(inputs)
                rounded_outputs = torch.round(outputs)
                loss = criterion(rounded_outputs, ages.float().unsqueeze(1))
                val_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        print(f'Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')

        if early_stopping:
            early_stopping(avg_val_loss, model)
            if early_stopping.early_stop:
                print("Early stopping")
                break



In [None]:
# Initialize EarlyStopping
early_stopping = EarlyStopping(patience=6, verbose=True, path='best_model.pth')

# Call the train function with early stopping
train_and_evaluate(model, train_loader, val_loader, optimizer, criterion, num_epochs, early_stopping)

# Load the best saved model
model.load_state_dict(torch.load('best_model.pth'))

# Save model
torch.save(model.state_dict(), 'age_prediction_resnet18_final.pth')



Epoch 1: Train Loss: 21.3804, Val Loss: 13.2454
Validation loss decreased (inf --> 13.245442).  Saving model ...
Epoch 2: Train Loss: 9.3542, Val Loss: 6.8581
Validation loss decreased (13.245442 --> 6.858061).  Saving model ...
Epoch 3: Train Loss: 5.8966, Val Loss: 5.5097
Validation loss decreased (6.858061 --> 5.509710).  Saving model ...
Epoch 4: Train Loss: 5.3805, Val Loss: 5.5974
EarlyStopping counter: 1 out of 5
Epoch 5: Train Loss: 5.1285, Val Loss: 5.8407
EarlyStopping counter: 2 out of 5
Epoch 6: Train Loss: 4.9864, Val Loss: 5.0883
Validation loss decreased (5.509710 --> 5.088280).  Saving model ...
Epoch 7: Train Loss: 4.7874, Val Loss: 4.9715
Validation loss decreased (5.088280 --> 4.971464).  Saving model ...
Epoch 8: Train Loss: 4.7901, Val Loss: 5.0844
EarlyStopping counter: 1 out of 5
Epoch 9: Train Loss: 4.6209, Val Loss: 4.9231
Validation loss decreased (4.971464 --> 4.923147).  Saving model ...
Epoch 10: Train Loss: 4.4666, Val Loss: 4.8694
Validation loss decrease

In [None]:
###### SUBMISSION CSV FILE #####

@torch.no_grad
def predict(loader, model):
    model.eval()
    predictions = []

    for img in tqdm(loader):
        img = img.to(device)

        pred = model(img)
        predictions.extend(pred.flatten().round().int().detach().cpu().numpy())

    return predictions

preds = predict(test_loader, model)

submit = pd.read_csv('/kaggle/input/smai-24-age-prediction/content/faces_dataset/submission.csv')
submit['age'] = preds
submit.head()

submit.to_csv('baseline.csv',index=False)

100%|██████████| 31/31 [00:14<00:00,  2.17it/s]
