In [4]:
import numpy as np
import pandas as pd
from glob import glob
from os.path import join
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torchvision
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomHorizontalFlip, RandomRotation
import torch.optim as optim

# Define the CNN model
class AgePredictionModel(nn.Module):
    def __init__(self):
        super(AgePredictionModel, self).__init__()
        self.features = torchvision.models.resnet18(pretrained=True)
        self.features.fc = nn.Linear(512, 1)  # Replace the final fully connected layer for age prediction

    def forward(self, x):
        x = self.features(x)
        return x.view(x.size(0))

# Define dataset class
class AgeDataset(torch.utils.data.Dataset):
    def __init__(self, data_path, annot_path, train=True, transform=None):
        super(AgeDataset, self).__init__()
        self.annot_path = annot_path
        self.data_path = data_path
        self.train = train
        self.ann = pd.read_csv(annot_path)
        self.files = self.ann['file_id']
        if train:
            self.ages = self.ann['age']
        self.transform = transform

    @staticmethod    
    def _convert_image_to_rgb(image):
        return image.convert("RGB")

    def read_img(self, file_name):
        im_path = join(self.data_path,file_name)   
        img = Image.open(im_path)
        img = self._convert_image_to_rgb(img)
        return img

    def __getitem__(self, index):
        file_name = self.files[index]
        img = self.read_img(file_name)
        if self.transform:
            img = self.transform(img)
        if self.train:
            age = self.ages[index]
            return img, age
        else:
            return img

    def __len__(self):
        return len(self.files)

# Define function for training the model
def train_model(train_loader, model, criterion, optimizer, num_epochs=25):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in tqdm(train_loader):
            inputs = inputs.to(device)
            labels = labels.float().to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
        epoch_loss = running_loss / len(train_loader.dataset)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")

# Load data
train_path = '/kaggle/input/smai-24-age-prediction/content/faces_dataset/train'
train_ann = '/kaggle/input/smai-24-age-prediction/content/faces_dataset/train.csv'
test_path = '/kaggle/input/smai-24-age-prediction/content/faces_dataset/test'
test_ann = '/kaggle/input/smai-24-age-prediction/content/faces_dataset/submission.csv'

# Define transformations for training and testing data
train_transform = Compose([
    Resize(256),
    RandomHorizontalFlip(),
    RandomRotation(10),
    CenterCrop(224),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_transform = Compose([
    Resize(224),
    CenterCrop(224),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Create datasets and loaders
train_dataset = AgeDataset(train_path, train_ann, train=True, transform=train_transform)
test_dataset = AgeDataset(test_path, test_ann, train=False, transform=test_transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

# Initialize the model, loss function, and optimizer
model = AgePredictionModel()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model
train_model(train_loader, model, criterion, optimizer, num_epochs=25)

# Function for making predictions
@torch.no_grad()
def predict(loader, model):
    model.eval()
    predictions = []

    for img in tqdm(loader):
        img = img.to(device)
        pred = model(img)
        predictions.extend(pred.flatten().cpu().detach().tolist())

    return predictions

# Generate predictions
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
preds = predict(test_loader, model)

# Save predictions to submission CSV file
submit = pd.read_csv('/kaggle/input/smai-24-age-prediction/content/faces_dataset/submission.csv')
submit['age'] = preds
submit.to_csv('submissions.csv', index=False)


100%|██████████| 334/334 [01:53<00:00,  2.95it/s]


Epoch 1/25, Loss: 163.9712


100%|██████████| 334/334 [01:52<00:00,  2.96it/s]


Epoch 2/25, Loss: 69.9095


100%|██████████| 334/334 [01:56<00:00,  2.88it/s]


Epoch 3/25, Loss: 62.8266


100%|██████████| 334/334 [01:54<00:00,  2.91it/s]


Epoch 4/25, Loss: 56.7412


100%|██████████| 334/334 [01:51<00:00,  3.00it/s]


Epoch 5/25, Loss: 53.3799


100%|██████████| 334/334 [01:51<00:00,  3.00it/s]


Epoch 6/25, Loss: 51.2282


100%|██████████| 334/334 [01:51<00:00,  3.00it/s]


Epoch 7/25, Loss: 47.6512


100%|██████████| 334/334 [01:51<00:00,  2.99it/s]


Epoch 8/25, Loss: 45.4417


100%|██████████| 334/334 [01:51<00:00,  3.00it/s]


Epoch 9/25, Loss: 43.2646


100%|██████████| 334/334 [01:51<00:00,  2.98it/s]


Epoch 10/25, Loss: 41.2698


100%|██████████| 334/334 [01:53<00:00,  2.96it/s]


Epoch 11/25, Loss: 40.3755


100%|██████████| 334/334 [01:53<00:00,  2.94it/s]


Epoch 12/25, Loss: 36.9715


100%|██████████| 334/334 [01:51<00:00,  2.99it/s]


Epoch 13/25, Loss: 35.8230


100%|██████████| 334/334 [01:52<00:00,  2.97it/s]


Epoch 14/25, Loss: 33.1554


100%|██████████| 334/334 [01:51<00:00,  2.99it/s]


Epoch 15/25, Loss: 32.3381


100%|██████████| 334/334 [01:52<00:00,  2.97it/s]


Epoch 16/25, Loss: 30.6750


100%|██████████| 334/334 [01:52<00:00,  2.98it/s]


Epoch 17/25, Loss: 29.4488


100%|██████████| 334/334 [01:52<00:00,  2.97it/s]


Epoch 18/25, Loss: 27.2747


100%|██████████| 334/334 [01:55<00:00,  2.90it/s]


Epoch 19/25, Loss: 26.0457


100%|██████████| 334/334 [01:53<00:00,  2.95it/s]


Epoch 20/25, Loss: 25.0704


100%|██████████| 334/334 [01:51<00:00,  3.01it/s]


Epoch 21/25, Loss: 22.3764


100%|██████████| 334/334 [01:53<00:00,  2.95it/s]


Epoch 22/25, Loss: 21.8617


100%|██████████| 334/334 [01:51<00:00,  2.99it/s]


Epoch 23/25, Loss: 20.5774


100%|██████████| 334/334 [01:51<00:00,  2.99it/s]


Epoch 24/25, Loss: 19.9180


100%|██████████| 334/334 [01:52<00:00,  2.98it/s]


Epoch 25/25, Loss: 19.3136


100%|██████████| 31/31 [00:08<00:00,  3.59it/s]
