In [None]:
import os
import glob
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import timm
from tqdm import tqdm

import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import mean_absolute_error
from datetime import datetime

from inference.models import AgeEstimatorModel

In [None]:
SIZE=48
BATCH_SIZE=200
EPOCHS=80
LR=0.0005

DATASET_PATH="../data/UTKFace_48"
LOG_PATH="../../logs/age"
MODEL_PATH="models/video/age_model_torch.pth"
WEIGHTS_PATH="models/video/age_model_weights.pth"
TEST_IMAGE_PATH="../data/face_recognition_images/person1.1.jpg"
TIME_FORMAT="%d-%m-%Y; %H:%M:%S"

In [None]:
transform = transforms.Compose([
    transforms.Resize((SIZE, SIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = AgeEstimatorModel().to(device)

In [None]:
class UTKFaceDataset(Dataset):
    def __init__(self, directory, transform=None):
        self.files = glob.glob(os.path.join(directory, '*.jpg'))
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_name = self.files[idx]
        image = Image.open(img_name)
        filename = img_name.split('/')[-1]
        age = int(filename.split('_')[0])

        if self.transform:
            image = self.transform(image)

        return image, age

In [None]:
dataset = UTKFaceDataset(directory='../data/UTKFace_48', transform=transform)
train_size = int(0.85 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
# Использование встроенной функции для анализа классов
class_counts = torch.zeros(90)

for _, labels in train_loader:
    class_counts += torch.bincount(labels, minlength=90)

print("Количество экземпляров каждого класса:")
for i, count in enumerate(class_counts):
    print(f"Класс {i}: {int(count)} экземпляров")

In [None]:
loss_fun = nn.MSELoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=LR)

In [None]:
writer = SummaryWriter(log_dir=LOG_PATH + "/" + datetime.now().strftime(TIME_FORMAT))

for epoch in tqdm(range(EPOCHS)):  # проход по датасету несколько раз
    model.train()
    running_loss = 0.0
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fun(outputs.view(-1), labels.float())
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    writer.add_scalar('Metrics/epoch_loss', running_loss  / len(train_loader), epoch)

    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for i, (images, labels) in enumerate(test_loader):
            images, labels = images.to(device), labels
            outputs = model(images)
            all_preds.extend(outputs.tolist())
            all_labels.extend(labels.tolist())
    err = mean_absolute_error(all_labels, all_preds)
    writer.add_scalar('Metrics/MAE', err, epoch)

print('Finished Training')

In [None]:
torch.save(model, MODEL_PATH)
torch.save(model.state_dict(), WEIGHTS_PATH)

In [None]:
# model = torch.load(MODEL_PATH)
model.load_state_dict(torch.load(WEIGHTS_PATH))
model.eval()

image = Image.open(TEST_IMAGE_PATH)
image = transform(image)
image = image.to(device)
image = image.unsqueeze(0)

with torch.no_grad():
    output = model(image)
    predicted_age = output.item()

print(f'Predicted Age: {predicted_age}')