In [None]:
import torch
import sys

print("=" * 60)
print("–ü–†–û–í–ï–†–ö–ê –£–°–¢–ê–ù–û–í–ö–ò PYTORCH")
print("=" * 60)

print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch CUDA version: {torch.version.cuda}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"Device: {device}")
    print(f"GPU count: {torch.cuda.device_count()}")
    print(f"GPU name: {torch.cuda.get_device_name(0)}")

    # –ü—Ä–æ–≤–µ—Ä–∫–∞ –ø–∞–º—è—Ç–∏ GPU
    print(f"GPU memory total: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    print(f"GPU memory allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
    print(f"GPU memory reserved: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")
else:
    device = torch.device('cpu')
    print(f"Device: {device} (CUDA –Ω–µ –¥–æ—Å—Ç—É–ø–Ω–∞!)")

print("=" * 60)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# –ü—Ä–æ–≤–µ—Ä–∫–∞ GPU
print("=" * 60)
print("–ò–ù–§–û–†–ú–ê–¶–ò–Ø –û –°–ò–°–¢–ï–ú–ï")
print("=" * 60)
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    print(f"CUDA Version: {torch.version.cuda}")
else:
    device = torch.device('cpu')
    print("–ò—Å–ø–æ–ª—å–∑—É–µ—Ç—Å—è CPU")

print("=" * 60)

# –°–æ–∑–¥–∞–µ–º –±–æ–ª–µ–µ —Å–ª–æ–∂–Ω—É—é –º–æ–¥–µ–ª—å –¥–ª—è –¥–µ–º–æ–Ω—Å—Ç—Ä–∞—Ü–∏–∏
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        self.dropout = nn.Dropout(0.25)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 64 * 7 * 7)
        x = self.dropout(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# –ó–∞–≥—Ä—É–∂–∞–µ–º –¥–∞–Ω–Ω—ã–µ
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# –°–∫–∞—á–∏–≤–∞–µ–º MNIST
print("–ó–∞–≥—Ä—É–∑–∫–∞ –¥–∞–Ω–Ω—ã—Ö MNIST...")
train_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

test_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    download=True,
    transform=transform
)

# –°–æ–∑–¥–∞–µ–º DataLoader —Å —É—á–µ—Ç–æ–º GPU
batch_size = 128  # –ú–æ–∂–µ—Ç–µ —É–≤–µ–ª–∏—á–∏—Ç—å, —Ç.–∫. —É –≤–∞—Å 16GB –ø–∞–º—è—Ç–∏
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

print(f"–†–∞–∑–º–µ—Ä –±–∞—Ç—á–∞: {batch_size}")
print(f"–û–±—É—á–∞—é—â–∏—Ö –ø—Ä–∏–º–µ—Ä–æ–≤: {len(train_dataset)}")
print(f"–¢–µ—Å—Ç–æ–≤—ã—Ö –ø—Ä–∏–º–µ—Ä–æ–≤: {len(test_dataset)}")

# –°–æ–∑–¥–∞–µ–º –º–æ–¥–µ–ª—å –∏ –ø–µ—Ä–µ–º–µ—â–∞–µ–º –Ω–∞ GPU
model = CNNModel().to(device)
print(f"\n–ú–æ–¥–µ–ª—å —Å–æ–∑–¥–∞–Ω–∞ –∏ –ø–µ—Ä–µ–º–µ—â–µ–Ω–∞ –Ω–∞ {device}")

# –§—É–Ω–∫—Ü–∏—è –¥–ª—è –ø–æ–¥—Å—á–µ—Ç–∞ –ø–∞—Ä–∞–º–µ—Ç—Ä–æ–≤
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"–ö–æ–ª–∏—á–µ—Å—Ç–≤–æ –æ–±—É—á–∞–µ–º—ã—Ö –ø–∞—Ä–∞–º–µ—Ç—Ä–æ–≤: {count_parameters(model):,}")

# –û–ø—Ç–∏–º–∏–∑–∞—Ç–æ—Ä –∏ —Ñ—É–Ω–∫—Ü–∏—è –ø–æ—Ç–µ—Ä—å
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# –§—É–Ω–∫—Ü–∏—è –æ–±—É—á–µ–Ω–∏—è
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (data, target) in enumerate(loader):
        # –ü–µ—Ä–µ–º–µ—â–∞–µ–º –¥–∞–Ω–Ω—ã–µ –Ω–∞ GPU
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

        if batch_idx % 100 == 0:
            print(f'  Batch {batch_idx}/{len(loader)}, Loss: {loss.item():.4f}')

    accuracy = 100. * correct / total
    avg_loss = running_loss / len(loader)
    return avg_loss, accuracy

# –§—É–Ω–∫—Ü–∏—è –≤–∞–ª–∏–¥–∞—Ü–∏–∏
def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)

            running_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

    accuracy = 100. * correct / total
    avg_loss = running_loss / len(loader)
    return avg_loss, accuracy

# –û–±—É—á–µ–Ω–∏–µ –º–æ–¥–µ–ª–∏
print("\n" + "=" * 60)
print("–ù–ê–ß–ê–õ–û –û–ë–£–ß–ï–ù–ò–Ø")
print("=" * 60)

epochs = 5
train_losses, train_accs = [], []
val_losses, val_accs = [], []

for epoch in range(epochs):
    print(f'\n–≠–ø–æ—Ö–∞ {epoch+1}/{epochs}')

    # –û–±—É—á–µ–Ω–∏–µ
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    train_losses.append(train_loss)
    train_accs.append(train_acc)

    # –í–∞–ª–∏–¥–∞—Ü–∏—è
    val_loss, val_acc = validate(model, test_loader, criterion, device)
    val_losses.append(val_loss)
    val_accs.append(val_acc)

    print(f'  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
    print(f'  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')

    # –ü–æ–∫–∞–∑—ã–≤–∞–µ–º –∏—Å–ø–æ–ª—å–∑–æ–≤–∞–Ω–∏–µ –ø–∞–º—è—Ç–∏ GPU
    if torch.cuda.is_available():
        memory_allocated = torch.cuda.memory_allocated(0) / 1024**3
        memory_reserved = torch.cuda.memory_reserved(0) / 1024**3
        print(f'  GPU Memory - Allocated: {memory_allocated:.2f} GB, Reserved: {memory_reserved:.2f} GB')

print("\n" + "=" * 60)
print("–û–ë–£–ß–ï–ù–ò–ï –ó–ê–í–ï–†–®–ï–ù–û")
print("=" * 60)

# –í–∏–∑—É–∞–ª–∏–∑–∞—Ü–∏—è —Ä–µ–∑—É–ª—å—Ç–∞—Ç–æ–≤
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(range(1, epochs+1), train_losses, 'b-', label='Train Loss')
ax1.plot(range(1, epochs+1), val_losses, 'r-', label='Val Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Loss over epochs')
ax1.legend()
ax1.grid(True)

ax2.plot(range(1, epochs+1), train_accs, 'b-', label='Train Accuracy')
ax2.plot(range(1, epochs+1), val_accs, 'r-', label='Val Accuracy')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Accuracy over epochs')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.show()

# –¢–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏–µ –Ω–∞ –Ω–µ—Å–∫–æ–ª—å–∫–∏—Ö –ø—Ä–∏–º–µ—Ä–∞—Ö
print("\n–¢–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏–µ –Ω–∞ –Ω–µ—Å–∫–æ–ª—å–∫–∏—Ö –ø—Ä–∏–º–µ—Ä–∞—Ö:")
model.eval()
with torch.no_grad():
    # –ë–µ—Ä–µ–º –Ω–µ—Å–∫–æ–ª—å–∫–æ –ø—Ä–∏–º–µ—Ä–æ–≤ –∏–∑ —Ç–µ—Å—Ç–æ–≤–æ–≥–æ –Ω–∞–±–æ—Ä–∞
    data, target = next(iter(test_loader))
    data, target = data[:8].to(device), target[:8].to(device)
    output = model(data)
    _, predicted = output.max(1)

    # –ü–µ—Ä–µ–º–µ—â–∞–µ–º –æ–±—Ä–∞—Ç–Ω–æ –Ω–∞ CPU –¥–ª—è –æ—Ç–æ–±—Ä–∞–∂–µ–Ω–∏—è
    data = data.cpu()
    predicted = predicted.cpu()
    target = target.cpu()

    # –ü–æ–∫–∞–∑—ã–≤–∞–µ–º –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏—è –∏ –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏—è
    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    for idx, ax in enumerate(axes.flat):
        if idx < 8:
            image = data[idx].squeeze().numpy()
            ax.imshow(image, cmap='gray')
            ax.set_title(f'Pred: {predicted[idx]}, True: {target[idx]}')
            ax.axis('off')
    plt.tight_layout()
    plt.show()

# –§–∏–Ω–∞–ª—å–Ω–∞—è —Å—Ç–∞—Ç–∏—Å—Ç–∏–∫–∞
print("\n" + "=" * 60)
print("–§–ò–ù–ê–õ–¨–ù–ê–Ø –°–¢–ê–¢–ò–°–¢–ò–ö–ê")
print("=" * 60)
print(f"–§–∏–Ω–∞–ª—å–Ω–∞—è —Ç–æ—á–Ω–æ—Å—Ç—å –Ω–∞ –≤–∞–ª–∏–¥–∞—Ü–∏–∏: {val_accs[-1]:.2f}%")
print(f"–§–∏–Ω–∞–ª—å–Ω–∞—è –ø–æ—Ç–µ—Ä—è –Ω–∞ –≤–∞–ª–∏–¥–∞—Ü–∏–∏: {val_losses[-1]:.4f}")

# –°–æ—Ö—Ä–∞–Ω–µ–Ω–∏–µ –º–æ–¥–µ–ª–∏
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'train_losses': train_losses,
    'val_losses': val_losses,
    'train_accs': train_accs,
    'val_accs': val_accs,
}, 'mnist_cnn_model.pth')

print("–ú–æ–¥–µ–ª—å —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∞ –≤ 'mnist_cnn_model.pth'")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import time
import numpy as np

# –£—Å—Ç–∞–Ω–æ–≤–∫–∞ —É—Å—Ç—Ä–æ–π—Å—Ç–≤–∞
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"–ò—Å–ø–æ–ª—å–∑—É–µ—Ç—Å—è —É—Å—Ç—Ä–æ–π—Å—Ç–≤–æ: {device}")

# –ò—Å–ø—Ä–∞–≤–ª–µ–Ω–Ω–∞—è –º–æ–¥–µ–ª—å –¥–ª—è CIFAR-10 (3 –∫–∞–Ω–∞–ª–∞ –≤–º–µ—Å—Ç–æ 1)
class EnhancedCNN(nn.Module):
    def __init__(self):
        super(EnhancedCNN, self).__init__()
        # –ò–∑–º–µ–Ω—è–µ–º –ø–µ—Ä–≤—ã–π —Å–ª–æ–π –Ω–∞ 3 –∫–∞–Ω–∞–ª–∞ (RGB)
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)  # –ë—ã–ª–æ 1, —Ç–µ–ø–µ—Ä—å 3
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.3)

        # –†–∞—Å—á–µ—Ç —Ä–∞–∑–º–µ—Ä–∞ –ø–æ—Å–ª–µ —Å–≤–µ—Ä—Ç–æ–∫ –¥–ª—è CIFAR-10 (32x32)
        # –ü–æ—Å–ª–µ conv1 + pool: 32x32 -> 16x16
        # –ü–æ—Å–ª–µ conv2 + pool: 16x16 -> 8x8
        # –ü–æ—Å–ª–µ conv3 + pool: 8x8 -> 4x4
        self.fc1 = nn.Linear(256 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 10)

        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm2d(256)

    def forward(self, x):
        x = self.pool(torch.relu(self.bn1(self.conv1(x))))
        x = self.pool(torch.relu(self.bn2(self.conv2(x))))
        x = self.pool(torch.relu(self.bn3(self.conv3(x))))

        x = x.view(-1, 256 * 4 * 4)
        x = self.dropout(x)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# –ó–∞–≥—Ä—É–∑–∫–∞ –¥–∞–Ω–Ω—ã—Ö CIFAR-10
print("\n–ó–∞–≥—Ä—É–∑–∫–∞ –¥–∞–Ω–Ω—ã—Ö CIFAR-10...")
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # –ê—É–≥–º–µ–Ω—Ç–∞—Ü–∏—è
    transforms.RandomCrop(32, padding=4),  # –ê—É–≥–º–µ–Ω—Ç–∞—Ü–∏—è
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))  # –°—Ç–∞–Ω–¥–∞—Ä—Ç–Ω—ã–µ –∑–Ω–∞—á–µ–Ω–∏—è –¥–ª—è CIFAR-10
])

# –ò—Å–ø–æ–ª—å–∑—É–µ–º CIFAR-10
train_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

test_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
    ])
)

# –£–≤–µ–ª–∏—á–∏–≤–∞–µ–º batch size –¥–ª—è –≤–∞—à–µ–π –º–æ—â–Ω–æ–π GPU
batch_size = 256  # –ú–æ–∂–Ω–æ —É–≤–µ–ª–∏—á–∏—Ç—å –¥–æ 512 –µ—Å–ª–∏ —Ö–≤–∞—Ç–∏—Ç –ø–∞–º—è—Ç–∏
print(f"Batch size: {batch_size}")

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

# –ö–ª–∞—Å—Å—ã CIFAR-10
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

print(f"–û–±—É—á–∞—é—â–∏—Ö –ø—Ä–∏–º–µ—Ä–æ–≤: {len(train_dataset)}")
print(f"–¢–µ—Å—Ç–æ–≤—ã—Ö –ø—Ä–∏–º–µ—Ä–æ–≤: {len(test_dataset)}")
print(f"–†–∞–∑–º–µ—Ä –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–π: {train_dataset[0][0].shape}")

# –°–æ–∑–¥–∞–µ–º –º–æ–¥–µ–ª—å
model = EnhancedCNN().to(device)
print(f"\n–ú–æ–¥–µ–ª—å —Å–æ–∑–¥–∞–Ω–∞ –∏ –ø–µ—Ä–µ–º–µ—â–µ–Ω–∞ –Ω–∞ {device}")

# –§—É–Ω–∫—Ü–∏—è –¥–ª—è –ø–æ–¥—Å—á–µ—Ç–∞ –ø–∞—Ä–∞–º–µ—Ç—Ä–æ–≤
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

total_params = count_parameters(model)
print(f"–ö–æ–ª–∏—á–µ—Å—Ç–≤–æ –æ–±—É—á–∞–µ–º—ã—Ö –ø–∞—Ä–∞–º–µ—Ç—Ä–æ–≤: {total_params:,}")

# –û–ø—Ç–∏–º–∏–∑–∞—Ç–æ—Ä –∏ —Ñ—É–Ω–∫—Ü–∏—è –ø–æ—Ç–µ—Ä—å
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

# –§—É–Ω–∫—Ü–∏—è –æ–±—É—á–µ–Ω–∏—è
def train_epoch(model, loader, criterion, optimizer, device, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (data, target) in enumerate(loader):
        # –ü–µ—Ä–µ–º–µ—â–∞–µ–º –¥–∞–Ω–Ω—ã–µ –Ω–∞ GPU
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

        if batch_idx % 50 == 0:
            accuracy = 100. * correct / total
            print(f'Epoch: {epoch} | Batch: {batch_idx}/{len(loader)} | '
                  f'Loss: {loss.item():.4f} | Acc: {accuracy:.2f}%')

    accuracy = 100. * correct / total
    avg_loss = running_loss / len(loader)
    return avg_loss, accuracy

# –§—É–Ω–∫—Ü–∏—è –≤–∞–ª–∏–¥–∞—Ü–∏–∏
def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)

            running_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

    accuracy = 100. * correct / total
    avg_loss = running_loss / len(loader)
    return avg_loss, accuracy

# –û–±—É—á–µ–Ω–∏–µ –º–æ–¥–µ–ª–∏
print("\n" + "=" * 60)
print("–ù–ê–ß–ê–õ–û –û–ë–£–ß–ï–ù–ò–Ø –ù–ê CIFAR-10")
print("=" * 60)

epochs = 15
train_losses, train_accs = [], []
val_losses, val_accs = [], []

for epoch in range(epochs):
    print(f'\n{"-" * 50}')
    print(f'–≠–ø–æ—Ö–∞ {epoch+1}/{epochs}')
    print(f'{"-" * 50}')

    # –û–±—É—á–µ–Ω–∏–µ
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device, epoch+1)
    train_losses.append(train_loss)
    train_accs.append(train_acc)

    # –í–∞–ª–∏–¥–∞—Ü–∏—è
    val_loss, val_acc = validate(model, test_loader, criterion, device)
    val_losses.append(val_loss)
    val_accs.append(val_acc)

    # –û–±–Ω–æ–≤–ª–µ–Ω–∏–µ learning rate
    scheduler.step()

    print(f'\n–ò—Ç–æ–≥–∏ —ç–ø–æ—Ö–∏ {epoch+1}:')
    print(f'  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
    print(f'  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
    print(f'  Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')

    # –ò–Ω—Ñ–æ—Ä–º–∞—Ü–∏—è –æ –ø–∞–º—è—Ç–∏ GPU
    if torch.cuda.is_available():
        memory_allocated = torch.cuda.memory_allocated(0) / 1024**3
        print(f'  GPU Memory: {memory_allocated:.2f}GB / 16.00GB')

print("\n" + "=" * 60)
print("–û–ë–£–ß–ï–ù–ò–ï –ó–ê–í–ï–†–®–ï–ù–û")
print("=" * 60)

# –í–∏–∑—É–∞–ª–∏–∑–∞—Ü–∏—è —Ä–µ–∑—É–ª—å—Ç–∞—Ç–æ–≤
plt.figure(figsize=(15, 5))

# –ì—Ä–∞—Ñ–∏–∫ –ø–æ—Ç–µ—Ä—å
plt.subplot(1, 2, 1)
plt.plot(range(1, epochs+1), train_losses, 'b-', linewidth=2, label='Train Loss')
plt.plot(range(1, epochs+1), val_losses, 'r-', linewidth=2, label='Val Loss')
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.title('Loss over epochs', fontsize=14)
plt.legend()
plt.grid(True)

# –ì—Ä–∞—Ñ–∏–∫ —Ç–æ—á–Ω–æ—Å—Ç–∏
plt.subplot(1, 2, 2)
plt.plot(range(1, epochs+1), train_accs, 'b-', linewidth=2, label='Train Accuracy')
plt.plot(range(1, epochs+1), val_accs, 'r-', linewidth=2, label='Val Accuracy')
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Accuracy (%)', fontsize=12)
plt.title('Accuracy over epochs', fontsize=14)
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

# –¢–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏–µ –Ω–∞ –Ω–µ—Å–∫–æ–ª—å–∫–∏—Ö –ø—Ä–∏–º–µ—Ä–∞—Ö
print("\n–¢–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏–µ –Ω–∞ –ø—Ä–∏–º–µ—Ä–∞—Ö –∏–∑ —Ç–µ—Å—Ç–æ–≤–æ–≥–æ –Ω–∞–±–æ—Ä–∞:")
model.eval()
with torch.no_grad():
    # –ë–µ—Ä–µ–º 10 –ø—Ä–∏–º–µ—Ä–æ–≤
    data, target = next(iter(test_loader))
    data, target = data[:10].to(device), target[:10].to(device)
    output = model(data)
    _, predicted = output.max(1)

    # –ü–µ—Ä–µ–º–µ—â–∞–µ–º –æ–±—Ä–∞—Ç–Ω–æ –Ω–∞ CPU –¥–ª—è –æ—Ç–æ–±—Ä–∞–∂–µ–Ω–∏—è
    data = data.cpu()
    predicted = predicted.cpu()
    target = target.cpu()

    # –î–µ–Ω–æ—Ä–º–∞–ª–∏–∑–∞—Ü–∏—è –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–π
    mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
    std = torch.tensor([0.2470, 0.2435, 0.2616]).view(3, 1, 1)
    data_denorm = data * std + mean
    data_denorm = torch.clamp(data_denorm, 0, 1)

    # –ü–æ–∫–∞–∑—ã–≤–∞–µ–º –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏—è –∏ –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏—è
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    for idx, ax in enumerate(axes.flat):
        if idx < 10:
            image = data_denorm[idx].permute(1, 2, 0).numpy()
            ax.imshow(image)
            title = f'Pred: {classes[predicted[idx]]}\nTrue: {classes[target[idx]]}'
            ax.set_title(title, color='green' if predicted[idx] == target[idx] else 'red')
            ax.axis('off')

    plt.suptitle('–ü—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏—è –º–æ–¥–µ–ª–∏ –Ω–∞ —Ç–µ—Å—Ç–æ–≤—ã—Ö –¥–∞–Ω–Ω—ã—Ö CIFAR-10', fontsize=16)
    plt.tight_layout()
    plt.show()

# –ú–∞—Ç—Ä–∏—Ü–∞ –æ—à–∏–±–æ–∫
print("\n–ú–∞—Ç—Ä–∏—Ü–∞ –æ—à–∏–±–æ–∫ (Confusion Matrix):")
model.eval()
all_preds = []
all_targets = []

with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        _, preds = output.max(1)
        all_preds.extend(preds.cpu().numpy())
        all_targets.extend(target.cpu().numpy())

# –°–æ–∑–¥–∞–µ–º –º–∞—Ç—Ä–∏—Ü—É –æ—à–∏–±–æ–∫
from sklearn.metrics import confusion_matrix
import seaborn as sns

cm = confusion_matrix(all_targets, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=classes, yticklabels=classes)
plt.title('Confusion Matrix', fontsize=16)
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.tight_layout()
plt.show()

# –§–∏–Ω–∞–ª—å–Ω–∞—è —Å—Ç–∞—Ç–∏—Å—Ç–∏–∫–∞
print("\n" + "=" * 60)
print("–§–ò–ù–ê–õ–¨–ù–ê–Ø –°–¢–ê–¢–ò–°–¢–ò–ö–ê")
print("=" * 60)
print(f"–§–∏–Ω–∞–ª—å–Ω–∞—è —Ç–æ—á–Ω–æ—Å—Ç—å –Ω–∞ –≤–∞–ª–∏–¥–∞—Ü–∏–∏: {val_accs[-1]:.2f}%")
print(f"–§–∏–Ω–∞–ª—å–Ω–∞—è –ø–æ—Ç–µ—Ä—è –Ω–∞ –≤–∞–ª–∏–¥–∞—Ü–∏–∏: {val_losses[-1]:.4f}")
print(f"–õ—É—á—à–∞—è —Ç–æ—á–Ω–æ—Å—Ç—å –Ω–∞ –≤–∞–ª–∏–¥–∞—Ü–∏–∏: {max(val_accs):.2f}%")
print(f"–í—Å–µ–≥–æ –ø–∞—Ä–∞–º–µ—Ç—Ä–æ–≤ –º–æ–¥–µ–ª–∏: {total_params:,}")

# –°–æ—Ö—Ä–∞–Ω–µ–Ω–∏–µ –º–æ–¥–µ–ª–∏
torch.save({
    'epoch': epochs,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'train_losses': train_losses,
    'val_losses': val_losses,
    'train_accs': train_accs,
    'val_accs': val_accs,
    'total_params': total_params,
}, 'cifar10_enhanced_cnn_corrected.pth')

print("\n–ú–æ–¥–µ–ª—å —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∞ –≤ 'cifar10_enhanced_cnn_corrected.pth'")

# –û—á–∏—Å—Ç–∫–∞ –ø–∞–º—è—Ç–∏ GPU
torch.cuda.empty_cache()
print("–ü–∞–º—è—Ç—å GPU –æ—á–∏—â–µ–Ω–∞")

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from tqdm import tqdm
import math
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import warnings
warnings.filterwarnings('ignore')


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
df_id = pd.read_csv('selected_images/selected_images_info.csv')
df_id['person_id'].value_counts()

In [None]:
df_id = df_id[['filename', 'person_id']]
print(df_id.head())
print(f"–í—Å–µ–≥–æ –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–π: {len(df_id)}")
print(f"–£–Ω–∏–∫–∞–ª—å–Ω—ã—Ö –ø–µ—Ä—Å–æ–Ω: {df_id['person_id'].nunique()}")

In [None]:
img_dir = 'aligned_faces'

In [None]:
# ==================== –ö–û–ù–§–ò–ì–£–†–ê–¶–ò–Ø ====================
class Config:
    def __init__(self):
        self.identity_df = df_id
        self.img_dir = img_dir
        self.max_classes = 350
        self.min_samples_per_person = 26
        self.seed = 42
        self.val_ratio = 0.15
        self.test_ratio = 0.15
        self.batch_size = 128
        self.num_workers = 0
        self.embedding_size = 512
        self.learning_rate = 0.001
        self.num_epochs = 25
        self.arcface_s = 32.0
        self.arcface_m = 0.5

# ==================== –û–ë–†–ê–ë–û–¢–ö–ê –î–ê–ù–ù–´–• ====================
class CelebADataProcessor:
    def __init__(self, config):
        self.config = config
        self.identity_df = config.identity_df
        print(f"–í—Å–µ–≥–æ –¥–∞–Ω–Ω—ã—Ö: {len(self.identity_df)} –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–π")
        print(f"–£–Ω–∏–∫–∞–ª—å–Ω—ã—Ö –ª—é–¥–µ–π: {self.identity_df['person_id'].nunique()}")

    def filter_data(self):
        person_counts = self.identity_df['person_id'].value_counts()
        top_persons = person_counts.nlargest(self.config.max_classes).index
        self.filtered_df = self.identity_df[self.identity_df['person_id'].isin(top_persons)].copy()

        unique_ids = sorted(self.filtered_df['person_id'].unique())
        self.id_to_idx = {old_id: idx for idx, old_id in enumerate(unique_ids)}
        self.idx_to_id = {idx: old_id for old_id, idx in self.id_to_idx.items()}

        self.filtered_df['class_idx'] = self.filtered_df['person_id'].map(self.id_to_idx)

        print(f"\n–ü–æ—Å–ª–µ —Ñ–∏–ª—å—Ç—Ä–∞—Ü–∏–∏:")
        print(f"  –í—Å–µ–≥–æ –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–π: {len(self.filtered_df)}")
        print(f"  –£–Ω–∏–∫–∞–ª—å–Ω—ã—Ö –ª—é–¥–µ–π: {self.filtered_df['person_id'].nunique()}")
        print(f"  –î–∏–∞–ø–∞–∑–æ–Ω –º–µ—Ç–æ–∫: {self.filtered_df['class_idx'].min()} - {self.filtered_df['class_idx'].max()}")

        return len(unique_ids)

    def split_data_by_images(self):
        print(f"\n–†–∞–∑–¥–µ–ª–µ–Ω–∏–µ –¥–∞–Ω–Ω—ã—Ö –ø–æ –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏—è–º (—Å—Ç—Ä–∞—Ç–∏—Ñ–∏—Ü–∏—Ä–æ–≤–∞–Ω–æ)...")

        train_df, temp_df = train_test_split(
            self.filtered_df,
            test_size=self.config.val_ratio + self.config.test_ratio,
            random_state=self.config.seed,
            stratify=self.filtered_df['class_idx']
        )

        val_df, test_df = train_test_split(
            temp_df,
            test_size=self.config.test_ratio/(self.config.val_ratio + self.config.test_ratio),
            random_state=self.config.seed,
            stratify=temp_df['class_idx']
        )

        self.train_df = train_df.reset_index(drop=True)
        self.val_df = val_df.reset_index(drop=True)
        self.test_df = test_df.reset_index(drop=True)

        print(f"Train: {len(self.train_df)} samples")
        print(f"Val: {len(self.val_df)} samples")
        print(f"Test: {len(self.test_df)} samples")

        return {
            'train': len(self.train_df),
            'val': len(self.val_df),
            'test': len(self.test_df),
            'num_persons': len(self.filtered_df['person_id'].unique())
        }

# ==================== –î–ê–¢–ê–°–ï–¢ ====================
class CelebAClassificationDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row['filename'])

        try:
            image = Image.open(img_path).convert('RGB')
        except:
            image = Image.new('RGB', (224, 224), color=(0, 0, 0))

        if self.transform:
            image = self.transform(image)

        label = int(row['class_idx'])
        return image, label

# ==================== –ú–û–î–ï–õ–ò ====================
class SimpleFaceModel(nn.Module):
    """–ü—Ä–æ—Å—Ç–∞—è –∏ —Å—Ç–∞–±–∏–ª—å–Ω–∞—è –º–æ–¥–µ–ª—å"""
    def __init__(self, num_classes=350, embedding_size=512):
        super().__init__()

        # –ò—Å–ø–æ–ª—å–∑—É–µ–º ResNet18
        self.backbone = models.resnet18(pretrained=True)
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()

        # –ü—Ä–æ—Å—Ç–æ–π embedding —Å–ª–æ–π
        self.embedding = nn.Sequential(
            nn.Linear(in_features, embedding_size),
            nn.BatchNorm1d(embedding_size),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3)
        )

        self.classifier = nn.Linear(embedding_size, num_classes)

        # –ó–∞–º–æ—Ä–∞–∂–∏–≤–∞–µ–º —Ç–æ–ª—å–∫–æ –ø–µ—Ä–≤—ã–µ —Å–ª–æ–∏
        for name, param in self.backbone.named_parameters():
            if 'layer1' in name or 'conv1' in name or 'bn1' in name:
                param.requires_grad = False

        print(f"\n–ú–æ–¥–µ–ª—å –ø–∞—Ä–∞–º–µ—Ç—Ä–æ–≤:")
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print(f"  –í—Å–µ–≥–æ: {total_params:,}")
        print(f"  –û–±—É—á–∞–µ–º—ã—Ö: {trainable_params:,}")

    def forward(self, x, return_embedding=False):
        features = self.backbone(x)
        embedding = self.embedding(features)

        if return_embedding:
            return embedding

        logits = self.classifier(embedding)
        return logits, embedding

# ==================== –û–°–ù–û–í–ù–ê–Ø –§–£–ù–ö–¶–ò–Ø ====================
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    config = Config()
    processor = CelebADataProcessor(config)

    num_classes = processor.filter_data()
    stats = processor.split_data_by_images()

    # –£–º–µ—Ä–µ–Ω–Ω—ã–µ –∞—É–≥–º–µ–Ω—Ç–∞—Ü–∏–∏
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])

    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])

    # –î–∞—Ç–∞—Å–µ—Ç—ã
    train_dataset = CelebAClassificationDataset(
        processor.train_df, config.img_dir, train_transform
    )
    val_dataset = CelebAClassificationDataset(
        processor.val_df, config.img_dir, val_transform
    )
    test_dataset = CelebAClassificationDataset(
        processor.test_df, config.img_dir, val_transform
    )

    # –î–∞—Ç–∞–ª–æ–∞–¥–µ—Ä—ã
    train_loader = DataLoader(
        train_dataset, batch_size=config.batch_size,
        shuffle=True, num_workers=config.num_workers
    )
    val_loader = DataLoader(
        val_dataset, batch_size=config.batch_size,
        shuffle=False, num_workers=config.num_workers
    )
    test_loader = DataLoader(
        test_dataset, batch_size=config.batch_size,
        shuffle=False, num_workers=config.num_workers
    )

    # –°–æ–∑–¥–∞–µ–º –º–æ–¥–µ–ª—å
    model = SimpleFaceModel(
        num_classes=num_classes,
        embedding_size=config.embedding_size
    ).to(device)

    # –ü—Ä–æ—Å—Ç–æ–π –æ–ø—Ç–∏–º–∏–∑–∞—Ç–æ—Ä
    optimizer = optim.Adam(
        model.parameters(),
        lr=config.learning_rate,
        weight_decay=1e-4
    )

    # StepLR scheduler
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

    criterion = nn.CrossEntropyLoss()

    os.makedirs('checkpoints', exist_ok=True)

    train_history = {
        'loss': [], 'acc': [],
        'val_loss': [], 'val_acc': [],
        'lr': []
    }

    best_val_acc = 0
    patience_counter = 0
    max_patience = 10

    print("\n" + "="*60)
    print("–ù–ê–ß–ê–õ–û –û–ë–£–ß–ï–ù–ò–Ø")
    print("="*60)

    for epoch in range(config.num_epochs):
        print(f"\nEpoch {epoch+1}/{config.num_epochs}")
        print("-" * 40)

        # –û–±—É—á–µ–Ω–∏–µ
        model.train()
        total_loss = 0
        correct = 0
        total = 0

        pbar = tqdm(train_loader, desc="Training")
        for batch_idx, (images, labels) in enumerate(pbar):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            logits, _ = model(images, return_embedding=False)
            loss = criterion(logits, labels)
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item() * images.size(0)
            _, predicted = logits.max(1)
            correct += predicted.eq(labels).sum().item()
            total += images.size(0)

            current_acc = 100. * predicted.eq(labels).sum().item() / images.size(0)
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{current_acc:.1f}%'
            })

        train_loss = total_loss / total
        train_acc = 100. * correct / total

        # –í–∞–ª–∏–¥–∞—Ü–∏—è
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc="Validation"):
                images, labels = images.to(device), labels.to(device)
                logits, _ = model(images, return_embedding=False)
                loss = criterion(logits, labels)

                val_loss += loss.item() * images.size(0)
                _, predicted = logits.max(1)
                val_correct += predicted.eq(labels).sum().item()
                val_total += images.size(0)

        val_loss /= val_total
        val_acc = 100. * val_correct / val_total

        # –°–æ—Ö—Ä–∞–Ω—è–µ–º –∏—Å—Ç–æ—Ä–∏—é
        train_history['loss'].append(train_loss)
        train_history['acc'].append(train_acc)
        train_history['val_loss'].append(val_loss)
        train_history['val_acc'].append(val_acc)
        train_history['lr'].append(optimizer.param_groups[0]['lr'])

        print(f"\n–ò—Ç–æ–≥–∏ —ç–ø–æ—Ö–∏:")
        print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        print(f"  Learning Rate: {optimizer.param_groups[0]['lr']:.2e}")

        # –°–æ—Ö—Ä–∞–Ω—è–µ–º –ª—É—á—à—É—é –º–æ–¥–µ–ª—å
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            checkpoint_path = f'checkpoints/best_model.pth'
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'train_acc': train_acc,
                'val_acc': val_acc,
            }, checkpoint_path)
            print(f"  ‚úì –°–æ—Ö—Ä–∞–Ω–µ–Ω–∞ –ª—É—á—à–∞—è –º–æ–¥–µ–ª—å (Acc: {val_acc:.2f}%)")
        else:
            patience_counter += 1
            print(f"  –ü–∞—Çience: {patience_counter}/{max_patience}")

            if patience_counter >= max_patience:
                print(f"  ‚ö† Early stopping –Ω–∞ —ç–ø–æ—Ö–µ {epoch+1}")
                break

        # –û–±–Ω–æ–≤–ª—è–µ–º scheduler
        scheduler.step()

    # ==================== –í–ò–ó–£–ê–õ–ò–ó–ê–¶–ò–Ø ====================
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    axes[0, 0].plot(train_history['loss'], label='Train Loss', linewidth=2)
    axes[0, 0].plot(train_history['val_loss'], label='Val Loss', linewidth=2)
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Loss History')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    axes[0, 1].plot(train_history['acc'], label='Train Acc', linewidth=2)
    axes[0, 1].plot(train_history['val_acc'], label='Val Acc', linewidth=2)
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy (%)')
    axes[0, 1].set_title('Accuracy History')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

    axes[1, 0].plot(train_history['lr'], linewidth=2)
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Learning Rate')
    axes[1, 0].set_title('Learning Rate Schedule')
    axes[1, 0].grid(True, alpha=0.3)

    # –†–∞—Å–ø—Ä–µ–¥–µ–ª–µ–Ω–∏–µ –º–µ—Ç–æ–∫
    axes[1, 1].hist(processor.train_df['class_idx'].values, bins=num_classes,
                    alpha=0.7, label='Train', density=True)
    axes[1, 1].hist(processor.val_df['class_idx'].values, bins=num_classes,
                    alpha=0.7, label='Val', density=True)
    axes[1, 1].set_xlabel('Class Index')
    axes[1, 1].set_ylabel('Density')
    axes[1, 1].set_title('Class Distribution')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('training_results_stable.png', dpi=150, bbox_inches='tight')
    plt.show()

    # ==================== –¢–ï–°–¢–ò–†–û–í–ê–ù–ò–ï ====================
    print("\n" + "="*60)
    print("–¢–ï–°–¢–ò–†–û–í–ê–ù–ò–ï –õ–£–ß–®–ï–ô –ú–û–î–ï–õ–ò")
    print("="*60)

    # –ó–∞–≥—Ä—É–∂–∞–µ–º –ª—É—á—à—É—é –º–æ–¥–µ–ª—å
    checkpoint = torch.load('checkpoints/best_model.pth')
    checkpoint = torch.load('checkpoints/best_model.pth')
    model.load_state_dict(checkpoint['model_state_dict'])

    model.eval()
    test_loss = 0
    test_correct = 0
    test_total = 0

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Testing"):
            images, labels = images.to(device), labels.to(device)
            logits, _ = model(images, return_embedding=False)
            loss = criterion(logits, labels)

            test_loss += loss.item() * images.size(0)
            _, predicted = logits.max(1)
            test_correct += predicted.eq(labels).sum().item()
            test_total += images.size(0)

    test_loss /= test_total
    test_acc = 100. * test_correct / test_total

    print(f"\n–†–µ–∑—É–ª—å—Ç–∞—Ç—ã –Ω–∞ —Ç–µ—Å—Ç–æ–≤–æ–π –≤—ã–±–æ—Ä–∫–µ:")
    print(f"  Test Loss: {test_loss:.4f}")
    print(f"  Test Accuracy: {test_acc:.2f}%")
    print(f"  Correct/Total: {test_correct}/{test_total}")

    print("\n" + "="*60)
    print("–û–ë–£–ß–ï–ù–ò–ï –ó–ê–í–ï–†–®–ï–ù–û!")
    print("="*60)
    print(f"–õ—É—á—à–∞—è —Ç–æ—á–Ω–æ—Å—Ç—å –Ω–∞ –≤–∞–ª–∏–¥–∞—Ü–∏–∏: {best_val_acc:.2f}%")
    print(f"–¢–æ—á–Ω–æ—Å—Ç—å –Ω–∞ —Ç–µ—Å—Ç–µ: {test_acc:.2f}%")

if __name__ == "__main__":
    main()

In [None]:
# ==================== –ö–û–ù–§–ò–ì–£–†–ê–¶–ò–Ø ====================
class Config:
    def __init__(self):
        self.identity_df = df_id
        self.img_dir = img_dir
        self.max_classes = 350
        self.min_samples_per_person = 26
        self.seed = 42
        self.val_ratio = 0.15
        self.test_ratio = 0.15
        self.batch_size = 64
        self.num_workers = 0
        self.embedding_size = 512
        self.learning_rate = 0.0001
        self.num_epochs_ce = 25  # –≠–ø–æ—Ö –¥–ª—è CE
        self.num_epochs_arcface = 15  # –≠–ø–æ—Ö –¥–ª—è ArcFace fine-tuning
        self.arcface_s = 32.0  # Scale parameter
        self.arcface_m = 0.5   # Margin parameter
        self.arcface_easy_margin = True

# ==================== –û–ë–†–ê–ë–û–¢–ö–ê –î–ê–ù–ù–´–• ====================
class CelebADataProcessor:
    def __init__(self, config):
        self.config = config
        self.identity_df = config.identity_df

        print(f"–í—Å–µ–≥–æ –¥–∞–Ω–Ω—ã—Ö: {len(self.identity_df)} –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–π")
        print(f"–£–Ω–∏–∫–∞–ª—å–Ω—ã—Ö –ª—é–¥–µ–π: {self.identity_df['person_id'].nunique()}")

    def filter_data(self):
        person_counts = self.identity_df['person_id'].value_counts()
        top_persons = person_counts.nlargest(self.config.max_classes).index
        self.filtered_df = self.identity_df[self.identity_df['person_id'].isin(top_persons)].copy()

        unique_ids = sorted(self.filtered_df['person_id'].unique())
        self.id_to_idx = {old_id: idx for idx, old_id in enumerate(unique_ids)}
        self.idx_to_id = {idx: old_id for old_id, idx in self.id_to_idx.items()}

        self.filtered_df['class_idx'] = self.filtered_df['person_id'].map(self.id_to_idx)

        print(f"\n–ü–æ—Å–ª–µ —Ñ–∏–ª—å—Ç—Ä–∞—Ü–∏–∏:")
        print(f"  –í—Å–µ–≥–æ –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–π: {len(self.filtered_df)}")
        print(f"  –£–Ω–∏–∫–∞–ª—å–Ω—ã—Ö –ª—é–¥–µ–π: {self.filtered_df['person_id'].nunique()}")
        print(f"  –î–∏–∞–ø–∞–∑–æ–Ω –º–µ—Ç–æ–∫: {self.filtered_df['class_idx'].min()} - {self.filtered_df['class_idx'].max()}")

        return len(unique_ids)

    def split_data_by_images(self):
        print(f"\n–†–∞–∑–¥–µ–ª–µ–Ω–∏–µ –¥–∞–Ω–Ω—ã—Ö –ø–æ –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏—è–º (—Å—Ç—Ä–∞—Ç–∏—Ñ–∏—Ü–∏—Ä–æ–≤–∞–Ω–æ)...")

        train_df, temp_df = train_test_split(
            self.filtered_df,
            test_size=self.config.val_ratio + self.config.test_ratio,
            random_state=self.config.seed,
            stratify=self.filtered_df['class_idx']
        )

        val_df, test_df = train_test_split(
            temp_df,
            test_size=self.config.test_ratio/(self.config.val_ratio + self.config.test_ratio),
            random_state=self.config.seed,
            stratify=temp_df['class_idx']
        )

        self.train_df = train_df.reset_index(drop=True)
        self.val_df = val_df.reset_index(drop=True)
        self.test_df = test_df.reset_index(drop=True)

        print(f"Train: {len(self.train_df)} samples")
        print(f"Val: {len(self.val_df)} samples")
        print(f"Test: {len(self.test_df)} samples")

        return {
            'train': len(self.train_df),
            'val': len(self.val_df),
            'test': len(self.test_df),
            'num_persons': len(self.filtered_df['person_id'].unique())
        }

# ==================== –î–ê–¢–ê–°–ï–¢ ====================
class CelebAClassificationDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row['filename'])

        try:
            image = Image.open(img_path).convert('RGB')
        except:
            image = Image.new('RGB', (224, 224), color=(0, 0, 0))

        if self.transform:
            image = self.transform(image)

        label = int(row['class_idx'])
        return image, label

# ==================== –ú–û–î–ï–õ–ò ====================
class SimpleFaceModel(nn.Module):
    def __init__(self, num_classes=300, embedding_size=512):
        super().__init__()

        self.backbone = models.resnet18(pretrained=True)
        in_features = self.backbone.fc.in_features

        # –ó–∞–º–µ–Ω—è–µ–º –ø–æ—Å–ª–µ–¥–Ω–∏–π —Å–ª–æ–π
        self.backbone.fc = nn.Identity()

        # –≠–º–±–µ–¥–¥–∏–Ω–≥ —Å–ª–æ–π (–¥–ª—è —Å–æ–≤–º–µ—Å—Ç–∏–º–æ—Å—Ç–∏ —Å ArcFace)
        self.embedding = nn.Sequential(
            nn.Linear(in_features, embedding_size),
            nn.BatchNorm1d(embedding_size),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3)
        )

        # –ö–ª–∞—Å—Å–∏—Ñ–∏–∫–∞—Ç–æ—Ä
        self.classifier = nn.Linear(embedding_size, num_classes)

        # –ó–∞–º–æ—Ä–∞–∂–∏–≤–∞–µ–º –ø–µ—Ä–≤—ã–µ —Å–ª–æ–∏
        for name, param in self.backbone.named_parameters():
            if 'layer1' in name or 'conv1' in name or 'bn1' in name:
                param.requires_grad = False

    def forward(self, x, labels=None):
        features = self.backbone(x)
        embedding = self.embedding(features)
        logits = self.classifier(embedding)

        if labels is not None:
            return logits, embedding
        return logits, embedding


class ArcMarginProduct(nn.Module):
    """ArcFace loss —Å–ª–æ–π"""
    def __init__(self, in_features, out_features, s=32.0, m=0.3, easy_margin=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m

        # Easy margin –¥–ª—è —Å—Ç–∞–±–∏–ª—å–Ω–æ—Å—Ç–∏
        self.easy_margin = easy_margin

        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_normal_(self.weight)

        # –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∞—Ü–∏—è –≤–µ—Å–æ–≤
        nn.init.xavier_normal_(self.weight)

        # –ü—Ä–µ–¥–≤–∞—Ä–∏—Ç–µ–ª—å–Ω—ã–µ –≤—ã—á–∏—Å–ª–µ–Ω–∏—è
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)

        # –£–≥–æ–ª –¥–ª—è easy margin
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

        print(f"ArcFace –∏–Ω–∏—Ü–∏–∞–ª–∏–∑–∏—Ä–æ–≤–∞–Ω: s={s}, m={m}")

    def forward(self, input, label):
        # –ù–æ—Ä–º–∞–ª–∏–∑—É–µ–º –≤—Ö–æ–¥–Ω—ã–µ —ç–º–±–µ–¥–¥–∏–Ω–≥–∏
        cosine = F.linear(F.normalize(input, p=2, dim=1),
                         F.normalize(self.weight, p=2, dim=1))

        # –û–±–µ—Å–ø–µ—á–∏–≤–∞–µ–º —á–∏—Å–ª–µ–Ω–Ω—É—é —Å—Ç–∞–±–∏–ª—å–Ω–æ—Å—Ç—å
        cosine = cosine.clamp(-1 + 1e-7, 1 - 1e-7)

        # –í—ã—á–∏—Å–ª—è–µ–º —Å–∏–Ω—É—Å
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2) + 1e-7)

        # –ö–æ—Å–∏–Ω—É—Å —É–≥–ª–∞ —Å –º–∞—Ä–∂–æ–π
        phi = cosine * self.cos_m - sine * self.sin_m

        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)

        # One-hot encoding –º–µ—Ç–æ–∫
        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)

        # –ö–æ–º–±–∏–Ω–∏—Ä—É–µ–º
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s

        return output

class ArcFaceModel(nn.Module):
    def __init__(self, num_classes=300, embedding_size=512, s=8.0, m=0.5):
        super().__init__()

        # Backbone (ResNet18)
        self.backbone = models.resnet18(pretrained=True)
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()

        # Embedding —Å–ª–æ–π –° BatchNorm
        self.embedding = nn.Sequential(
            nn.Linear(in_features, embedding_size),
            nn.BatchNorm1d(embedding_size),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3)
        )

        # ArcFace —Å–ª–æ–π
        self.arcface = ArcMarginProduct(
            embedding_size,
            num_classes,
            s=s,
            m=m,
            easy_margin=True
        )

        print(f"\nArcFace Model —Å–æ–∑–¥–∞–Ω–∞:")
        print(f"  –ö–æ–ª–∏—á–µ—Å—Ç–≤–æ –∫–ª–∞—Å—Å–æ–≤: {num_classes}")
        print(f"  Embedding size: {embedding_size}")
        print(f"  Scale s: {s}")
        print(f"  Margin m: {m}")

    def load_from_ce_model(self, ce_model_path, device='cpu'):
        """–ó–∞–≥—Ä—É–∑–∫–∞ –≤–µ—Å–æ–≤ –∏–∑ CE –º–æ–¥–µ–ª–∏"""
        print(f"\n{'='*50}")
        print("–ó–ê–ì–†–£–ó–ö–ê –í–ï–°–û–í –ò–ó CE –ú–û–î–ï–õ–ò")
        print(f"{'='*50}")

        try:
            # –ó–∞–≥—Ä—É–∂–∞–µ–º CE –º–æ–¥–µ–ª—å
            if os.path.exists(ce_model_path):
                ce_checkpoint = torch.load(ce_model_path, map_location=device)

                # –ü—Ä–æ–≤–µ—Ä—è–µ–º —Ñ–æ—Ä–º–∞—Ç —á–µ–∫–ø–æ–∏–Ω—Ç–∞
                if isinstance(ce_checkpoint, dict) and 'model_state_dict' in ce_checkpoint:
                    ce_state_dict = ce_checkpoint['model_state_dict']
                else:
                    ce_state_dict = ce_checkpoint

                # –ó–∞–≥—Ä—É–∂–∞–µ–º –≤–µ—Å–∞ –≤ ArcFace –º–æ–¥–µ–ª—å
                self_state_dict = self.state_dict()

                # 1. –ó–∞–≥—Ä—É–∂–∞–µ–º backbone (–ø–æ–ª–Ω–æ–µ —Å–æ–≤–ø–∞–¥–µ–Ω–∏–µ –∏–º–µ–Ω)
                for name, param in self.backbone.named_parameters():
                    ce_key = f'backbone.{name}'
                    if ce_key in ce_state_dict:
                        param.data.copy_(ce_state_dict[ce_key])

                # 2. –ó–∞–≥—Ä—É–∂–∞–µ–º embedding —Å–ª–æ–π (–ø–æ–ª–Ω–æ–µ —Å–æ–≤–ø–∞–¥–µ–Ω–∏–µ –∏–º–µ–Ω)
                for name, param in self.embedding.named_parameters():
                    ce_key = f'embedding.{name}'
                    if ce_key in ce_state_dict:
                        param.data.copy_(ce_state_dict[ce_key])

                # 3. –ó–∞–≥—Ä—É–∂–∞–µ–º BatchNorm —Å—Ç–∞—Ç–∏—Å—Ç–∏–∫—É
                for name, buffer in self.embedding.named_buffers():
                    ce_key = f'embedding.{name}'
                    if ce_key in ce_state_dict:
                        buffer.copy_(ce_state_dict[ce_key])

                # 4. –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∏—Ä—É–µ–º ArcFace —Å–ª–æ–π –∏–∑ –≤–µ—Å–æ–≤ –∫–ª–∞—Å—Å–∏—Ñ–∏–∫–∞—Ç–æ—Ä–∞ CE
                if 'classifier.weight' in ce_state_dict:
                    ce_weight = ce_state_dict['classifier.weight']

                    # –ù–æ—Ä–º–∞–ª–∏–∑—É–µ–º –≤–µ—Å–∞ –∫–ª–∞—Å—Å–∏—Ñ–∏–∫–∞—Ç–æ—Ä–∞
                    ce_weight_norm = F.normalize(ce_weight, p=2, dim=1)
                    ce_weight_norm = ce_weight_norm * 0.1

                    # –ö–æ–ø–∏—Ä—É–µ–º –≤ ArcFace —Å–ª–æ–π
                    self.arcface.weight.data.copy_(ce_weight_norm)

                    print(f"‚úì Backbone –∑–∞–≥—Ä—É–∂–µ–Ω")
                    print(f"‚úì Embedding —Å–ª–æ–π –∑–∞–≥—Ä—É–∂–µ–Ω")
                    print(f"‚úì ArcFace –∏–Ω–∏—Ü–∏–∞–ª–∏–∑–∏—Ä–æ–≤–∞–Ω –∏–∑ classifier.weight")
                    print(f"‚úì Scale s = {self.arcface.s}")

                    # –ü—Ä–æ–≤–µ—Ä–∫–∞ —Ä–∞–∑–º–µ—Ä–Ω–æ—Å—Ç–µ–π
                    print(f"\n–ü—Ä–æ–≤–µ—Ä–∫–∞ —Ä–∞–∑–º–µ—Ä–Ω–æ—Å—Ç–µ–π:")
                    print(f"  classifier.weight: {ce_weight.shape}")
                    print(f"  arcface.weight: {self.arcface.weight.shape}")
                    print(f"  –ù–æ—Ä–º–∞–ª–∏–∑–æ–≤–∞–Ω–Ω—ã–µ –≤–µ—Å–∞: {ce_weight_norm.norm(dim=1).mean():.4f}")

                else:
                    print("‚ö† classifier.weight –Ω–µ –Ω–∞–π–¥–µ–Ω, –∏–Ω–∏—Ü–∏–∞–ª–∏–∑–∏—Ä—É–µ–º ArcFace —Å–ª—É—á–∞–π–Ω–æ")
                    self.arcface.s = 8.0  # –ú–µ–Ω—å—à–∏–π scale

            else:
                print(f"‚ö† –§–∞–π–ª {ce_model_path} –Ω–µ –Ω–∞–π–¥–µ–Ω")

        except Exception as e:
            print(f"‚ùå –û—à–∏–±–∫–∞ –ø—Ä–∏ –∑–∞–≥—Ä—É–∑–∫–µ: {e}")
            print("–û–±—É—á–∞–µ–º ArcFace —Å –Ω—É–ª—è...")
            self.arcface.s = 16.0  # –ú–µ–Ω—å—à–∏–π scale –ø—Ä–∏ –æ–±—É—á–µ–Ω–∏–∏ —Å –Ω—É–ª—è

    def forward(self, x, labels=None):
        # –ò–∑–≤–ª–µ–∫–∞–µ–º —Ñ–∏—á–∏
        features = self.backbone(x)

        # –ü–æ–ª—É—á–∞–µ–º —ç–º–±–µ–¥–¥–∏–Ω–≥–∏
        embeddings = self.embedding(features)

        # –û–±—è–∑–∞—Ç–µ–ª—å–Ω–æ –Ω–æ—Ä–º–∞–ª–∏–∑—É–µ–º –¥–ª—è ArcFace
        embeddings = F.normalize(embeddings, p=2, dim=1)

        if labels is not None:
            # –î–ª—è –æ–±—É—á–µ–Ω–∏—è: –≤–æ–∑–≤—Ä–∞—â–∞–µ–º logits –∏ —ç–º–±–µ–¥–¥–∏–Ω–≥–∏
            logits = self.arcface(embeddings, labels)
            return logits, embeddings

        # –î–ª—è –∏–Ω—Ñ–µ—Ä–µ–Ω—Å–∞: –≤–æ–∑–≤—Ä–∞—â–∞–µ–º —Ç–æ–ª—å–∫–æ —ç–º–±–µ–¥–¥–∏–Ω–≥–∏
        return embeddings

# ==================== –§–£–ù–ö–¶–ò–ò –û–ë–£–ß–ï–ù–ò–Ø ====================
def train_epoch_ce(model, loader, optimizer, criterion, device):
    """–û–±—É—á–µ–Ω–∏–µ CE –º–æ–¥–µ–ª–∏ –¥–ª—è –æ–¥–Ω–æ–π —ç–ø–æ—Ö–∏"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    pbar = tqdm(loader, desc="Training CE")
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        logits, _ = model(images)
        loss = criterion(logits, labels)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        total_loss += loss.item() * images.size(0)
        _, predicted = logits.max(1)
        correct += predicted.eq(labels).sum().item()
        total += images.size(0)

        current_acc = 100. * predicted.eq(labels).sum().item() / images.size(0)
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{current_acc:.1f}%'
        })

    avg_loss = total_loss / total if total > 0 else 0
    accuracy = 100. * correct / total if total > 0 else 0

    return avg_loss, accuracy


def train_epoch_arcface(model, loader, optimizer, criterion, device, epoch):
    """–û–±—É—á–µ–Ω–∏–µ ArcFace –º–æ–¥–µ–ª–∏ –¥–ª—è –æ–¥–Ω–æ–π —ç–ø–æ—Ö–∏"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    # Progressive scaling
    if epoch < 3:
        model.arcface.s = 8.0
    elif epoch < 6:
        model.arcface.s = 16.0
    elif epoch < 8:
        model.arcface.s = 24.0
    else:
        model.arcface.s = 32.0

    pbar = tqdm(loader, desc=f"Training ArcFace (s={model.arcface.s:.1f})")
    for batch_idx, (images, labels) in enumerate(pbar):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        logits, _ = model(images, labels)

        if batch_idx == 0 and epoch == 0:
            print(f"\n[DEBUG] –ü–µ—Ä–≤—ã–π –±–∞—Ç—á —ç–ø–æ—Ö–∏ {epoch+1}:")
            print(f"  Logits min/max: {logits.min():.4f}/{logits.max():.4f}")
            print(f"  Logits mean/std: {logits.mean():.4f}/{logits.std():.4f}")
            print(f"  Scale s: {model.arcface.s}")

        loss = criterion(logits, labels)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        # –í—ã—á–∏—Å–ª—è–µ–º accuracy (–¥–ª—è ArcFace –Ω—É–∂–µ–Ω softmax)
        with torch.no_grad():
            probs = F.softmax(logits, dim=1)
            predicted = probs.argmax(dim=1)

            total_loss += loss.item() * images.size(0)
            correct += predicted.eq(labels).sum().item()
            total += images.size(0)

            current_acc = 100. * predicted.eq(labels).sum().item() / images.size(0)
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{current_acc:.1f}%'
            })

    avg_loss = total_loss / total if total > 0 else 0
    accuracy = 100. * correct / total if total > 0 else 0

    return avg_loss, accuracy


def validate(model, loader, criterion, device, model_type='ce'):
    """–í–∞–ª–∏–¥–∞—Ü–∏—è –º–æ–¥–µ–ª–∏"""
    model.eval()
    val_loss = 0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Validation"):
            images, labels = images.to(device), labels.to(device)

            if model_type == 'ce':
                logits, _ = model(images)
                _, predicted = logits.max(1)
            else:  # arcface
                logits, _ = model(images, labels)
                probs = F.softmax(logits, dim=1)
                predicted = probs.argmax(dim=1)

            loss = criterion(logits, labels)
            val_loss += loss.item() * images.size(0)
            val_correct += predicted.eq(labels).sum().item()
            val_total += images.size(0)

    avg_loss = val_loss / val_total
    accuracy = 100. * val_correct / val_total

    return avg_loss, accuracy

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    config = Config()

    # ==================== 1. –ü–û–î–ì–û–¢–û–í–ö–ê –î–ê–ù–ù–´–• ====================
    print("\n" + "="*60)
    print("–ü–û–î–ì–û–¢–û–í–ö–ê –î–ê–ù–ù–´–•")
    print("="*60)

    # –ö–æ–ø–∏—Ä—É–µ–º –≤–∞—à —Ä–∞–±–æ—Ç–∞—é—â–∏–π –∫–æ–¥ –ø–æ–¥–≥–æ—Ç–æ–≤–∫–∏ –¥–∞–Ω–Ω—ã—Ö
    processor = CelebADataProcessor(config)
    num_classes = processor.filter_data()
    stats = processor.split_data_by_images()

    # –¢—Ä–∞–Ω—Å—Ñ–æ—Ä–º–∞—Ü–∏–∏ (—Ç–∞–∫–∏–µ –∂–µ –∫–∞–∫ –≤ —Ä–∞–±–æ—Ç–∞—é—â–µ–º –∫–æ–¥–µ)
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])

    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])

    # –î–∞—Ç–∞—Å–µ—Ç—ã
    train_dataset = CelebAClassificationDataset(
        processor.train_df, config.img_dir, train_transform
    )
    val_dataset = CelebAClassificationDataset(
        processor.val_df, config.img_dir, val_transform
    )
    test_dataset = CelebAClassificationDataset(
        processor.test_df, config.img_dir, val_transform
    )

    # –î–∞—Ç–∞–ª–æ–∞–¥–µ—Ä—ã
    train_loader = DataLoader(
        train_dataset, batch_size=config.batch_size,
        shuffle=True, num_workers=config.num_workers,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=config.batch_size,
        shuffle=False, num_workers=config.num_workers
    )
    test_loader = DataLoader(
        test_dataset, batch_size=config.batch_size,
        shuffle=False, num_workers=config.num_workers
    )

    # –°–æ–∑–¥–∞–µ–º –ø–∞–ø–∫—É –¥–ª—è —á–µ–∫–ø–æ–∏–Ω—Ç–æ–≤
    os.makedirs('checkpoints', exist_ok=True)

    # ==================== 2. –ó–ê–ì–†–£–ó–ö–ê –û–ë–£–ß–ï–ù–ù–û–ô CE –ú–û–î–ï–õ–ò ====================
    print("\n" + "="*60)
    print("–ó–ê–ì–†–£–ó–ö–ê –û–ë–£–ß–ï–ù–ù–û–ô CE –ú–û–î–ï–õ–ò")
    print("="*60)

    # –ü—Ä–æ–≤–µ—Ä—è–µ–º, –µ—Å—Ç—å –ª–∏ —É–∂–µ –æ–±—É—á–µ–Ω–Ω–∞—è CE –º–æ–¥–µ–ª—å
    ce_checkpoint_path = 'checkpoints/best_model_ce.pth'

    if not os.path.exists(ce_checkpoint_path):
        print(f"–§–∞–π–ª {ce_checkpoint_path} –Ω–µ –Ω–∞–π–¥–µ–Ω")
        print("–û–±—É—á–∞–µ–º CE –º–æ–¥–µ–ª—å —Å–Ω–∞—á–∞–ª–∞...")

        # –û–±—É—á–∞–µ–º CE –º–æ–¥–µ–ª—å
        model_ce = SimpleFaceModel(
            num_classes=num_classes,
            embedding_size=config.embedding_size
        ).to(device)

        optimizer_ce = optim.Adam(
            model_ce.parameters(),
            lr=0.001,
            weight_decay=1e-4
        )

        scheduler_ce = optim.lr_scheduler.StepLR(optimizer_ce, step_size=10, gamma=0.5)
        criterion_ce = nn.CrossEntropyLoss()

        best_ce_acc = 0
        for epoch in range(config.num_epochs_ce):
            print(f"\nEpoch {epoch+1}/{config.num_epochs_ce} (CE)")
            print("-" * 40)

            train_loss, train_acc = train_epoch_ce(
                model_ce, train_loader, optimizer_ce, criterion_ce, device
            )

            val_loss, val_acc = validate(
                model_ce, val_loader, criterion_ce, device, 'ce'
            )

            print(f"\n–ò—Ç–æ–≥–∏ —ç–ø–æ—Ö–∏:")
            print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
            print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

            if val_acc > best_ce_acc:
                best_ce_acc = val_acc
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model_ce.state_dict(),
                    'val_acc': val_acc,
                    'optimizer_state_dict': optimizer_ce.state_dict(),
                }, ce_checkpoint_path)
                print(f"  ‚úì –°–æ—Ö—Ä–∞–Ω–µ–Ω–∞ CE –º–æ–¥–µ–ª—å (Acc: {val_acc:.2f}%)")

            scheduler_ce.step()

    # ==================== 3. FINE-TUNING –° ARCFACE ====================
    print("\n" + "="*60)
    print("FINE-TUNING –° ARCFACE LOSS")
    print("="*60)

    # –°–æ–∑–¥–∞–µ–º ArcFace –º–æ–¥–µ–ª—å
    model_arcface = ArcFaceModel(
        num_classes=num_classes,
        embedding_size=config.embedding_size,
        s=8.0,  # –ù–∞—á–∏–Ω–∞–µ–º —Å –º–µ–Ω—å—à–µ–≥–æ scale
        m=config.arcface_m
    ).to(device)

    # –ó–∞–≥—Ä—É–∂–∞–µ–º –≤–µ—Å–∞ –∏–∑ CE –º–æ–¥–µ–ª–∏
    model_arcface.load_from_ce_model(ce_checkpoint_path, device)

    # –ü—Ä–æ–≤–µ—Ä–∫–∞ –ø–æ—Å–ª–µ –∑–∞–≥—Ä—É–∑–∫–∏
    print("\n" + "="*50)
    print("–ü–†–û–í–ï–†–ö–ê –ü–û–°–õ–ï –ó–ê–ì–†–£–ó–ö–ò –í–ï–°–û–í")
    print("="*50)

    model_arcface.eval()
    with torch.no_grad():
        test_images, test_labels = next(iter(train_loader))
        test_images = test_images[:4].to(device)
        test_labels = test_labels[:4].to(device)

        # –ü—Ä–æ–≤–µ—Ä–∫–∞ forward pass
        logits, embeddings = model_arcface(test_images, test_labels)

        print(f"\n1. –ü—Ä–æ–≤–µ—Ä–∫–∞ —Ä–∞–∑–º–µ—Ä–Ω–æ—Å—Ç–µ–π:")
        print(f"   Logits shape: {logits.shape}")
        print(f"   Embeddings shape: {embeddings.shape}")

        print(f"\n2. –ü—Ä–æ–≤–µ—Ä–∫–∞ —á–∏—Å–ª–µ–Ω–Ω—ã—Ö –∑–Ω–∞—á–µ–Ω–∏–π:")
        print(f"   Logits range: [{logits.min():.4f}, {logits.max():.4f}]")
        print(f"   Logits mean/std: {logits.mean():.4f} / {logits.std():.4f}")

        print(f"\n3. –ü—Ä–æ–≤–µ—Ä–∫–∞ accuracy:")
        probs = F.softmax(logits, dim=1)
        predicted = probs.argmax(dim=1)
        accuracy = (predicted == test_labels).float().mean().item()
        print(f"   Accuracy –Ω–∞ —Ç–µ—Å—Ç–æ–≤–æ–º –±–∞—Ç—á–µ: {accuracy:.2%}")

        # –ü—Ä–æ–≤–µ—Ä–∫–∞ –ø–æ—Ç–µ—Ä—å
        test_loss = nn.CrossEntropyLoss()(logits, test_labels)
        print(f"   Loss: {test_loss.item():.4f}")

    # –û–ø—Ç–∏–º–∏–∑–∞—Ç–æ—Ä –¥–ª—è fine-tuning
    optimizer_arc = optim.AdamW(
        model_arcface.parameters(),
        lr=config.learning_rate,  # –ú–∞–ª—ã–π LR –¥–ª—è fine-tuning
        weight_decay=0.0005
    )

    # Scheduler
    scheduler_arc = optim.lr_scheduler.CosineAnnealingLR(
        optimizer_arc,
        T_max=config.num_epochs_arcface,
        eta_min=1e-6
    )

    # Loss function —Å label smoothing
    criterion_arc = nn.CrossEntropyLoss(label_smoothing=0.1)

    # –û–±—É—á–µ–Ω–∏–µ ArcFace
    print("\n" + "="*60)
    print("–ù–ê–ß–ê–õ–û –û–ë–£–ß–ï–ù–ò–Ø ARCFACE")
    print("="*60)

    best_arc_acc = 0
    arc_history = {'loss': [], 'acc': [], 'val_loss': [], 'val_acc': []}

    for epoch in range(config.num_epochs_arcface):
        print(f"\n{'='*50}")
        print(f"Epoch {epoch+1}/{config.num_epochs_arcface} - ArcFace Fine-tuning")
        print(f"{'='*50}")

        # –û–±—É—á–µ–Ω–∏–µ
        train_loss, train_acc = train_epoch_arcface(
            model_arcface, train_loader, optimizer_arc, criterion_arc, device, epoch
        )

        # –í–∞–ª–∏–¥–∞—Ü–∏—è
        val_loss, val_acc = validate(
            model_arcface, val_loader, criterion_arc, device, 'arcface'
        )

        # –°–æ—Ö—Ä–∞–Ω—è–µ–º –∏—Å—Ç–æ—Ä–∏—é
        arc_history['loss'].append(train_loss)
        arc_history['acc'].append(train_acc)
        arc_history['val_loss'].append(val_loss)
        arc_history['val_acc'].append(val_acc)

        print(f"\nüìä –ò—Ç–æ–≥–∏ —ç–ø–æ—Ö–∏:")
        print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        print(f"  Scale s: {model_arcface.arcface.s}")
        print(f"  Learning Rate: {optimizer_arc.param_groups[0]['lr']:.2e}")

        # –°–æ—Ö—Ä–∞–Ω—è–µ–º –ª—É—á—à—É—é –º–æ–¥–µ–ª—å
        if val_acc > best_arc_acc:
            best_arc_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model_arcface.state_dict(),
                'optimizer_state_dict': optimizer_arc.state_dict(),
                'val_acc': val_acc,
                'scale': model_arcface.arcface.s,
            }, 'checkpoints/best_model_arcface.pth')
            print(f"  üíæ –°–æ—Ö—Ä–∞–Ω–µ–Ω–∞ –ª—É—á—à–∞—è ArcFace –º–æ–¥–µ–ª—å (Val Acc: {val_acc:.2f}%)")

        # –û–±–Ω–æ–≤–ª—è–µ–º scheduler
        scheduler_arc.step()

    # ==================== 4. –¢–ï–°–¢–ò–†–û–í–ê–ù–ò–ï ====================
    print("\n" + "="*60)
    print("–¢–ï–°–¢–ò–†–û–í–ê–ù–ò–ï ARCFACE –ú–û–î–ï–õ–ò")
    print("="*60)

    # –ó–∞–≥—Ä—É–∂–∞–µ–º –ª—É—á—à—É—é ArcFace –º–æ–¥–µ–ª—å
    arc_checkpoint = torch.load('checkpoints/best_model_arcface.pth')
    model_arcface.load_state_dict(arc_checkpoint['model_state_dict'])

    # –¢–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏–µ
    test_loss, test_acc = validate(
        model_arcface, test_loader, criterion_arc, device, 'arcface'
    )

    print(f"\nüìä –†–µ–∑—É–ª—å—Ç–∞—Ç—ã ArcFace –Ω–∞ —Ç–µ—Å—Ç–æ–≤–æ–π –≤—ã–±–æ—Ä–∫–µ:")
    print(f"  Test Loss: {test_loss:.4f}")
    print(f"  Test Accuracy: {test_acc:.2f}%")
    print(f"  Final Scale s: {model_arcface.arcface.s}")

    # ==================== 5. –í–ò–ó–£–ê–õ–ò–ó–ê–¶–ò–Ø ====================
    print("\n" + "="*60)
    print("–í–ò–ó–£–ê–õ–ò–ó–ê–¶–ò–Ø –†–ï–ó–£–õ–¨–¢–ê–¢–û–í")
    print("="*60)

    # –°–æ–∑–¥–∞–µ–º –≥—Ä–∞—Ñ–∏–∫–∏
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))

    # Loss history
    axes[0, 0].plot(arc_history['loss'], label='Train Loss', linewidth=2)
    axes[0, 0].plot(arc_history['val_loss'], label='Val Loss', linewidth=2)
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('ArcFace: Loss History')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # Accuracy history
    axes[0, 1].plot(arc_history['acc'], label='Train Acc', linewidth=2)
    axes[0, 1].plot(arc_history['val_acc'], label='Val Acc', linewidth=2)
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy (%)')
    axes[0, 1].set_title('ArcFace: Accuracy History')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

    # Learning rate history
    lr_history = [group['lr'] for group in optimizer_arc.param_groups]
    axes[1, 0].plot(lr_history, linewidth=2)
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Learning Rate')
    axes[1, 0].set_title('Learning Rate Schedule')
    axes[1, 0].grid(True, alpha=0.3)

    # Scale s history
    scale_progression = [32, 32, 32, 48, 48, 48, 64, 64, 64, 64][:len(arc_history['loss'])]
    axes[1, 1].plot(scale_progression, linewidth=2, color='green')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Scale s')
    axes[1, 1].set_title('ArcFace Scale Progression')
    axes[1, 1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('arcface_training_results.png', dpi=150, bbox_inches='tight')
    plt.show()

    print("\n" + "="*60)
    print("–û–ë–£–ß–ï–ù–ò–ï ARCFACE –ó–ê–í–ï–†–®–ï–ù–û!")
    print("="*60)
    print(f"–õ—É—á—à–∞—è —Ç–æ—á–Ω–æ—Å—Ç—å –Ω–∞ –≤–∞–ª–∏–¥–∞—Ü–∏–∏: {best_arc_acc:.2f}%")
    print(f"–¢–æ—á–Ω–æ—Å—Ç—å –Ω–∞ —Ç–µ—Å—Ç–µ: {test_acc:.2f}%")


if __name__ == "__main__":
    main()

–ï–©–ï –û–î–ò–ù –í–ê–†–ò–ê–ù–¢

In [None]:
import os
import math
import logging
from datetime import datetime
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.model_selection import train_test_split
from sklearn.neighbors import NearestNeighbors
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

# ==================== –ö–û–ù–§–ò–ì–£–†–ê–¶–ò–Ø ====================
class Config:
    def __init__(self):
        # –≠—Ç–∏ –ø–µ—Ä–µ–º–µ–Ω–Ω—ã–µ –¥–æ–ª–∂–Ω—ã –±—ã—Ç—å –æ–ø—Ä–µ–¥–µ–ª–µ–Ω—ã –≤ –≤–∞—à–µ–º –æ–∫—Ä—É–∂–µ–Ω–∏–∏
        # self.identity_df = df_id
        # self.img_dir = img_dir
        self.max_classes = 350
        self.min_samples_per_person = 26
        self.seed = 42
        self.val_ratio = 0.15
        self.test_ratio = 0.15
        self.batch_size = 64
        self.num_workers = 0
        self.embedding_size = 512
        self.learning_rate = 0.0001
        self.num_epochs_ce = 15
        self.num_epochs_arcface = 10
        self.arcface_s = 32.0
        self.arcface_m = 0.5
        self.arcface_easy_margin = True

# ==================== –û–ë–†–ê–ë–û–¢–ö–ê –î–ê–ù–ù–´–• ====================
class CelebADataProcessor:
    def __init__(self, config):
        self.config = config
        self.identity_df = config.identity_df

        print(f"–í—Å–µ–≥–æ –¥–∞–Ω–Ω—ã—Ö: {len(self.identity_df)} –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–π")
        print(f"–£–Ω–∏–∫–∞–ª—å–Ω—ã—Ö –ª—é–¥–µ–π: {self.identity_df['person_id'].nunique()}")

    def filter_data(self):
        person_counts = self.identity_df['person_id'].value_counts()
        top_persons = person_counts.nlargest(self.config.max_classes).index
        self.filtered_df = self.identity_df[self.identity_df['person_id'].isin(top_persons)].copy()

        unique_ids = sorted(self.filtered_df['person_id'].unique())
        self.id_to_idx = {old_id: idx for idx, old_id in enumerate(unique_ids)}
        self.idx_to_id = {idx: old_id for old_id, idx in self.id_to_idx.items()}

        self.filtered_df['class_idx'] = self.filtered_df['person_id'].map(self.id_to_idx)

        print(f"\n–ü–æ—Å–ª–µ —Ñ–∏–ª—å—Ç—Ä–∞—Ü–∏–∏:")
        print(f"  –í—Å–µ–≥–æ –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–π: {len(self.filtered_df)}")
        print(f"  –£–Ω–∏–∫–∞–ª—å–Ω—ã—Ö –ª—é–¥–µ–π: {self.filtered_df['person_id'].nunique()}")
        print(f"  –î–∏–∞–ø–∞–∑–æ–Ω –º–µ—Ç–æ–∫: {self.filtered_df['class_idx'].min()} - {self.filtered_df['class_idx'].max()}")

        return len(unique_ids)

    def split_data_by_images(self):
        print(f"\n–†–∞–∑–¥–µ–ª–µ–Ω–∏–µ –¥–∞–Ω–Ω—ã—Ö –ø–æ –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏—è–º (—Å—Ç—Ä–∞—Ç–∏—Ñ–∏—Ü–∏—Ä–æ–≤–∞–Ω–æ)...")

        train_df, temp_df = train_test_split(
            self.filtered_df,
            test_size=self.config.val_ratio + self.config.test_ratio,
            random_state=self.config.seed,
            stratify=self.filtered_df['class_idx']
        )

        val_df, test_df = train_test_split(
            temp_df,
            test_size=self.config.test_ratio/(self.config.val_ratio + self.config.test_ratio),
            random_state=self.config.seed,
            stratify=temp_df['class_idx']
        )

        self.train_df = train_df.reset_index(drop=True)
        self.val_df = val_df.reset_index(drop=True)
        self.test_df = test_df.reset_index(drop=True)

        print(f"Train: {len(self.train_df)} samples")
        print(f"Val: {len(self.val_df)} samples")
        print(f"Test: {len(self.test_df)} samples")

        return {
            'train': len(self.train_df),
            'val': len(self.val_df),
            'test': len(self.test_df),
            'num_persons': len(self.filtered_df['person_id'].unique())
        }

# ==================== –î–ê–¢–ê–°–ï–¢ ====================
class CelebAClassificationDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform

        # –ù–∞—Å—Ç—Ä–æ–π–∫–∞ –ª–æ–≥–∏—Ä–æ–≤–∞–Ω–∏—è
        logging.basicConfig(
            level=logging.WARNING,
            format='%(asctime)s - %(levelname)s - %(message)s'
        )
        self.logger = logging.getLogger(__name__)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row['filename'])

        try:
            image = Image.open(img_path).convert('RGB')

            # –ü—Ä–æ–≤–µ—Ä–∫–∞ –≤–∞–ª–∏–¥–Ω–æ—Å—Ç–∏ –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏—è
            if image.mode != 'RGB':
                image = image.convert('RGB')

            # –ü—Ä–æ–≤–µ—Ä–∫–∞ —Ä–∞–∑–º–µ—Ä–∞
            if image.size[0] < 10 or image.size[1] < 10:
                self.logger.warning(f"Image {img_path} too small: {image.size}")
                image = image.resize((224, 224), Image.BILINEAR)

        except Exception as e:
            self.logger.warning(f"Cannot load image {img_path}: {e}")
            # –°–æ–∑–¥–∞–µ–º —á–µ—Ä–Ω–æ–µ –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–µ
            image = Image.new('RGB', (224, 224), color=(0, 0, 0))

        if self.transform:
            image = self.transform(image)

        label = int(row['class_idx'])
        return image, label

# ==================== –ú–û–î–ï–õ–ò ====================
class SimpleFaceModel(nn.Module):
    def __init__(self, num_classes=350, embedding_size=512):
        super().__init__()

        self.backbone = models.resnet18(pretrained=True)
        in_features = self.backbone.fc.in_features

        # –ó–∞–º–µ–Ω—è–µ–º –ø–æ—Å–ª–µ–¥–Ω–∏–π —Å–ª–æ–π
        self.backbone.fc = nn.Identity()

        # –≠–º–±–µ–¥–¥–∏–Ω–≥ —Å–ª–æ–π
        self.embedding = nn.Sequential(
            nn.Linear(in_features, embedding_size),
            nn.BatchNorm1d(embedding_size),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3)
        )

        # –ö–ª–∞—Å—Å–∏—Ñ–∏–∫–∞—Ç–æ—Ä
        self.classifier = nn.Linear(embedding_size, num_classes)

        # –ó–∞–º–æ—Ä–∞–∂–∏–≤–∞–µ–º –ø–µ—Ä–≤—ã–µ —Å–ª–æ–∏
        for name, param in self.backbone.named_parameters():
            if 'layer1' in name or 'conv1' in name or 'bn1' in name:
                param.requires_grad = False

    def forward(self, x):
        features = self.backbone(x)
        embedding = self.embedding(features)
        logits = self.classifier(embedding)
        return logits, embedding


class ArcMarginProduct(nn.Module):
    """ArcFace loss —Å–ª–æ–π"""
    def __init__(self, in_features, out_features, s=32.0, m=0.5, easy_margin=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.easy_margin = easy_margin

        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        # –ü—Ä–µ–¥–≤–∞—Ä–∏—Ç–µ–ª—å–Ω—ã–µ –≤—ã—á–∏—Å–ª–µ–Ω–∏—è
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)

        # –£–≥–æ–ª –¥–ª—è easy margin
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

        print(f"ArcFace –∏–Ω–∏—Ü–∏–∞–ª–∏–∑–∏—Ä–æ–≤–∞–Ω: s={s}, m={m}")

    def forward(self, input, label):
        # –ù–æ—Ä–º–∞–ª–∏–∑—É–µ–º –≤—Ö–æ–¥–Ω—ã–µ —ç–º–±–µ–¥–¥–∏–Ω–≥–∏
        cosine = F.linear(F.normalize(input, p=2, dim=1),
                         F.normalize(self.weight, p=2, dim=1))

        # –û–±–µ—Å–ø–µ—á–∏–≤–∞–µ–º —á–∏—Å–ª–µ–Ω–Ω—É—é —Å—Ç–∞–±–∏–ª—å–Ω–æ—Å—Ç—å
        cosine = cosine.clamp(-1 + 1e-7, 1 - 1e-7)

        # –í—ã—á–∏—Å–ª—è–µ–º —Å–∏–Ω—É—Å
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2) + 1e-7)

        # –ö–æ—Å–∏–Ω—É—Å —É–≥–ª–∞ —Å –º–∞—Ä–∂–æ–π
        phi = cosine * self.cos_m - sine * self.sin_m

        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)

        # One-hot encoding –º–µ—Ç–æ–∫
        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)

        # –ö–æ–º–±–∏–Ω–∏—Ä—É–µ–º
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s

        return output

class ArcFaceModel(nn.Module):
    def __init__(self, num_classes=350, embedding_size=512, s=8.0, m=0.5):
        super().__init__()

        # Backbone (ResNet18)
        self.backbone = models.resnet18(pretrained=True)
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()

        # Embedding —Å–ª–æ–π
        self.embedding = nn.Sequential(
            nn.Linear(in_features, embedding_size),
            nn.BatchNorm1d(embedding_size),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3)
        )

        # ArcFace —Å–ª–æ–π
        self.arcface = ArcMarginProduct(
            embedding_size,
            num_classes,
            s=s,
            m=m,
            easy_margin=True
        )

        print(f"\nArcFace Model —Å–æ–∑–¥–∞–Ω–∞:")
        print(f"  –ö–æ–ª–∏—á–µ—Å—Ç–≤–æ –∫–ª–∞—Å—Å–æ–≤: {num_classes}")
        print(f"  Embedding size: {embedding_size}")
        print(f"  Scale s: {s}")
        print(f"  Margin m: {m}")

    def load_from_ce_model(self, ce_model_path, device='cpu'):
        """–ó–∞–≥—Ä—É–∑–∫–∞ –≤–µ—Å–æ–≤ –∏–∑ CE –º–æ–¥–µ–ª–∏"""
        print(f"\n{'='*50}")
        print("–ó–ê–ì–†–£–ó–ö–ê –í–ï–°–û–í –ò–ó CE –ú–û–î–ï–õ–ò")
        print(f"{'='*50}")

        try:
            if not os.path.exists(ce_model_path):
                print(f"‚ö† –§–∞–π–ª {ce_model_path} –Ω–µ –Ω–∞–π–¥–µ–Ω")
                print("–û–±—É—á–∞–µ–º ArcFace —Å –Ω—É–ª—è...")
                self._initialize_random()
                return False

            # –ó–∞–≥—Ä—É–∂–∞–µ–º CE –º–æ–¥–µ–ª—å
            ce_checkpoint = torch.load(ce_model_path, map_location=device)

            # –ü—Ä–æ–≤–µ—Ä—è–µ–º —Ñ–æ—Ä–º–∞—Ç —á–µ–∫–ø–æ–∏–Ω—Ç–∞
            if isinstance(ce_checkpoint, dict) and 'model_state_dict' in ce_checkpoint:
                ce_state_dict = ce_checkpoint['model_state_dict']
            else:
                ce_state_dict = ce_checkpoint

            # 1. –ó–∞–≥—Ä—É–∂–∞–µ–º backbone
            backbone_dict = {}
            for key, value in ce_state_dict.items():
                if key.startswith('backbone.'):
                    # –£–±–∏—Ä–∞–µ–º 'backbone.' –∏–∑ –∫–ª—é—á–∞
                    new_key = key.replace('backbone.', '', 1)
                    backbone_dict[new_key] = value

            if backbone_dict:
                missing_keys, unexpected_keys = self.backbone.load_state_dict(backbone_dict, strict=False)
                print("‚úì Backbone –∑–∞–≥—Ä—É–∂–µ–Ω")
                if missing_keys:
                    print(f"  –û—Ç—Å—É—Ç—Å—Ç–≤—É—é—â–∏–µ –∫–ª—é—á–∏: {missing_keys[:5]}{'...' if len(missing_keys) > 5 else ''}")
                if unexpected_keys:
                    print(f"  –ù–µ–æ–∂–∏–¥–∞–Ω–Ω—ã–µ –∫–ª—é—á–∏: {unexpected_keys[:5]}{'...' if len(unexpected_keys) > 5 else ''}")
            else:
                print("‚ö† Backbone –≤–µ—Å–∞ –Ω–µ –Ω–∞–π–¥–µ–Ω—ã")

            # 2. –ó–∞–≥—Ä—É–∂–∞–µ–º embedding —Å–ª–æ–π
            embedding_dict = {}
            for key, value in ce_state_dict.items():
                if key.startswith('embedding.'):
                    new_key = key.replace('embedding.', '', 1)
                    embedding_dict[new_key] = value

            if embedding_dict:
                self.embedding.load_state_dict(embedding_dict, strict=False)
                print("‚úì Embedding —Å–ª–æ–π –∑–∞–≥—Ä—É–∂–µ–Ω")
            else:
                print("‚ö† Embedding –≤–µ—Å–∞ –Ω–µ –Ω–∞–π–¥–µ–Ω—ã")

            # 3. –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∏—Ä—É–µ–º ArcFace —Å–ª–æ–π
            self._initialize_arcface()
            print("‚úì ArcFace –∏–Ω–∏—Ü–∏–∞–ª–∏–∑–∏—Ä–æ–≤–∞–Ω —Å–ª—É—á–∞–π–Ω–æ")
            print(f"‚úì Scale s = {self.arcface.s}")

            return True

        except Exception as e:
            print(f"‚ùå –û—à–∏–±–∫–∞ –ø—Ä–∏ –∑–∞–≥—Ä—É–∑–∫–µ: {e}")
            print("–û–±—É—á–∞–µ–º ArcFace —Å –Ω—É–ª—è...")
            self._initialize_random()
            return False

    def _initialize_arcface(self):
        """–ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∞—Ü–∏—è ArcFace —Å–ª–æ—è"""
        nn.init.xavier_uniform_(self.arcface.weight)
        self.arcface.s = 8.0

    def _initialize_random(self):
        """–ü–æ–ª–Ω–∞—è —Å–ª—É—á–∞–π–Ω–∞—è –∏–Ω–∏—Ü–∏–∞–ª–∏–∑–∞—Ü–∏—è"""
        for module in [self.embedding, self.arcface]:
            for m in module.modules():
                if isinstance(m, nn.Linear):
                    nn.init.xavier_uniform_(m.weight)
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)
                elif isinstance(m, nn.BatchNorm1d):
                    nn.init.ones_(m.weight)
                    nn.init.zeros_(m.bias)

        self.arcface.s = 8.0

    def forward(self, x, labels=None):
        # –ò–∑–≤–ª–µ–∫–∞–µ–º —Ñ–∏—á–∏
        features = self.backbone(x)

        # –ü–æ–ª—É—á–∞–µ–º —ç–º–±–µ–¥–¥–∏–Ω–≥–∏
        embeddings = self.embedding(features)

        # –ù–æ—Ä–º–∞–ª–∏–∑—É–µ–º –¥–ª—è ArcFace (–≤–∞–∂–Ω–æ!)
        embeddings = F.normalize(embeddings, p=2, dim=1)

        if labels is not None:
            # –î–ª—è –æ–±—É—á–µ–Ω–∏—è: –≤–æ–∑–≤—Ä–∞—â–∞–µ–º logits –∏ —ç–º–±–µ–¥–¥–∏–Ω–≥–∏
            logits = self.arcface(embeddings, labels)
            return logits, embeddings

        # –î–ª—è –∏–Ω—Ñ–µ—Ä–µ–Ω—Å–∞: –≤–æ–∑–≤—Ä–∞—â–∞–µ–º —Ç–æ–ª—å–∫–æ —ç–º–±–µ–¥–¥–∏–Ω–≥–∏
        return embeddings

# ==================== –£–¢–ò–õ–ò–¢–´ –î–õ–Ø –í–ê–õ–ò–î–ê–¶–ò–ò ====================
def analyze_embeddings(model, loader, device, num_samples=1000):
    """–ê–Ω–∞–ª–∏–∑ –∫–∞—á–µ—Å—Ç–≤–∞ —ç–º–±–µ–¥–¥–∏–Ω–≥–æ–≤"""
    model.eval()
    all_embeddings = []
    all_labels = []

    print("\n–ê–Ω–∞–ª–∏–∑ –∫–∞—á–µ—Å—Ç–≤–∞ —ç–º–±–µ–¥–¥–∏–Ω–≥–æ–≤...")
    with torch.no_grad():
        for i, (images, labels) in enumerate(tqdm(loader, desc="Extracting embeddings")):
            if i * loader.batch_size >= num_samples:
                break

            images = images.to(device)
            embeddings = model(images)  # –≠–º–±–µ–¥–¥–∏–Ω–≥–∏ —É–∂–µ –Ω–æ—Ä–º–∞–ª–∏–∑–æ–≤–∞–Ω—ã

            all_embeddings.append(embeddings.cpu())
            all_labels.append(labels)

    if not all_embeddings:
        return None, None

    all_embeddings = torch.cat(all_embeddings)
    all_labels = torch.cat(all_labels)

    # –ë–µ—Ä–µ–º –ø–æ–¥–≤—ã–±–æ—Ä–∫—É –¥–ª—è –∞–Ω–∞–ª–∏–∑–∞
    if len(all_embeddings) > 5000:
        indices = torch.randperm(len(all_embeddings))[:5000]
        all_embeddings = all_embeddings[indices]
        all_labels = all_labels[indices]

    # –í—ã—á–∏—Å–ª—è–µ–º –∫–æ—Å–∏–Ω—É—Å–Ω—ã–µ —Ä–∞—Å—Å—Ç–æ—è–Ω–∏—è
    embeddings_norm = all_embeddings  # –£–∂–µ –Ω–æ—Ä–º–∞–ª–∏–∑–æ–≤–∞–Ω—ã
    cosine_sim = torch.mm(embeddings_norm, embeddings_norm.t())

    # –ê–Ω–∞–ª–∏–∑ –≤–Ω—É—Ç—Ä–∏–∫–ª–∞—Å—Å–æ–≤—ã—Ö –∏ –º–µ–∂–∫–ª–∞—Å—Å–æ–≤—ã—Ö —Ä–∞—Å—Å—Ç–æ—è–Ω–∏–π
    intra_distances = []
    inter_distances = []

    unique_labels = torch.unique(all_labels)

    for i, label_i in enumerate(unique_labels[:20]):  # –û–≥—Ä–∞–Ω–∏—á–∏–º –¥–ª—è —Å–∫–æ—Ä–æ—Å—Ç–∏
        mask_i = (all_labels == label_i)
        indices_i = torch.where(mask_i)[0]

        if len(indices_i) > 1:
            # –í–Ω—É—Ç—Ä–∏ –∫–ª–∞—Å—Å–∞
            for j in range(min(len(indices_i), 5)):  # –ë–µ—Ä–µ–º –Ω–µ—Å–∫–æ–ª—å–∫–æ –ø–∞—Ä
                for k in range(j+1, min(len(indices_i), 5)):
                    sim = cosine_sim[indices_i[j], indices_i[k]].item()
                    intra_distances.append(sim)

        # –ú–µ–∂–¥—É –∫–ª–∞—Å—Å–∞–º–∏ (–±–µ—Ä–µ–º –ø–µ—Ä–≤—ã–π –¥—Ä—É–≥–æ–π –∫–ª–∞—Å—Å)
        if i < len(unique_labels) - 1:
            label_j = unique_labels[i+1]
            mask_j = (all_labels == label_j)
            indices_j = torch.where(mask_j)[0]

            if len(indices_j) > 0:
                # –ë–µ—Ä–µ–º –Ω–µ—Å–∫–æ–ª—å–∫–æ –ø–∞—Ä –º–µ–∂–¥—É –∫–ª–∞—Å—Å–∞–º–∏
                for _ in range(min(5, len(indices_i))):
                    idx_i = indices_i[torch.randint(0, len(indices_i), (1,))]
                    idx_j = indices_j[torch.randint(0, len(indices_j), (1,))]
                    sim = cosine_sim[idx_i, idx_j].item()
                    inter_distances.append(sim)

    if intra_distances and inter_distances:
        intra_mean = np.mean(intra_distances)
        inter_mean = np.mean(inter_distances)

        print(f"\nüìä –ö–∞—á–µ—Å—Ç–≤–æ —ç–º–±–µ–¥–¥–∏–Ω–≥–æ–≤:")
        print(f"  –°—Ä–µ–¥–Ω—è—è –∫–æ—Å–∏–Ω—É—Å–Ω–∞—è —Å—Ö–æ–∂–µ—Å—Ç—å –≤–Ω—É—Ç—Ä–∏ –∫–ª–∞—Å—Å–∞: {intra_mean:.4f}")
        print(f"  –°—Ä–µ–¥–Ω—è—è –∫–æ—Å–∏–Ω—É—Å–Ω–∞—è —Å—Ö–æ–∂–µ—Å—Ç—å –º–µ–∂–¥—É –∫–ª–∞—Å—Å–∞–º–∏: {inter_mean:.4f}")
        print(f"  Ratio (inter/intra): {inter_mean/intra_mean if intra_mean > 0 else float('inf'):.2f}")

        # –°–æ–∑–¥–∞–µ–º –≥–∏—Å—Ç–æ–≥—Ä–∞–º–º—É —Ä–∞—Å—Å—Ç–æ—è–Ω–∏–π
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))

        # –ì–∏—Å—Ç–æ–≥—Ä–∞–º–º–∞ –≤–Ω—É—Ç—Ä–∏–∫–ª–∞—Å—Å–æ–≤—ã—Ö —Ä–∞—Å—Å—Ç–æ—è–Ω–∏–π
        axes[0].hist(intra_distances, bins=30, alpha=0.7, color='blue', density=True)
        axes[0].axvline(x=intra_mean, color='red', linestyle='--', label=f'Mean: {intra_mean:.3f}')
        axes[0].set_xlabel('Cosine Similarity (Intra-class)')
        axes[0].set_ylabel('Density')
        axes[0].set_title('Intra-class Similarity Distribution')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)

        # –ì–∏—Å—Ç–æ–≥—Ä–∞–º–º–∞ –º–µ–∂–∫–ª–∞—Å—Å–æ–≤—ã—Ö —Ä–∞—Å—Å—Ç–æ—è–Ω–∏–π
        axes[1].hist(inter_distances, bins=30, alpha=0.7, color='green', density=True)
        axes[1].axvline(x=inter_mean, color='red', linestyle='--', label=f'Mean: {inter_mean:.3f}')
        axes[1].set_xlabel('Cosine Similarity (Inter-class)')
        axes[1].set_ylabel('Density')
        axes[1].set_title('Inter-class Similarity Distribution')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig('embedding_analysis.png', dpi=150, bbox_inches='tight')
        plt.show()

    return all_embeddings, all_labels

# ==================== –§–£–ù–ö–¶–ò–ò –û–ë–£–ß–ï–ù–ò–Ø ====================
def train_epoch_ce(model, loader, optimizer, criterion, device):
    """–û–±—É—á–µ–Ω–∏–µ CE –º–æ–¥–µ–ª–∏ –¥–ª—è –æ–¥–Ω–æ–π —ç–ø–æ—Ö–∏"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    pbar = tqdm(loader, desc="Training CE")
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        logits, _ = model(images)
        loss = criterion(logits, labels)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        total_loss += loss.item() * images.size(0)
        _, predicted = logits.max(1)
        correct += predicted.eq(labels).sum().item()
        total += images.size(0)

        current_acc = 100. * predicted.eq(labels).sum().item() / images.size(0)
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{current_acc:.1f}%'
        })

    avg_loss = total_loss / total if total > 0 else 0
    accuracy = 100. * correct / total if total > 0 else 0

    return avg_loss, accuracy

def train_epoch_arcface(model, loader, optimizer, criterion, device, epoch, total_epochs):
    """–û–±—É—á–µ–Ω–∏–µ ArcFace –º–æ–¥–µ–ª–∏ –¥–ª—è –æ–¥–Ω–æ–π —ç–ø–æ—Ö–∏"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    # Progressive scaling - –±–æ–ª–µ–µ –ø–ª–∞–≤–Ω–æ–µ —É–≤–µ–ª–∏—á–µ–Ω–∏–µ
    min_scale, max_scale = 8.0, 32.0
    if total_epochs <= 1:
        current_scale = max_scale
    else:
        progress = epoch / (total_epochs - 1)
        # –ò—Å–ø–æ–ª—å–∑—É–µ–º –∫–≤–∞–¥—Ä–∞—Ç–∏—á–Ω–æ–µ —É–≤–µ–ª–∏—á–µ–Ω–∏–µ
        current_scale = min_scale + (max_scale - min_scale) * (progress ** 0.5)

    model.arcface.s = current_scale

    pbar = tqdm(loader, desc=f"Training ArcFace (s={model.arcface.s:.1f}, m={model.arcface.m:.2f})")
    for batch_idx, (images, labels) in enumerate(pbar):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        logits, embeddings = model(images, labels)

        # DEBUG –∏–Ω—Ñ–æ—Ä–º–∞—Ü–∏—è
        if batch_idx == 0 and epoch == 0:
            print(f"\n[DEBUG] –ü–µ—Ä–≤—ã–π –±–∞—Ç—á:")
            print(f"  Logits shape: {logits.shape}")
            print(f"  Logits range: [{logits.min():.4f}, {logits.max():.4f}]")
            print(f"  Logits mean/std: {logits.mean():.4f}/{logits.std():.4f}")
            print(f"  Embeddings norm: {embeddings.norm(p=2, dim=1).mean():.4f} ¬± {embeddings.norm(p=2, dim=1).std():.4f}")
            print(f"  Scale s: {model.arcface.s}")

        loss = criterion(logits, labels)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        # –í—ã—á–∏—Å–ª—è–µ–º accuracy
        with torch.no_grad():
            # –î–ª—è ArcFace –∏—Å–ø–æ–ª—å–∑—É–µ–º cosine similarity
            embeddings_norm = embeddings  # –£–∂–µ –Ω–æ—Ä–º–∞–ª–∏–∑–æ–≤–∞–Ω—ã
            weights_norm = F.normalize(model.arcface.weight.data, p=2, dim=1)
            cosine = torch.mm(embeddings_norm, weights_norm.t())
            predicted = cosine.argmax(dim=1)

            total_loss += loss.item() * images.size(0)
            correct += predicted.eq(labels).sum().item()
            total += images.size(0)

            current_acc = 100. * predicted.eq(labels).sum().item() / images.size(0)
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{current_acc:.1f}%',
                'scale': f'{model.arcface.s:.1f}'
            })

    avg_loss = total_loss / total if total > 0 else 0
    accuracy = 100. * correct / total if total > 0 else 0

    return avg_loss, accuracy

def validate(model, loader, criterion, device, model_type='ce'):
    """–í–∞–ª–∏–¥–∞—Ü–∏—è –º–æ–¥–µ–ª–∏"""
    model.eval()
    val_loss = 0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        pbar = tqdm(loader, desc="Validation")
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)

            if model_type == 'ce':
                logits, _ = model(images)
                _, predicted = logits.max(1)
            else:  # arcface
                embeddings = model(images)
                embeddings_norm = embeddings  # –£–∂–µ –Ω–æ—Ä–º–∞–ª–∏–∑–æ–≤–∞–Ω—ã
                weights_norm = F.normalize(model.arcface.weight.data, p=2, dim=1)
                cosine = torch.mm(embeddings_norm, weights_norm.t())
                logits = model.arcface.s * cosine
                predicted = cosine.argmax(dim=1)

            loss = criterion(logits, labels)
            val_loss += loss.item() * images.size(0)
            val_correct += predicted.eq(labels).sum().item()
            val_total += images.size(0)

            current_acc = 100. * predicted.eq(labels).sum().item() / images.size(0)
            pbar.set_postfix({'acc': f'{current_acc:.1f}%'})

    avg_loss = val_loss / val_total if val_total > 0 else 0
    accuracy = 100. * val_correct / val_total if val_total > 0 else 0

    return avg_loss, accuracy

# ==================== –û–°–ù–û–í–ù–ê–Ø –§–£–ù–ö–¶–ò–Ø ====================
def main():
    # –ù–∞—Å—Ç—Ä–æ–π–∫–∞ –ª–æ–≥–∏—Ä–æ–≤–∞–Ω–∏—è
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(f'training_{timestamp}.log'),
            logging.StreamHandler()
        ]
    )
    logger = logging.getLogger(__name__)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

    # –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∞—Ü–∏—è –∫–æ–Ω—Ñ–∏–≥—É—Ä–∞—Ü–∏–∏
    config = Config()

    # ==================== 1. –ü–û–î–ì–û–¢–û–í–ö–ê –î–ê–ù–ù–´–• ====================
    print("\n" + "="*60)
    print("–ü–û–î–ì–û–¢–û–í–ö–ê –î–ê–ù–ù–´–•")
    print("="*60)

    # –ó–¥–µ—Å—å –Ω—É–∂–Ω–æ –∑–∞–≥—Ä—É–∑–∏—Ç—å –≤–∞—à–∏ –¥–∞–Ω–Ω—ã–µ
    # config.identity_df = df_id  # –≤–∞—à DataFrame
    # config.img_dir = img_dir    # –ø—É—Ç—å –∫ –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏—è–º

    if not hasattr(config, 'identity_df') or config.identity_df is None:
        print("‚ùå –û—à–∏–±–∫–∞: identity_df –Ω–µ –∑–∞–¥–∞–Ω –≤ –∫–æ–Ω—Ñ–∏–≥—É—Ä–∞—Ü–∏–∏")
        print("–£—Å—Ç–∞–Ω–æ–≤–∏—Ç–µ config.identity_df = –≤–∞—à_dataframe")
        return

    if not hasattr(config, 'img_dir') or config.img_dir is None:
        print("‚ùå –û—à–∏–±–∫–∞: img_dir –Ω–µ –∑–∞–¥–∞–Ω –≤ –∫–æ–Ω—Ñ–∏–≥—É—Ä–∞—Ü–∏–∏")
        print("–£—Å—Ç–∞–Ω–æ–≤–∏—Ç–µ config.img_dir = '–ø—É—Ç—å/–∫/–∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏—è–º'")
        return

    processor = CelebADataProcessor(config)
    num_classes = processor.filter_data()
    stats = processor.split_data_by_images()

    # –¢—Ä–∞–Ω—Å—Ñ–æ—Ä–º–∞—Ü–∏–∏
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])

    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])

    # –î–∞—Ç–∞—Å–µ—Ç—ã
    train_dataset = CelebAClassificationDataset(
        processor.train_df, config.img_dir, train_transform
    )
    val_dataset = CelebAClassificationDataset(
        processor.val_df, config.img_dir, val_transform
    )
    test_dataset = CelebAClassificationDataset(
        processor.test_df, config.img_dir, val_transform
    )

    # –î–∞—Ç–∞–ª–æ–∞–¥–µ—Ä—ã
    train_loader = DataLoader(
        train_dataset, batch_size=config.batch_size,
        shuffle=True, num_workers=config.num_workers,
        pin_memory=True, drop_last=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=config.batch_size,
        shuffle=False, num_workers=config.num_workers
    )
    test_loader = DataLoader(
        test_dataset, batch_size=config.batch_size,
        shuffle=False, num_workers=config.num_workers
    )

    # –°–æ–∑–¥–∞–µ–º –ø–∞–ø–∫—É –¥–ª—è —á–µ–∫–ø–æ–∏–Ω—Ç–æ–≤
    os.makedirs('checkpoints', exist_ok=True)

    # ==================== 2. –û–ë–£–ß–ï–ù–ò–ï CE –ú–û–î–ï–õ–ò ====================
    print("\n" + "="*60)
    print("–û–ë–£–ß–ï–ù–ò–ï CE –ú–û–î–ï–õ–ò")
    print("="*60)

    ce_checkpoint_path = 'checkpoints/best_model_ce.pth'
    ce_model_exists = os.path.exists(ce_checkpoint_path)

    if not ce_model_exists:
        print("–û–±—É—á–µ–Ω–∏–µ CE –º–æ–¥–µ–ª–∏ —Å –Ω—É–ª—è...")

        model_ce = SimpleFaceModel(
            num_classes=num_classes,
            embedding_size=config.embedding_size
        ).to(device)

        # –†–∞–∑–º–æ—Ä–∞–∂–∏–≤–∞–µ–º –±–æ–ª—å—à–µ —Å–ª–æ–µ–≤
        for param in model_ce.backbone.layer2.parameters():
            param.requires_grad = True
        for param in model_ce.backbone.layer3.parameters():
            param.requires_grad = True

        optimizer_ce = optim.AdamW(
            filter(lambda p: p.requires_grad, model_ce.parameters()),
            lr=0.001,
            weight_decay=1e-4
        )

        scheduler_ce = optim.lr_scheduler.CosineAnnealingLR(
            optimizer_ce,
            T_max=config.num_epochs_ce,
            eta_min=1e-6
        )

        criterion_ce = nn.CrossEntropyLoss(label_smoothing=0.1)

        best_ce_acc = 0
        ce_history = {'loss': [], 'acc': [], 'val_loss': [], 'val_acc': []}

        for epoch in range(config.num_epochs_ce):
            print(f"\nEpoch {epoch+1}/{config.num_epochs_ce} (CE)")
            print("-" * 40)

            train_loss, train_acc = train_epoch_ce(
                model_ce, train_loader, optimizer_ce, criterion_ce, device
            )

            val_loss, val_acc = validate(
                model_ce, val_loader, criterion_ce, device, 'ce'
            )

            # –°–æ—Ö—Ä–∞–Ω—è–µ–º –∏—Å—Ç–æ—Ä–∏—é
            ce_history['loss'].append(train_loss)
            ce_history['acc'].append(train_acc)
            ce_history['val_loss'].append(val_loss)
            ce_history['val_acc'].append(val_acc)

            print(f"\n–ò—Ç–æ–≥–∏ —ç–ø–æ—Ö–∏:")
            print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
            print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
            print(f"  Learning Rate: {optimizer_ce.param_groups[0]['lr']:.2e}")

            if val_acc > best_ce_acc:
                best_ce_acc = val_acc
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model_ce.state_dict(),
                    'val_acc': val_acc,
                    'optimizer_state_dict': optimizer_ce.state_dict(),
                    'history': ce_history,
                    'num_classes': num_classes
                }, ce_checkpoint_path)
                print(f"  ‚úì –°–æ—Ö—Ä–∞–Ω–µ–Ω–∞ CE –º–æ–¥–µ–ª—å (Acc: {val_acc:.2f}%)")

            scheduler_ce.step()

        print(f"\nüéØ –õ—É—á—à–∞—è —Ç–æ—á–Ω–æ—Å—Ç—å CE –º–æ–¥–µ–ª–∏: {best_ce_acc:.2f}%")
    else:
        print(f"‚úÖ CE –º–æ–¥–µ–ª—å —É–∂–µ –æ–±—É—á–µ–Ω–∞: {ce_checkpoint_path}")

    # ==================== 3. FINE-TUNING –° ARCFACE ====================
    print("\n" + "="*60)
    print("FINE-TUNING –° ARCFACE LOSS")
    print("="*60)

    # –°–æ–∑–¥–∞–µ–º ArcFace –º–æ–¥–µ–ª—å
    model_arcface = ArcFaceModel(
        num_classes=num_classes,
        embedding_size=config.embedding_size,
        s=8.0,  # –ù–∞—á–∏–Ω–∞–µ–º —Å –º–µ–Ω—å—à–µ–≥–æ scale
        m=config.arcface_m
    ).to(device)

    # –ó–∞–≥—Ä—É–∂–∞–µ–º –≤–µ—Å–∞ –∏–∑ CE –º–æ–¥–µ–ª–∏
    print("\n–ó–∞–≥—Ä—É–∑–∫–∞ –≤–µ—Å–æ–≤ –∏–∑ CE –º–æ–¥–µ–ª–∏...")
    load_success = model_arcface.load_from_ce_model(ce_checkpoint_path, device)

    # –†–∞–∑–º–æ—Ä–∞–∂–∏–≤–∞–µ–º –≤—Å–µ —Å–ª–æ–∏ –¥–ª—è fine-tuning
    for param in model_arcface.parameters():
        param.requires_grad = True

    # –ü—Ä–æ–≤–µ—Ä–∫–∞ –º–æ–¥–µ–ª–∏
    print("\n" + "="*50)
    print("–ü–†–û–í–ï–†–ö–ê –ú–û–î–ï–õ–ò")
    print("="*50)

    model_arcface.eval()
    with torch.no_grad():
        test_images, test_labels = next(iter(train_loader))
        test_images = test_images[:4].to(device)
        test_labels = test_labels[:4].to(device)

        logits, embeddings = model_arcface(test_images, test_labels)

        print(f"\n1. –†–∞–∑–º–µ—Ä–Ω–æ—Å—Ç–∏:")
        print(f"   Logits shape: {logits.shape}")
        print(f"   Embeddings shape: {embeddings.shape}")

        print(f"\n2. –ß–∏—Å–ª–µ–Ω–Ω—ã–µ –∑–Ω–∞—á–µ–Ω–∏—è:")
        print(f"   Logits range: [{logits.min():.4f}, {logits.max():.4f}]")
        print(f"   Logits mean/std: {logits.mean():.4f} / {logits.std():.4f}")
        print(f"   Embeddings norm: {embeddings.norm(p=2, dim=1).mean():.4f} ¬± {embeddings.norm(p=2, dim=1).std():.4f}")

        print(f"\n3. –ü—Ä–æ–≤–µ—Ä–∫–∞ accuracy:")
        embeddings_norm = embeddings  # –£–∂–µ –Ω–æ—Ä–º–∞–ª–∏–∑–æ–≤–∞–Ω—ã
        weights_norm = F.normalize(model_arcface.arcface.weight.data, p=2, dim=1)
        cosine = torch.mm(embeddings_norm, weights_norm.t())
        predicted = cosine.argmax(dim=1)
        accuracy = (predicted == test_labels).float().mean().item()
        print(f"   Accuracy –Ω–∞ —Ç–µ—Å—Ç–æ–≤–æ–º –±–∞—Ç—á–µ: {accuracy:.2%}")

        test_loss = nn.CrossEntropyLoss()(logits, test_labels)
        print(f"   Loss: {test_loss.item():.4f}")

    # –û–ø—Ç–∏–º–∏–∑–∞—Ç–æ—Ä –¥–ª—è fine-tuning
    optimizer_arc = optim.AdamW(
        model_arcface.parameters(),
        lr=config.learning_rate,
        weight_decay=0.0005
    )

    # Scheduler —Å warmup
    scheduler_arc = optim.lr_scheduler.OneCycleLR(
        optimizer_arc,
        max_lr=config.learning_rate * 10,
        epochs=config.num_epochs_arcface,
        steps_per_epoch=len(train_loader),
        pct_start=0.3
    )

    # Loss function —Å label smoothing
    criterion_arc = nn.CrossEntropyLoss(label_smoothing=0.1)

    # –û–±—É—á–µ–Ω–∏–µ ArcFace
    print("\n" + "="*60)
    print("–ù–ê–ß–ê–õ–û –û–ë–£–ß–ï–ù–ò–Ø ARCFACE")
    print("="*60)

    best_arc_acc = 0
    arc_history = {'loss': [], 'acc': [], 'val_loss': [], 'val_acc': [], 'scale': []}

    for epoch in range(config.num_epochs_arcface):
        print(f"\n{'='*50}")
        print(f"Epoch {epoch+1}/{config.num_epochs_arcface} - ArcFace Fine-tuning")
        print(f"{'='*50}")

        # –û–±—É—á–µ–Ω–∏–µ
        train_loss, train_acc = train_epoch_arcface(
            model_arcface, train_loader, optimizer_arc, criterion_arc, device,
            epoch, config.num_epochs_arcface
        )

        # –í–∞–ª–∏–¥–∞—Ü–∏—è
        val_loss, val_acc = validate(
            model_arcface, val_loader, criterion_arc, device, 'arcface'
        )

        # –°–æ—Ö—Ä–∞–Ω—è–µ–º –∏—Å—Ç–æ—Ä–∏—é
        arc_history['loss'].append(train_loss)
        arc_history['acc'].append(train_acc)
        arc_history['val_loss'].append(val_loss)
        arc_history['val_acc'].append(val_acc)
        arc_history['scale'].append(model_arcface.arcface.s)

        print(f"\nüìä –ò—Ç–æ–≥–∏ —ç–ø–æ—Ö–∏:")
        print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        print(f"  Scale s: {model_arcface.arcface.s:.1f}")
        print(f"  Margin m: {model_arcface.arcface.m:.2f}")
        print(f"  Learning Rate: {optimizer_arc.param_groups[0]['lr']:.2e}")

        # –°–æ—Ö—Ä–∞–Ω—è–µ–º –ª—É—á—à—É—é –º–æ–¥–µ–ª—å
        if val_acc > best_arc_acc:
            best_arc_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model_arcface.state_dict(),
                'optimizer_state_dict': optimizer_arc.state_dict(),
                'val_acc': val_acc,
                'scale': model_arcface.arcface.s,
                'margin': model_arcface.arcface.m,
                'history': arc_history,
                'num_classes': num_classes
            }, 'checkpoints/best_model_arcface.pth')
            print(f"  üíæ –°–æ—Ö—Ä–∞–Ω–µ–Ω–∞ –ª—É—á—à–∞—è ArcFace –º–æ–¥–µ–ª—å (Val Acc: {val_acc:.2f}%)")

        # –û–±–Ω–æ–≤–ª—è–µ–º scheduler
        scheduler_arc.step()

    # ==================== 4. –¢–ï–°–¢–ò–†–û–í–ê–ù–ò–ï ====================
    print("\n" + "="*60)
    print("–¢–ï–°–¢–ò–†–û–í–ê–ù–ò–ï ARCFACE –ú–û–î–ï–õ–ò")
    print("="*60)

    # –ó–∞–≥—Ä—É–∂–∞–µ–º –ª—É—á—à—É—é ArcFace –º–æ–¥–µ–ª—å
    arc_checkpoint_path = 'checkpoints/best_model_arcface.pth'
    if os.path.exists(arc_checkpoint_path):
        arc_checkpoint = torch.load(arc_checkpoint_path, map_location=device)
        model_arcface.load_state_dict(arc_checkpoint['model_state_dict'])
        print(f"‚úÖ –ó–∞–≥—Ä—É–∂–µ–Ω–∞ –ª—É—á—à–∞—è ArcFace –º–æ–¥–µ–ª—å (Val Acc: {arc_checkpoint['val_acc']:.2f}%)")
    else:
        print("‚ö† –õ—É—á—à–∞—è –º–æ–¥–µ–ª—å –Ω–µ –Ω–∞–π–¥–µ–Ω–∞, –∏—Å–ø–æ–ª—å–∑—É–µ–º –ø–æ—Å–ª–µ–¥–Ω—é—é")

    # –¢–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏–µ
    test_loss, test_acc = validate(
        model_arcface, test_loader, criterion_arc, device, 'arcface'
    )

    print(f"\nüéØ –†–ï–ó–£–õ–¨–¢–ê–¢–´ ARCFACE –ù–ê –¢–ï–°–¢–û–í–û–ô –í–´–ë–û–†–ö–ï:")
    print(f"  Test Loss: {test_loss:.4f}")
    print(f"  Test Accuracy: {test_acc:.2f}%")
    print(f"  Final Scale s: {model_arcface.arcface.s:.1f}")
    print(f"  Final Margin m: {model_arcface.arcface.m:.2f}")

    # –ê–Ω–∞–ª–∏–∑ —ç–º–±–µ–¥–¥–∏–Ω–≥–æ–≤
    print("\n" + "="*60)
    print("–ê–ù–ê–õ–ò–ó –ö–ê–ß–ï–°–¢–í–ê –≠–ú–ë–ï–î–î–ò–ù–ì–û–í")
    print("="*60)

    test_embeddings, test_labels = analyze_embeddings(
        model_arcface, test_loader, device, num_samples=2000
    )

    # ==================== 5. –í–ò–ó–£–ê–õ–ò–ó–ê–¶–ò–Ø ====================
    print("\n" + "="*60)
    print("–í–ò–ó–£–ê–õ–ò–ó–ê–¶–ò–Ø –†–ï–ó–£–õ–¨–¢–ê–¢–û–í –û–ë–£–ß–ï–ù–ò–Ø")
    print("="*60)

    # –°–æ–∑–¥–∞–µ–º –≥—Ä–∞—Ñ–∏–∫–∏
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))

    # Loss history
    axes[0, 0].plot(arc_history['loss'], label='Train Loss', linewidth=2, marker='o', markersize=4)
    axes[0, 0].plot(arc_history['val_loss'], label='Val Loss', linewidth=2, marker='s', markersize=4)
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('ArcFace: Loss History')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # Accuracy history
    axes[0, 1].plot(arc_history['acc'], label='Train Acc', linewidth=2, marker='o', markersize=4)
    axes[0, 1].plot(arc_history['val_acc'], label='Val Acc', linewidth=2, marker='s', markersize=4)
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy (%)')
    axes[0, 1].set_title('ArcFace: Accuracy History')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    axes[0, 1].set_ylim([0, 100])

    # Scale s history
    axes[1, 0].plot(arc_history['scale'], linewidth=2, marker='o', markersize=4, color='green')
    axes[1, 0].fill_between(range(len(arc_history['scale'])),
                           arc_history['scale'],
                           alpha=0.3, color='green')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Scale s')
    axes[1, 0].set_title('ArcFace Scale Progression')
    axes[1, 0].grid(True, alpha=0.3)

    # Learning rate history (—Å–∏–º—É–ª–∏—Ä—É–µ–º)
    lr_values = []
    temp_optimizer = optim.AdamW([torch.zeros(1)], lr=config.learning_rate)
    temp_scheduler = optim.lr_scheduler.OneCycleLR(
        temp_optimizer,
        max_lr=config.learning_rate * 10,
        epochs=config.num_epochs_arcface,
        steps_per_epoch=len(train_loader),
        pct_start=0.3
    )

    for epoch in range(config.num_epochs_arcface):
        for _ in range(len(train_loader)):
            lr_values.append(temp_optimizer.param_groups[0]['lr'])
            temp_scheduler.step()

    # –ë–µ—Ä–µ–º –∑–Ω–∞—á–µ–Ω–∏—è –Ω–∞ –Ω–∞—á–∞–ª–æ –∫–∞–∂–¥–æ–π —ç–ø–æ—Ö–∏
    epoch_lr = [lr_values[i * len(train_loader)] for i in range(config.num_epochs_arcface)]

    axes[1, 1].plot(epoch_lr, linewidth=2, marker='o', markersize=4, color='orange')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Learning Rate')
    axes[1, 1].set_title('Learning Rate Schedule (OneCycle)')
    axes[1, 1].grid(True, alpha=0.3)
    axes[1, 1].set_yscale('log')

    plt.tight_layout()
    plt.savefig(f'arcface_training_results_{timestamp}.png', dpi=150, bbox_inches='tight')
    plt.show()

    print("\n" + "="*60)
    print("üéâ –û–ë–£–ß–ï–ù–ò–ï ARCFACE –ó–ê–í–ï–†–®–ï–ù–û!")
    print("="*60)
    print(f"üìä –ò—Ç–æ–≥–æ–≤—ã–µ —Ä–µ–∑—É–ª—å—Ç–∞—Ç—ã:")
    print(f"  –õ—É—á—à–∞—è —Ç–æ—á–Ω–æ—Å—Ç—å –Ω–∞ –≤–∞–ª–∏–¥–∞—Ü–∏–∏: {best_arc_acc:.2f}%")
    print(f"  –¢–æ—á–Ω–æ—Å—Ç—å –Ω–∞ —Ç–µ—Å—Ç–µ: {test_acc:.2f}%")
    print(f"  –§–∏–Ω–∞–ª—å–Ω—ã–π Scale s: {model_arcface.arcface.s:.1f}")
    print(f"  –§–∏–Ω–∞–ª—å–Ω—ã–π Margin m: {model_arcface.arcface.m:.2f}")
    print(f"\nüíæ –§–∞–π–ª—ã —Å–æ—Ö—Ä–∞–Ω–µ–Ω—ã:")
    print(f"  –ß–µ–∫–ø–æ–∏–Ω—Ç—ã: 'checkpoints/'")
    print(f"  –õ–æ–≥: 'training_{timestamp}.log'")
    print(f"  –ì—Ä–∞—Ñ–∏–∫–∏: 'arcface_training_results_{timestamp}.png'")
    print(f"  –ê–Ω–∞–ª–∏–∑ —ç–º–±–µ–¥–¥–∏–Ω–≥–æ–≤: 'embedding_analysis.png'")

if __name__ == "__main__":
    main()