In [None]:
import os
import glob
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import timm
from tqdm import tqdm

import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import precision_score, recall_score, f1_score
from datetime import datetime


In [None]:
SIZE=48
BATCH_SIZE=512
NUM_CLASSES=7

In [None]:
class FERDataset(Dataset):
    def __init__(self, directory, transform=None):
        self.root_dir = directory
        self.transform = transform
        self.images = []
        self.labels = []

        self.label_dict = {"angry": 0, "disgusted": 1, "fearful": 2, "happy": 3, "natural": 4, "sadness": 5, "surprised": 6}

        for label in os.listdir(directory):
            label_path = os.path.join(directory, label)
            if os.path.isdir(label_path):
                for img_file in os.listdir(label_path):
                    self.images.append(os.path.join(label_path, img_file))
                    self.labels.append(self.label_dict[label])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        label = self.labels[idx]
        img_path = self.images[idx]
        image = Image.open(img_path)

        if self.transform:
            image = self.transform(image)

        return image, label

In [None]:
transform = transforms.Compose([
    transforms.Resize((SIZE, SIZE)),
    transforms.ToTensor(),
    # transforms.Normalize((0.5,), (0.5,))
])

train_dataset = FERDataset(directory='../data/data_fer/train', transform=transform)
test_dataset = FERDataset(directory='../data/data_fer/test', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
class EmotionEstimatorModel(nn.Module):
    def __init__(self):
        super(EmotionEstimatorModel, self).__init__()
        # Загрузка модели EfficientNetV2
        self.base_model = timm.create_model('efficientnetv2_rw_s', pretrained=True)
        # Заменяем первый сверточный слой. Создаем новый сверточный слой с 1 входным каналом и тем же количеством выходных каналов
        self.base_model.conv_stem = nn.Conv2d(in_channels=1, 
                                              out_channels=self.base_model.conv_stem.out_channels,
                                              kernel_size=self.base_model.conv_stem.kernel_size, 
                                              stride=self.base_model.conv_stem.stride, 
                                              padding=self.base_model.conv_stem.padding, 
                                              bias=False)
        # self.base_model.conv_stem.weight.data = self.base_model.conv_stem.weight.data.sum(dim=1, keepdim=True)
        # Заменяем классификатор для соответствия числу классов
        self.base_model.classifier = nn.Linear(self.base_model.classifier.in_features, NUM_CLASSES)
        
    def forward(self, x):
        return self.base_model(x)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = EmotionEstimatorModel().to(device)
loss_fun = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
writer = SummaryWriter(log_dir="log/emotion", filename_suffix=datetime.now().strftime("%Y%m%d-%H%M%S"))

for epoch in tqdm(range(100)):  # проход по датасету несколько раз
    model.train()
    running_loss = 0.0
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fun(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        writer.add_scalar('Metrics/epoch_loss', running_loss  / len(train_loader), epoch)
    

    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for i, (images, labels) in enumerate(test_loader):
            images, labels = images.to(device), labels
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.tolist())
            all_labels.extend(labels.tolist())
    precision = precision_score(all_labels, all_preds, average="weighted")
    recall = recall_score(all_labels, all_preds, average="weighted")
    f1 = f1_score(all_labels, all_preds, average="weighted")
    writer.add_scalar('Metrics/precision', precision, epoch)
    writer.add_scalar('Metrics/recall', recall, epoch)
    writer.add_scalar('Metrics/f1', f1, epoch)
print('Finished Training')

In [None]:
# Функция для вычисления предсказаний
def get_predictions(model, loader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.tolist())
            all_labels.extend(labels.tolist())
    return all_labels, all_preds

# Получение предсказаний на валидационном наборе
val_labels, val_preds = get_predictions(model, test_loader)

# Вычисление метрик
precision = precision_score(val_labels, val_preds, average=None)
recall = recall_score(val_labels, val_preds, average=None)
f1 = f1_score(val_labels, val_preds, average=None)

print(f'Precision: {precision}')
print(f'Recall: {recall}')
print(f'F1 Score: {f1}')

In [None]:
torch.save(model, '../../models/emotion_model_torch.pth')
torch.save(model.state_dict(), '../../models/emotion_model_weights.pth')

In [None]:
# model = AgeEstimatorModel()  # Создайте экземпляр вашей модели
# model.load_state_dict(torch.load('../../emotion_model_torch.pth'))

model = torch.load('../../models/emotion_model_torch.pth')

model.eval()  # Переведите модель в режим оценки

# Загрузите изображение
image_path = '/home/vorkov/Workspace/EDA/learning/data/UTKFace_48/2_0_2_20161219141143184.jpg.chip.jpg'
image = Image.open(image_path)

# Примените преобразования к изображению
image = transform(image)
image = image.to(device)
image = image.unsqueeze(0)  # Добавьте дополнительное измерение, так как модель ожидает пакет изображений

# Сделайте предсказание
with torch.no_grad():
    output = model(image)
    _, predicted = torch.max(output, 1)
    predicted_age = predicted.item()  # Получите предсказанный возраст как число

print(f'Predicted Emotion: {predicted_age}')