In [1]:
import os
import glob
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import timm
from tqdm import tqdm

import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import precision_score, recall_score, f1_score
from datetime import datetime


2024-01-08 22:45:36.018687: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-01-08 22:45:36.042221: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-08 22:45:36.042253: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-08 22:45:36.042276: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-01-08 22:45:36.047710: I tensorflow/core/platform/cpu_feature_g

In [2]:
SIZE=48
BATCH_SIZE=512
NUM_CLASSES=5

In [3]:
class UTKFaceDataset(Dataset):
    def __init__(self, directory, transform=None):
        self.files = glob.glob(os.path.join(directory, '*.jpg'))
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_name = self.files[idx]
        image = Image.open(img_name)
        filename = img_name.split('/')[-1]
        race = int(filename.split('_')[2])  # Предполагается, что имя файла начинается с возраста

        if self.transform:
            image = self.transform(image)

        return image, race

In [4]:
transform = transforms.Compose([
    transforms.Resize((SIZE, SIZE)),
    transforms.ToTensor(),
])

dataset = UTKFaceDataset(directory='../data/UTKFace_48', transform=transform)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [5]:
class RaceEstimatorModel(nn.Module):
    def __init__(self):
        super(RaceEstimatorModel, self).__init__()
        self.base_model = timm.create_model('efficientnetv2_rw_s', pretrained=True)
        self.base_model.classifier = nn.Linear(self.base_model.classifier.in_features, NUM_CLASSES)
        
    def forward(self, x):
        return self.base_model(x)

In [6]:
# class FocalLoss(nn.Module):
#     def __init__(self, alpha=1, gamma=2, reduction='mean'):
#         super(FocalLoss, self).__init__()
#         self.alpha = alpha
#         self.gamma = gamma
#         self.reduction = reduction
# 
#     def forward(self, inputs, targets):
#         BCE_loss = F.cross_entropy(inputs, targets, reduction='none')
#         pt = torch.exp(-BCE_loss)
#         F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
# 
#         if self.reduction == 'mean':
#             return torch.mean(F_loss)
#         elif self.reduction == 'sum':
#             return torch.sum(F_loss)
#         else:
#             return F_loss

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = RaceEstimatorModel().to(device)
loss_fun = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)

In [8]:
writer = SummaryWriter(log_dir="run/race", filename_suffix=datetime.now().strftime("%Y%m%d-%H%M%S"))

for epoch in tqdm(range(100)):  # проход по датасету несколько раз
    model.train()
    running_loss = 0.0
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fun(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        writer.add_scalar('Metrics/epoch_loss', running_loss  / len(train_loader), epoch)
    
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for i, (images, labels) in enumerate(test_loader):
            images, labels = images.to(device), labels
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.tolist())
            all_labels.extend(labels.tolist())
    precision = precision_score(all_labels, all_preds, average="weighted")
    recall = recall_score(all_labels, all_preds, average="weighted")
    f1 = f1_score(all_labels, all_preds, average="weighted")
    writer.add_scalar('Metrics/precision', precision, epoch)
    writer.add_scalar('Metrics/recall', recall, epoch)
    writer.add_scalar('Metrics/f1', f1, epoch)
print('Finished Training')

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1, loss: 1.2893175681432087


  1%|          | 1/100 [00:13<21:36, 13.10s/it]

Epoch 2, loss: 0.8608646906084485


  2%|▏         | 2/100 [00:25<20:53, 12.79s/it]

Epoch 3, loss: 0.6129182991054323


  3%|▎         | 3/100 [00:38<20:41, 12.79s/it]

Epoch 4, loss: 0.3917626639207204


  4%|▍         | 4/100 [00:51<20:23, 12.74s/it]

Epoch 5, loss: 0.292584665119648


  5%|▌         | 5/100 [01:04<20:21, 12.86s/it]

Epoch 6, loss: 0.21704095415771008


  6%|▌         | 6/100 [01:19<21:25, 13.67s/it]

Epoch 7, loss: 0.17404347006231546


  7%|▋         | 7/100 [01:33<21:36, 13.94s/it]

Epoch 8, loss: 0.18946616496476862


  8%|▊         | 8/100 [01:48<21:40, 14.13s/it]

Epoch 9, loss: 0.14151183091517952


  9%|▉         | 9/100 [02:03<21:37, 14.26s/it]

Epoch 10, loss: 0.07117806411244804


 10%|█         | 10/100 [02:17<21:30, 14.34s/it]

Epoch 11, loss: 0.20656033219873077


 11%|█         | 11/100 [02:32<21:21, 14.40s/it]

Epoch 12, loss: 0.1775914431653089


 12%|█▏        | 12/100 [02:46<21:10, 14.44s/it]

Epoch 13, loss: 0.07506247982382774


 13%|█▎        | 13/100 [03:01<21:00, 14.48s/it]

Epoch 14, loss: 0.13698453192288676


 14%|█▍        | 14/100 [03:15<20:31, 14.32s/it]

Epoch 15, loss: 0.14837185795315438


 15%|█▌        | 15/100 [03:27<19:29, 13.76s/it]

Epoch 16, loss: 0.17036243714392185


 16%|█▌        | 16/100 [03:40<18:44, 13.38s/it]

Epoch 17, loss: 0.23356894724484947


 17%|█▋        | 17/100 [03:52<18:09, 13.13s/it]

Epoch 18, loss: 0.18709010403189394


 18%|█▊        | 18/100 [04:05<17:45, 12.99s/it]

Epoch 19, loss: 0.08052876188109319


 19%|█▉        | 19/100 [04:17<17:24, 12.90s/it]

Epoch 20, loss: 0.14783253419833878


 20%|██        | 20/100 [04:30<17:06, 12.83s/it]

Epoch 21, loss: 0.11557703009910053


 21%|██        | 21/100 [04:43<16:47, 12.75s/it]

Epoch 22, loss: 0.08102665619096822


 22%|██▏       | 22/100 [04:55<16:31, 12.71s/it]

Epoch 23, loss: 0.12701963229725757


 23%|██▎       | 23/100 [05:08<16:14, 12.66s/it]

Epoch 24, loss: 0.1606495693946878


 24%|██▍       | 24/100 [05:21<16:03, 12.68s/it]

Epoch 25, loss: 0.06677230685535404


 25%|██▌       | 25/100 [05:33<15:48, 12.64s/it]

Epoch 26, loss: 0.11170771170873195


 26%|██▌       | 26/100 [05:46<15:33, 12.61s/it]

Epoch 27, loss: 0.10910134089903699


 27%|██▋       | 27/100 [05:58<15:19, 12.59s/it]

Epoch 28, loss: 0.07216010165090363


 28%|██▊       | 28/100 [06:11<15:05, 12.58s/it]

Epoch 29, loss: 0.10396802647867137


 29%|██▉       | 29/100 [06:23<14:51, 12.55s/it]

Epoch 30, loss: 0.18225225061178207


 30%|███       | 30/100 [06:36<14:38, 12.55s/it]

Epoch 31, loss: 0.10038576544158989


 31%|███       | 31/100 [06:48<14:25, 12.55s/it]

Epoch 32, loss: 0.09651382722788387


 32%|███▏      | 32/100 [07:01<14:13, 12.55s/it]

Epoch 33, loss: 0.07574807298887107


 33%|███▎      | 33/100 [07:13<14:00, 12.55s/it]

Epoch 34, loss: 0.1313553517166939


 34%|███▍      | 34/100 [07:26<13:48, 12.55s/it]

Epoch 35, loss: 0.04710060414961643


 35%|███▌      | 35/100 [07:39<13:37, 12.57s/it]

Epoch 36, loss: 0.07260201409614335


 36%|███▌      | 36/100 [07:51<13:24, 12.57s/it]

Epoch 37, loss: 0.08840728188968366


 37%|███▋      | 37/100 [08:04<13:10, 12.56s/it]

Epoch 38, loss: 0.07986944515465035


 38%|███▊      | 38/100 [08:16<12:58, 12.56s/it]

Epoch 39, loss: 0.036072731043936476


 39%|███▉      | 39/100 [08:29<12:46, 12.56s/it]

Epoch 40, loss: 0.04216685835530774


 40%|████      | 40/100 [08:41<12:33, 12.56s/it]

Epoch 41, loss: 0.054255741557830736


 41%|████      | 41/100 [08:54<12:19, 12.54s/it]

Epoch 42, loss: 0.1330650616178496


 42%|████▏     | 42/100 [09:06<12:07, 12.54s/it]

Epoch 43, loss: 0.05123134459265404


 43%|████▎     | 43/100 [09:19<11:56, 12.57s/it]

Epoch 44, loss: 0.01967629580758512


 44%|████▍     | 44/100 [09:32<11:44, 12.59s/it]

Epoch 45, loss: 0.02680591497078745


 45%|████▌     | 45/100 [09:44<11:32, 12.59s/it]

Epoch 46, loss: 0.08109139407881433


 46%|████▌     | 46/100 [09:57<11:19, 12.58s/it]

Epoch 47, loss: 0.10626410584275921


 47%|████▋     | 47/100 [10:09<11:06, 12.57s/it]

Epoch 48, loss: 0.06068484614499741


 48%|████▊     | 48/100 [10:22<10:53, 12.56s/it]

Epoch 49, loss: 0.036316983608735934


 49%|████▉     | 49/100 [10:35<10:40, 12.56s/it]

Epoch 50, loss: 0.027149450518966962


 50%|█████     | 50/100 [10:47<10:27, 12.55s/it]

Epoch 51, loss: 0.06698343478557137


 51%|█████     | 51/100 [11:00<10:17, 12.60s/it]

Epoch 52, loss: 0.062222924186951585


 52%|█████▏    | 52/100 [11:12<10:04, 12.58s/it]

Epoch 53, loss: 0.08853771472867164


 53%|█████▎    | 53/100 [11:25<09:53, 12.64s/it]

Epoch 54, loss: 0.049542089794865914


 54%|█████▍    | 54/100 [11:38<09:41, 12.64s/it]

Epoch 55, loss: 0.08039667325404783


 55%|█████▌    | 55/100 [11:50<09:28, 12.62s/it]

Epoch 56, loss: 0.0698635059977985


 56%|█████▌    | 56/100 [12:03<09:14, 12.60s/it]

Epoch 57, loss: 0.03709390276991245


 57%|█████▋    | 57/100 [12:16<09:02, 12.63s/it]

Epoch 58, loss: 0.014597761305695813


 58%|█████▊    | 58/100 [12:28<08:49, 12.61s/it]

Epoch 59, loss: 0.04649609527162587


 59%|█████▉    | 59/100 [12:41<08:37, 12.61s/it]

Epoch 60, loss: 0.0414418642823067


 60%|██████    | 60/100 [12:53<08:25, 12.65s/it]

Epoch 61, loss: 0.04478182990310921


 61%|██████    | 61/100 [13:06<08:12, 12.63s/it]

Epoch 62, loss: 0.12836036282695001


 62%|██████▏   | 62/100 [13:19<08:01, 12.68s/it]

Epoch 63, loss: 0.08967906867878304


 63%|██████▎   | 63/100 [13:31<07:48, 12.66s/it]

Epoch 64, loss: 0.09904626806059645


 64%|██████▍   | 64/100 [13:44<07:35, 12.66s/it]

Epoch 65, loss: 0.0487688292697486


 65%|██████▌   | 65/100 [13:57<07:24, 12.70s/it]

Epoch 66, loss: 0.10086692588972962


 66%|██████▌   | 66/100 [14:09<07:10, 12.65s/it]

Epoch 67, loss: 0.12113326410245565


 67%|██████▋   | 67/100 [14:22<06:56, 12.63s/it]

Epoch 68, loss: 0.08571908353931373


 68%|██████▊   | 68/100 [14:35<06:43, 12.61s/it]

Epoch 69, loss: 0.03265733283478767


 69%|██████▉   | 69/100 [14:47<06:29, 12.58s/it]

Epoch 70, loss: 0.05956760778402289


 70%|███████   | 70/100 [15:00<06:17, 12.58s/it]

Epoch 71, loss: 0.09721196931786835


 71%|███████   | 71/100 [15:12<06:04, 12.58s/it]

Epoch 72, loss: 0.09293437743973401


 72%|███████▏  | 72/100 [15:25<05:52, 12.58s/it]

Epoch 73, loss: 0.05481473413399524


 73%|███████▎  | 73/100 [15:37<05:39, 12.57s/it]

Epoch 74, loss: 0.037989356425694294


 74%|███████▍  | 74/100 [15:50<05:29, 12.66s/it]

Epoch 75, loss: 0.0653121055284929


 75%|███████▌  | 75/100 [16:03<05:17, 12.72s/it]

Epoch 76, loss: 0.06711873337109056


 76%|███████▌  | 76/100 [16:16<05:06, 12.76s/it]

Epoch 77, loss: 0.03311887939667536


 77%|███████▋  | 77/100 [16:29<04:54, 12.79s/it]

Epoch 78, loss: 0.09183639532420784


 78%|███████▊  | 78/100 [16:42<04:42, 12.84s/it]

Epoch 79, loss: 0.11135643156659272


 79%|███████▉  | 79/100 [16:54<04:28, 12.78s/it]

Epoch 80, loss: 0.0381552892892311


 80%|████████  | 80/100 [17:07<04:15, 12.77s/it]

Epoch 81, loss: 0.027877455021047756


 81%|████████  | 81/100 [17:20<04:02, 12.75s/it]

Epoch 82, loss: 0.043825612377582326


 82%|████████▏ | 82/100 [17:33<03:49, 12.72s/it]

Epoch 83, loss: 0.0647504118258237


 83%|████████▎ | 83/100 [17:45<03:36, 12.71s/it]

Epoch 84, loss: 0.058478988965766296


 84%|████████▍ | 84/100 [17:58<03:23, 12.72s/it]

Epoch 85, loss: 0.050262620465623006


 85%|████████▌ | 85/100 [18:11<03:10, 12.67s/it]

Epoch 86, loss: 0.031470947769573994


 86%|████████▌ | 86/100 [18:23<02:57, 12.68s/it]

Epoch 87, loss: 0.04295679964383857


 87%|████████▋ | 87/100 [18:36<02:44, 12.67s/it]

Epoch 88, loss: 0.031470424961298704


 88%|████████▊ | 88/100 [18:48<02:31, 12.65s/it]

Epoch 89, loss: 0.05585595563752577


 89%|████████▉ | 89/100 [19:01<02:18, 12.63s/it]

Epoch 90, loss: 0.054197172039291926


 90%|█████████ | 90/100 [19:14<02:06, 12.61s/it]

Epoch 91, loss: 0.05601652377905945


 91%|█████████ | 91/100 [19:26<01:53, 12.62s/it]

Epoch 92, loss: 0.044152371540096484


 92%|█████████▏| 92/100 [19:39<01:41, 12.67s/it]

Epoch 93, loss: 0.040414956754021764


 93%|█████████▎| 93/100 [19:52<01:28, 12.63s/it]

Epoch 94, loss: 0.08794722522401975


 94%|█████████▍| 94/100 [20:04<01:15, 12.63s/it]

Epoch 95, loss: 0.03267548520428439


 95%|█████████▌| 95/100 [20:17<01:03, 12.61s/it]

Epoch 96, loss: 0.023412375238775793


 96%|█████████▌| 96/100 [20:29<00:50, 12.59s/it]

Epoch 97, loss: 0.12497913352369021


 97%|█████████▋| 97/100 [20:42<00:37, 12.58s/it]

Epoch 98, loss: 0.04308495453248421


 98%|█████████▊| 98/100 [20:54<00:25, 12.57s/it]

Epoch 99, loss: 0.047836438830321036


 99%|█████████▉| 99/100 [21:07<00:12, 12.63s/it]

Epoch 100, loss: 0.08680661149426466


100%|██████████| 100/100 [21:20<00:00, 12.80s/it]

Finished Training





In [9]:
# Функция для вычисления предсказаний
def get_predictions(model, loader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.tolist())
            all_labels.extend(labels.tolist())
    return all_labels, all_preds

# Получение предсказаний на валидационном наборе
val_labels, val_preds = get_predictions(model, test_loader)

# Вычисление метрик
precision = precision_score(val_labels, val_preds, average=None)
recall = recall_score(val_labels, val_preds, average=None)
f1 = f1_score(val_labels, val_preds, average=None)

print(f'Precision: {precision}')
print(f'Recall: {recall}')
print(f'F1 Score: {f1}')

Precision: [0.83517588 0.81269841 0.74839744 0.73513514 0.33152174]
Recall: [0.85362096 0.84674752 0.81217391 0.7129751  0.20962199]
F1 Score: [0.84429769 0.82937365 0.77898249 0.72388556 0.25684211]


In [10]:
torch.save(model, '../../models/race_model_torch.pth')
torch.save(model.state_dict(), '../../models/race_model_weights.pth')

In [11]:
# model = AgeEstimatorModel()  # Создайте экземпляр вашей модели
# model.load_state_dict(torch.load('../../race_model_torch.pth'))

model = torch.load('../../models/race_model_torch.pth')

model.eval()  # Переведите модель в режим оценки

# Загрузите изображение
image_path = '/home/vorkov/Workspace/EDA/learning/data/UTKFace_48/2_0_2_20161219141143184.jpg.chip.jpg'
image = Image.open(image_path)

# Примените преобразования к изображению
image = transform(image)
image = image.to(device)
image = image.unsqueeze(0)  # Добавьте дополнительное измерение, так как модель ожидает пакет изображений

# Сделайте предсказание
with torch.no_grad():
    output = model(image)
    _, predicted = torch.max(output, 1)
    predicted_age = predicted.item()  # Получите предсказанный возраст как число

print(f'Predicted Race: {predicted_age}')

Predicted Race: 2
