In [1]:
import torch.nn as nn
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F
import json
import os
import numpy as np
import pandas as pd
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR
import cv2

In [2]:
def init_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.skip = (
            nn.Identity()
            if in_channels == out_channels
            else nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        )

        mid_channels = out_channels // 2

        self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(mid_channels)

        self.conv2 = nn.Conv2d(mid_channels, mid_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(mid_channels)

        self.conv3 = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = self.skip(x)

        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))

        x = x + residual
        return self.relu(x)

class HourglassBlock(nn.Module):
    def __init__(self, in_channels, channels):
        super().__init__()
        
        # Сжатие картинки
        self.down1 = ResidualBlock(in_channels, channels)
        self.pool1 = nn.MaxPool2d(2)
        self.down2 = ResidualBlock(channels, channels)
        self.pool2 = nn.MaxPool2d(2)
        self.down3 = ResidualBlock(channels, channels)
        self.pool3 = nn.MaxPool2d(2)
        self.down4 = ResidualBlock(channels, channels)
        
        # середина модели с неизменной размерностью
        self.center = nn.Sequential(
            ResidualBlock(channels, channels),
            ResidualBlock(channels, channels),
            ResidualBlock(channels, channels)
        )
        
        # Возвращение изначальных размеров
        self.up1 = ResidualBlock(channels, channels)
        self.up2 = ResidualBlock(channels, channels)
        self.up3 = ResidualBlock(channels, channels)
        
        # Прокинутые неизменные слои
        self.upsample1 = nn.ConvTranspose2d(channels, channels, kernel_size=2, stride=2)
        self.upsample2 = nn.ConvTranspose2d(channels, channels, kernel_size=2, stride=2)
        self.upsample3 = nn.ConvTranspose2d(channels, channels, kernel_size=2, stride=2)
        
        # Батч-нормы после апсемплинга
        self.bn1 = nn.BatchNorm2d(channels)
        self.bn2 = nn.BatchNorm2d(channels)
        self.bn3 = nn.BatchNorm2d(channels)
        
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        # сжатие
        d1 = self.down1(x)  # 128×128
        p1 = self.pool1(d1)  # 64×64
        
        d2 = self.down2(p1)  # 64×64
        p2 = self.pool2(d2)  # 32×32
        
        d3 = self.down3(p2)  # 32×32
        p3 = self.pool3(d3)  # 16×16
        
        d4 = self.down4(p3)  # 16×16
        
        # Center
        x = self.center(d4)  # 16×16
        
        # увеличение разрешения с skip connection
        x = self.up1(x)  # 16×16
        x = self.upsample1(x)  # 32×32
        x = self.bn1(x + d3)  # Skip connection
        x = self.relu(x)
        
        x = self.up2(x)  # 32×32
        x = self.upsample2(x)  # 64×64
        x = self.bn2(x + d2)  # Skip connection
        x = self.relu(x)
        
        x = self.up3(x)  # 64×64
        x = self.upsample3(x)  # 128×128
        x = self.bn3(x + d1)  # Skip connection
        x = self.relu(x)
        
        return x

class StackedHourglassNetwork(nn.Module):
    def __init__(self, num_stacks=2, num_keypoints=5, upsample_outputs=False):
        super().__init__()

        self.apply(init_weights)
        self.num_stacks = num_stacks
        self.num_keypoints = num_keypoints
        self.upsample_outputs = upsample_outputs

        self.initial = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            ResidualBlock(64, 128),
            ResidualBlock(128, 128),
            ResidualBlock(128, 256)
        )
        
        # Стек hourglass блоков
        self.hourglasses = nn.ModuleList()
        self.output_blocks = nn.ModuleList()  # Блоки для получения heatmaps
        self.merge_blocks = nn.ModuleList()   # Блоки для объединения с next stack
        
        for i in range(num_stacks):
            # Hourglass блок
            if i == 0:
                self.hourglasses.append(HourglassBlock(256, 256))
            else:
                self.hourglasses.append(HourglassBlock(256 + num_keypoints, 256))
            
            self.output_blocks.append(nn.Sequential(
                ResidualBlock(256, 256),
                nn.Conv2d(256, num_keypoints, kernel_size=1)
            ))
            
            # Блок для подготовки к следующему стеку (если не последний)
            if i < num_stacks - 1:
                self.merge_blocks.append(nn.Sequential(
                    nn.Conv2d(256, 256, kernel_size=1),
                    nn.BatchNorm2d(256),
                    nn.ReLU(inplace=True)
                ))

    def forward(self, x):
        original_size = x.shape[2:]
    
        x = self.initial(x)
    
        outputs = []
        low_res_outputs = []
    
        for i in range(self.num_stacks):
            hourglass_output = self.hourglasses[i](x)
    
            low_res_heatmaps = self.output_blocks[i](hourglass_output)
            low_res_heatmaps = torch.sigmoid(low_res_heatmaps)
            low_res_outputs.append(low_res_heatmaps)
    
            if self.upsample_outputs:
                heatmaps = F.interpolate(
                    low_res_heatmaps,
                    size=original_size,
                    mode='bilinear',
                    align_corners=False
                )
                outputs.append(heatmaps)
            else:
                outputs.append(low_res_heatmaps)
    
            if i < self.num_stacks - 1:
                features = self.merge_blocks[i](hourglass_output)
                x = torch.cat([features, low_res_heatmaps], dim=1)
    
        return outputs

In [4]:
def load_processed_datasets(data_dir="processed_datasets"):

    print(f"Загрузка датасетов из {data_dir}...")
    
    train_df = pd.read_csv(f"{data_dir}/train_dataset.csv")
    val_df = pd.read_csv(f"{data_dir}/val_dataset.csv")
    test_df = pd.read_csv(f"{data_dir}/test_dataset.csv")

    
    print(f"Загружено:")
    print(f"   Train: {len(train_df)} записей")
    print(f"   Val: {len(val_df)} записей")
    print(f"   Test: {len(test_df)} записей")
    
    return train_df, val_df, test_df

train_df, val_df, test_df = load_processed_datasets()

Загрузка датасетов из processed_datasets...
Загружено:
   Train: 19100 записей
   Val: 2387 записей
   Test: 2388 записей


In [5]:
device = torch.device('cuda')
print(f"Используемое устройство: {device}")

Используемое устройство: cuda


In [6]:
class FaceKeypointsDataset(Dataset):
    def __init__(self, dataframe, image_size=256, heatmap_size=128, sigma=2):

        self.df = dataframe
        self.image_size = image_size
        self.heatmap_size = heatmap_size
        self.sigma = sigma
        
        # Кэширование путей для скорости
        self.image_paths = self.df['path'].tolist()
        self.bboxes = self.df[['x_1', 'y_1', 'width', 'height']].values
        
        # Предвычисление нормализованных координат
        self.keypoints_norm = []
        for i in range(1, 6):
            self.keypoints_norm.append(self.df[f'x{i}_bbox_norm'].values)
            self.keypoints_norm.append(self.df[f'y{i}_bbox_norm'].values)
        self.keypoints_norm = np.column_stack(self.keypoints_norm)  # [N, 10]
        
        # Трансформации
        self.to_tensor = transforms.ToTensor()
        self.normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
        
        print(f"Dataset создан: {len(self)} изображений")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        # Загрузка и обрезка изображения
        img = Image.open(self.image_paths[idx]).convert('RGB')
        x1, y1, w, h = self.bboxes[idx]
        x2, y2 = x1 + w, y1 + h
        
        # Проверяем границы
        if x1 >= 0 and y1 >= 0 and x2 <= img.width and y2 <= img.height:
            img = img.crop((x1, y1, x2, y2))
        
        # Ресайз
        img = img.resize((self.image_size, self.image_size))
        
        # Конвертация в тензор и нормализация
        img_tensor = self.to_tensor(img)
        img_tensor = self.normalize(img_tensor)
        
        # Ключевые точки
        keypoints = self.keypoints_norm[idx].reshape(5, 2).astype(np.float32)
        
        # Heatmaps
        heatmaps = self._create_heatmaps(keypoints)
        
        return {
            'image': img_tensor,
            'heatmaps': torch.FloatTensor(heatmaps),
            'keypoints': torch.FloatTensor(keypoints),
            'image_id': self.df.iloc[idx]['image_id']
        }
    
    def _create_heatmaps(self, keypoints_norm):
        """Создает heatmaps для ключевых точек"""
        heatmaps = np.zeros((5, self.heatmap_size, self.heatmap_size), 
                           dtype=np.float32)
        
        # Масштабируем координаты
        scaled_points = keypoints_norm * self.heatmap_size
        
        for i in range(5):
            x, y = scaled_points[i]
            x_int, y_int = int(x), int(y)
            
            # Гауссово распределение
            if 0 <= x_int < self.heatmap_size and 0 <= y_int < self.heatmap_size:
                # Создаем сетку
                xx, yy = np.meshgrid(np.arange(self.heatmap_size), 
                                    np.arange(self.heatmap_size))
                
                heatmap = np.exp(-((xx - x)**2 + (yy - y)**2) / (2 * self.sigma**2))
                heatmaps[i] = heatmap
        
        return heatmaps

In [7]:
def create_dataloaders(train_df, val_df, test_df, batch_size=8, pin_memory=True):
    
    train_dataset = FaceKeypointsDataset(
        train_df,
        image_size=256,
        heatmap_size=128,
        sigma=2)
    
    val_dataset = FaceKeypointsDataset(
        val_df,
        image_size=256,
        heatmap_size=128,
        sigma=2)
    
    test_dataset = FaceKeypointsDataset(
        test_df,
        image_size=256,
        heatmap_size=128,
        sigma=2)
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        pin_memory=pin_memory)
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=pin_memory)
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=pin_memory)

    return train_loader, val_loader, test_loader

In [8]:
def save_checkpoint(model, optimizer, epoch, train_loss, val_loss, path):
    """Сохраняем чекпоинт"""
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
    }, path)

def load_checkpoint(model, optimizer, path):
    """Загружаем чекпоинт"""
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch'], checkpoint['train_loss'], checkpoint['val_loss']

def train(model, train_loader, val_loader, num_epochs=3):

    print(f"Device: {device}")
    
    # Переносим модель на устройство
    model = model.to(device)
    
    # Оптимизатор и лосс
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.MSELoss()
    
    
    # История обучения
    history = {'train_loss': [], 'val_loss': []}
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 50)
        
        model.train()
        train_loss = 0
        train_bar = tqdm(train_loader, desc="Training", leave=False)
        
        for batch in train_bar:
            # Загрузка данных
            images = batch['image'].to(device)
            heatmaps = batch['heatmaps'].to(device)
            
            # Forward
            outputs = model(images)
            loss = 0.0
            for stack_output in outputs:
                loss += criterion(stack_output, heatmaps)
            
            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Статистика
            train_loss += loss.item()
            train_bar.set_postfix({'loss': f"{loss.item():.4f}"})
        
        avg_train_loss = train_loss / len(train_loader)
        history['train_loss'].append(avg_train_loss)
        
        #Валидация
        model.eval()
        val_loss = 0
        
        with torch.no_grad():
            val_bar = tqdm(val_loader, desc="Validation", leave=False)
            
            for batch in val_bar:
                images = batch['image'].to(device)
                heatmaps = batch['heatmaps'].to(device)
                
                outputs = model(images)[-1]
                loss = criterion(outputs, heatmaps)
                val_loss += loss.item()
                
                val_bar.set_postfix({'loss': f"{loss.item():.4f}"})
        
        avg_val_loss = val_loss / len(val_loader)
        history['val_loss'].append(avg_val_loss)
        
        # Сохраняем лучшую модель
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            save_checkpoint(
                model, optimizer, epoch, 
                avg_train_loss, avg_val_loss,
                r'weights\hourglass_model.pth'
            )
            print(f"Сохранена лучшая модель (val_loss: {avg_val_loss:.4f})")
        
    return history


def evaluate(model, test_loader):
    """Оценка модели на тестовом наборе"""
    model.eval()
    
    criterion = nn.MSELoss()
    test_loss = 0
    distances = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing"):
            images = batch['image'].to(device)
            heatmaps = batch['heatmaps'].to(device)
            
            outputs = model(images)[-1]
            loss = criterion(outputs, heatmaps)
            test_loss += loss.item()
            
            # Конвертируем heatmaps в координаты и считаем ошибку
            batch_size, num_points, h, w = outputs.shape
            scale = 2
            
            for i in range(batch_size):
                for p in range(num_points):
                    # предикты
                    pred_map = outputs[i, p].cpu().numpy()
                    pred_idx = pred_map.argmax()
                    pred_y, pred_x = divmod(pred_idx, w)
                    pred_x, pred_y = pred_x * scale, pred_y * scale
                    
                    # исходные точки
                    gt_map = heatmaps[i, p].cpu().numpy()
                    gt_idx = gt_map.argmax()
                    gt_y, gt_x = divmod(gt_idx, w)
                    gt_x, gt_y = gt_x * scale, gt_y * scale
                    
                    # Расстояние между исходной точкой и предсказанием
                    dist = ((pred_x - gt_x)**2 + (pred_y - gt_y)**2)**0.5
                    distances.append(dist)
    
    avg_loss = test_loss / len(test_loader)
    avg_distance = sum(distances) / len(distances)
    
    print(f"\nРезультаты теста:")
    print(f"  Loss: {avg_loss:.4f}")
    print(f"  Avg Distance: {avg_distance:.2f} px")
    
    # Accuracy при пороге 5px
    accuracy_5px = sum(1 for d in distances if d < 5) / len(distances)
    print(f"  Accuracy (5px): {accuracy_5px:.2%}")
    
    return avg_loss, avg_distance

In [9]:
train_loader, val_loader, test_loader = create_dataloaders(train_df, val_df, test_df, batch_size=8)
    
model = StackedHourglassNetwork(num_stacks=2, num_keypoints=5)

history = train(model, train_loader, val_loader, num_epochs=2)

Dataset создан: 19100 изображений
Dataset создан: 2387 изображений
Dataset создан: 2388 изображений
Device: cuda

Epoch 1/2
--------------------------------------------------


                                                                         

KeyboardInterrupt: 

In [None]:
# Тестируем
evaluate(model, test_loader)