In [2]:
import os
import glob
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import timm
from tqdm import tqdm

import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F


In [4]:
SIZE=48
BATCH_SIZE=512

In [5]:
class UTKFaceDataset(Dataset):
    def __init__(self, directory, transform=None):
        self.files = glob.glob(os.path.join(directory, '*.jpg'))
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_name = self.files[idx]
        image = Image.open(img_name)
        filename = img_name.split('/')[-1]
        age = int(filename.split('_')[0])  # Предполагается, что имя файла начинается с возраста

        if self.transform:
            image = self.transform(image)

        return image, age

In [6]:
transform = transforms.Compose([
    transforms.Resize((SIZE, SIZE)),
    transforms.ToTensor(),
])

dataset = UTKFaceDataset(directory='../data/UTKFace_48', transform=transform)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [7]:
class AgeEstimatorModel(nn.Module):
    def __init__(self):
        super(AgeEstimatorModel, self).__init__()
        self.base_model = timm.create_model('efficientnetv2_rw_s', pretrained=True)
        self.base_model.classifier = nn.Linear(self.base_model.classifier.in_features, 1)
        
    def forward(self, x):
        return self.base_model(x)

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = AgeEstimatorModel().to(device)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [22]:
for epoch in range(10):  # проход по датасету несколько раз
    running_loss = 0.0
    for i, data in tqdm(enumerate(train_loader)):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs.view(-1), labels.float())
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'Epoch {epoch + 1}, loss: {running_loss / len(train_loader)}')
print('Finished Training')

68it [00:14,  4.75it/s]


Epoch 1, loss: 18.261461419217728


68it [00:13,  5.08it/s]


Epoch 2, loss: 6.775937879786772


68it [00:13,  4.96it/s]


Epoch 3, loss: 5.38443252619575


68it [00:13,  5.04it/s]


Epoch 4, loss: 4.569055662435644


68it [00:13,  5.03it/s]


Epoch 5, loss: 4.062251185669618


68it [00:13,  4.98it/s]


Epoch 6, loss: 3.61576588714824


68it [00:13,  5.02it/s]


Epoch 7, loss: 3.30544370062211


68it [00:13,  5.00it/s]


Epoch 8, loss: 3.048522752874038


68it [00:13,  4.89it/s]


Epoch 9, loss: 2.8573475690448986


68it [00:13,  5.01it/s]

Epoch 10, loss: 2.723175013766569
Finished Training





In [23]:
total_loss = 0.0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs.view(-1), labels.float())
        total_loss += loss.item()

print('Средняя ошибка возраста на тестовом наборе: ', total_loss / len(test_loader))

Средняя ошибка возраста на тестовом наборе:  5.724269221810734


In [ ]:
torch.save(model, '../../age_model_torch.pth')
torch.save(model.state_dict(), '../../age_model_weights.pth')

In [9]:
# model = AgeEstimatorModel()  # Создайте экземпляр вашей модели
# model.load_state_dict(torch.load('../../age_model_torch.pth'))

model = torch.load('../../age_model_torch.pth')

model.eval()  # Переведите модель в режим оценки

# Загрузите изображение
image_path = '../data/face_recognition_images/person2.1.jpg'
image = Image.open(image_path)

# Примените преобразования к изображению
image = transform(image)
image = image.to(device)
image = image.unsqueeze(0)  # Добавьте дополнительное измерение, так как модель ожидает пакет изображений

# Сделайте предсказание
with torch.no_grad():
    output = model(image)
    predicted_age = output.item()  # Получите предсказанный возраст как число

print(f'Predicted Age: {predicted_age}')

Predicted Age: 24.406728744506836
