In [3]:
import os
import sys 
import torch
from torchvision import datasets, transforms
from collections import defaultdict
import numpy as np
import torch.nn as nn
import torch.optim as optim
sys.path.append(os.path.abspath('..'))
from base_trainer import Trainer
from model import SimpleCNN
from torch.utils.data import DataLoader, Subset, Dataset
from tqdm import tqdm
import torch.nn.functional as F
from sklearn.metrics import f1_score, accuracy_score

In [57]:
class MNLPTrainer(Trainer):
    def select_samples(self):
        self.model.eval()  # Set the model to evaluation mode
        all_scores = []  # List to store MNLP scores for each sample
        all_indices = []  # List to store indices of the samples

        with torch.no_grad():
            for batch in self.pool_loader:
                inputs, indices = batch  # Assume loader returns data and their indices
                inputs = inputs.to(self.device)

                log_probs = torch.log_softmax(self.model(inputs), dim=-1)  # (batch_size, seq_len, num_classes)

                # Для каждого примера суммируем логарифмы вероятностей по длине последовательности
                sequence_log_probs = log_probs.sum(dim=-1)  # (batch_size, seq_len)

                # Усредняем значения логарифмов вероятностей по длине последовательности
                mean_log_probs = sequence_log_probs.mean(axis=0)  # (batch_size)

                # Нормализуем логарифмы вероятностей по длине последовательности
                normalized_log_probs = mean_log_probs / inputs.shape[1]  # (batch_size)

                # Save results
                all_scores.append(normalized_log_probs.cpu())
                all_indices.append(indices)
        
        print(all_scores)

        # Combine results across all batches
        all_scores = torch.stack(all_scores) # All MNLP scores
        all_indices = torch.cat(all_indices)  # All indices

        # Select top-K indices with the lowest MNLP scores (most uncertain samples)
        _, top_indices = torch.topk(-all_scores, 10)  # Negative for ascending sort
        informative_indices = all_indices[top_indices].tolist()

        # Update dataloaders for the next iteration
        self.update_dataloader(informative_indices)

        return informative_indices

    def fit(self, num_epochs):
        """
        Полный цикл обучения.
        :param num_epochs: Количество эпох
        """
        for epoch in range(num_epochs):
            train_loss = self.train_step()
            self.select_samples()
            val_loss, accuracy, f1  = self.val_step()
            print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Acc: {accuracy:.4f}, F1: {f1:.4f}")

    

In [31]:
NUM_CLASSES = 10           # CIFAR-10
NUM_EPOCH = 10
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
percentages = [0.01, 0.10, 0.20]
initial_datasets = {}
pool_data_indices = []
config = {}
config['split_label_unlabel'] = 0.1
config['batch_size']=64

initial_indices = np.random.choice(len(train_dataset), size=int(len(train_dataset)*config['split_label_unlabel']), replace=False) # 20%
initial_data = Subset(train_dataset, initial_indices)

unlabeled_indices = list(set(range(len(train_dataset))) - set(initial_indices))
unlabeled_data = Subset(train_dataset, unlabeled_indices)

train_dataloader = DataLoader(initial_data, batch_size=config['batch_size'], shuffle=True)
val_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True)
pool_dataloader = DataLoader(unlabeled_data, batch_size=config['batch_size'], shuffle=True)

model = SimpleCNN(10)

Files already downloaded and verified
Files already downloaded and verified


In [58]:
optimizer = optim.AdamW(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()
train = MNLPTrainer(model=model.to(DEVICE), optimizer=optimizer,criterion=criterion, train_loader=train_dataloader, val_loader=val_dataloader,pool_loader = pool_dataloader, device=DEVICE)


In [59]:
train.fit(2)

Training: 100%|██████████| 79/79 [00:06<00:00, 12.70it/s]


[tensor(-14.4505), tensor(-14.0875), tensor(-14.7053), tensor(-14.1767), tensor(-14.4621), tensor(-14.1969), tensor(-14.3411), tensor(-14.9397), tensor(-13.8649), tensor(-13.8257), tensor(-14.8641), tensor(-14.7586), tensor(-14.1145), tensor(-13.9136), tensor(-14.3768), tensor(-13.5882), tensor(-13.9425), tensor(-14.2765), tensor(-14.0406), tensor(-14.7136), tensor(-13.3104), tensor(-14.3849), tensor(-14.5428), tensor(-14.3923), tensor(-13.7965), tensor(-14.3908), tensor(-14.0770), tensor(-14.3916), tensor(-14.3304), tensor(-14.2694), tensor(-14.2465), tensor(-14.1802), tensor(-13.2785), tensor(-15.2160), tensor(-14.7710), tensor(-14.0642), tensor(-14.2407), tensor(-14.5082), tensor(-15.4111), tensor(-14.6439), tensor(-14.1312), tensor(-14.5301), tensor(-14.5958), tensor(-13.7976), tensor(-14.6266), tensor(-14.8875), tensor(-14.0442), tensor(-15.2420), tensor(-14.8178), tensor(-14.2407), tensor(-14.2786), tensor(-14.8017), tensor(-13.9379), tensor(-14.5469), tensor(-14.1916), tensor(-1

Validating: 100%|██████████| 157/157 [00:06<00:00, 22.61it/s]


Validation Loss: 1.3038, Accuracy: 0.5425, F1 Score: 0.5327
Epoch 1/2 - Train Loss: 1.0464, Val Loss: 1.3038, Acc: 0.5425, F1: 0.5327


Training: 100%|██████████| 79/79 [00:05<00:00, 13.50it/s]


[tensor(-14.8107), tensor(-15.1199), tensor(-14.4863), tensor(-13.8317), tensor(-14.8667), tensor(-14.4588), tensor(-15.1447), tensor(-14.4861), tensor(-15.0273), tensor(-15.0570), tensor(-14.3123), tensor(-14.1142), tensor(-14.1187), tensor(-15.3073), tensor(-14.1111), tensor(-14.3584), tensor(-14.7774), tensor(-14.7364), tensor(-13.1282), tensor(-14.7112), tensor(-15.1688), tensor(-14.3176), tensor(-14.2650), tensor(-13.9763), tensor(-13.7741), tensor(-14.2896), tensor(-15.2843), tensor(-14.0473), tensor(-14.3878), tensor(-14.2902), tensor(-13.7745), tensor(-14.8722), tensor(-14.1277), tensor(-13.9393), tensor(-14.5727), tensor(-15.0498), tensor(-14.4147), tensor(-13.9490), tensor(-14.4789), tensor(-13.8333), tensor(-14.4014), tensor(-14.0738), tensor(-14.2900), tensor(-14.5123), tensor(-14.2212), tensor(-14.3049), tensor(-14.2302), tensor(-14.6350), tensor(-14.2610), tensor(-14.8602), tensor(-14.4360), tensor(-14.2967), tensor(-15.0408), tensor(-14.1543), tensor(-13.9970), tensor(-1

Validating: 100%|██████████| 157/157 [00:06<00:00, 22.52it/s]


Validation Loss: 1.2370, Accuracy: 0.5561, F1 Score: 0.5571
Epoch 2/2 - Train Loss: 1.0056, Val Loss: 1.2370, Acc: 0.5561, F1: 0.5571
