In [66]:
import os
import sys 
import torch
from torchvision import datasets, transforms
from collections import defaultdict
import numpy as np
import torch.nn as nn
import torch.optim as optim
sys.path.append(os.path.abspath('..'))
from base_trainer import Trainer
from model import SimpleCNN
from torch.utils.data import DataLoader, Subset, Dataset
from tqdm import tqdm
import torch.nn.functional as F
from sklearn.metrics import f1_score, accuracy_score

In [85]:
class UpdatedTrainer(Trainer):
    def calculate_mnlp(self, outputs, targets):
        """
        Calculate Maximum Normalized Log-Probability (MNLP).
        :param outputs: Model outputs (logits)
        :param targets: True labels
        :return: MNLP value
        """
        probabilities = F.softmax(outputs, dim=1)
        log_probabilities = torch.log(probabilities)

        true_log_probs = log_probabilities[range(len(targets)), targets]
        num_classes = probabilities.size(1)

        # Normalize log probabilities
        normalized_log_probs = true_log_probs / torch.log(torch.tensor(1.0 / num_classes).to(self.device))
        return normalized_log_probs.mean().item()

    def fit(self, num_epochs):
        """
        Полный цикл обучения.
        :param num_epochs: Количество эпох
        """
        for epoch in range(num_epochs):
            train_loss = self.train_step()
            val_loss, accuracy, f1, avg_mnlp = self.val_step()
            print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Acc: {accuracy:.4f}, F1: {f1:.4f}, MNLP: {avg_mnlp:.4f}")

    def val_step(self):
        """
        Один шаг валидации.
        Вычисляет среднюю потерю, F1-score и точность.
        :return: Средняя потеря, F1-score и точность за эпоху
        """
        self.model.eval()
        running_loss = 0.0
        all_targets = []
        all_predictions = []
        mnlp_values = []

        with torch.no_grad():
            for inputs, targets in tqdm(self.val_loader, desc="Validating"):
                inputs, targets = inputs.to(self.device), targets.to(self.device)

                # Прямой проход
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)

                running_loss += loss.item()

                # Сохранение предсказаний и истинных меток
                predictions = outputs.argmax(dim=1).cpu().numpy()
                all_predictions.extend(predictions)
                all_targets.extend(targets.cpu().numpy())

                # Сохранение MNLP
                mnlp = self.calculate_mnlp(outputs, targets)
                mnlp_values.append(mnlp)

        # Рассчитываем метрики
        avg_loss = running_loss / len(self.val_loader)
        accuracy = accuracy_score(all_targets, all_predictions)
        f1 = f1_score(all_targets, all_predictions, average="weighted")
        self.val_acc.append(accuracy)
        self.val_f1.append(f1)

        avg_mnlp = sum(mnlp_values) / len(mnlp_values) if mnlp_values else 0

        self.val_losses.append(avg_loss)

        print(f"Validation Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}, MNLP: {avg_mnlp:.4f}")
        return avg_loss, accuracy, f1, avg_mnlp

In [86]:
NUM_CLASSES = 10           # CIFAR-10
NUM_EPOCH = 10
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
percentages = [0.01, 0.10, 0.20]
initial_datasets = {}
pool_data_indices = []

# Создаем словарь для хранения индексов по классам
class_indices = defaultdict(list)

# Заполняем словарь индексами изображений по классам
for index, (_, label) in enumerate(train_dataset):
    class_indices[label].append(index)

# Формируем начальные наборы данных
for percentage in percentages:
    initial_indices = []
    num_samples_per_class = {label: int(len(indices) * percentage) for label, indices in class_indices.items()}
    
    for label, indices in class_indices.items():
        # Случайным образом выбираем индексы для каждого класса
        selected_indices = np.random.choice(indices, num_samples_per_class[label], replace=False)
        initial_indices.extend(selected_indices)
    
    # Сортируем индексы для создания подмножества
    initial_datasets[percentage] = sorted(initial_indices)

# Создаем pool_data с оставшимися данными
all_initial_indices = set(initial_datasets[0.01] + initial_datasets[0.10] + initial_datasets[0.20])
pool_data_indices = [i for i in range(len(train_dataset)) if i not in all_initial_indices]

# Создаем подмножества для начальных данных и оставшихся данных
initial_dataset_1_percent = torch.utils.data.Subset(train_dataset, initial_datasets[0.01])
initial_dataset_10_percent = torch.utils.data.Subset(train_dataset, initial_datasets[0.10])
initial_dataset_20_percent = torch.utils.data.Subset(train_dataset, initial_datasets[0.20])
pool_data = torch.utils.data.Subset(train_dataset, pool_data_indices)

# Проверяем размеры подмножеств
print(f"Initial dataset (1%): {len(initial_dataset_1_percent)}")
print(f"Initial dataset (10%): {len(initial_dataset_10_percent)}")
print(f"Initial dataset (20%): {len(initial_dataset_20_percent)}")
print(f"Pool data size: {len(pool_data)}")

train_dataloader = DataLoader(initial_dataset_1_percent, batch_size=64, shuffle=True)
pool_dataloader = DataLoader(pool_data, batch_size=64, shuffle=True)
val_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)
model = SimpleCNN(10)

Files already downloaded and verified
Files already downloaded and verified
Initial dataset (1%): 500
Initial dataset (10%): 5000
Initial dataset (20%): 10000
Pool data size: 35656


In [87]:
optimizer = optim.AdamW(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()
train = UpdatedTrainer(model=model.to(DEVICE), optimizer=optimizer,criterion=criterion, train_loader=train_dataloader, val_loader=val_dataloader,pool_loader = pool_dataloader, device=DEVICE)


In [88]:
train.fit(2)



alidating: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 149.74it/s]

Validation Loss: 2.2108, Accuracy: 0.2428, F1 Score: 0.1777, MNLP: 0.9601
Epoch 1/2 - Train Loss: 2.2822, Val Loss: 2.2108, Acc: 0.2428, F1: 0.1777, MNLP: 0.9601



Validating: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 149.38it/s]

Validation Loss: 2.0979, Accuracy: 0.2287, F1 Score: 0.1725, MNLP: 0.9111
Epoch 2/2 - Train Loss: 2.1209, Val Loss: 2.0979, Acc: 0.2287, F1: 0.1725, MNLP: 0.9111



