In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from ..utils.data_utils import get_svhn_loaders, SVHNCustomDataset
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import os


In [None]:
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for images, labels, _ in loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        
        # Como labels são sequências (batch, max_len), 
        # a Loss depende de como você estruturou a saída do seu modelo
        loss = criterion(outputs.view(-1, 11), labels.view(-1)) 
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    return running_loss / len(loader)

def validate_accuracy(model, loader, device):
    model.eval()
    correct_sequences = 0
    total = 0
    
    with torch.no_grad():
        for images, labels, _ in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            
            # Pegamos o dígito mais provável para cada uma das 5 posições
            preds = torch.argmax(outputs, dim=2) # [Batch, 5]
            
            # Uma predição só é correta se TODOS os dígitos da sequência baterem
            # (Ignorando os pads se necessário, mas aqui comparamos a sequência cheia)
            correct_batch = torch.all(preds == labels, dim=1).sum().item()
            correct_sequences += correct_batch
            total += labels.size(0)
            
    return 100 * correct_sequences / total

In [None]:
def run_active_learning_cycle(model, train_dataset, labeled_indices, test_loader, device, config, cycle):
    """
    Executa o treinamento para um conjunto específico de índices rotulados.
    """
    
    writer = SummaryWriter(log_dir=f"runs/AL_Cycle_{cycle}")
    
    # Criar DataLoader apenas com as amostras selecionadas (Labeled Set)
    train_subset = Subset(train_dataset, labeled_indices)
    train_loader = DataLoader(train_subset, batch_size=config['batch_size'], 
                              shuffle=True, num_workers=4, pin_memory=True)
    
    optimizer = optim.Adam(model.parameters(), lr=config['lr'])
    criterion = nn.CrossEntropyLoss(ignore_index=10) # 10 é o pad_token
    
    
    
    for epoch in range(config['epochs']):
        loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
        test_acc = validate_accuracy(model, test_loader, device)
        print(f"Epoch {epoch+1}: Loss {loss:.4f} | Test Acc: {test_acc:.2f}%")
        # Dentro do loop de épocas:
        writer.add_scalar('Loss/train', loss, epoch)
        writer.add_scalar('Accuracy/test', test_acc, epoch)
    
    writer.close() 
    return model, test_acc, optimizer

In [None]:
# HYPERPARAMETROS
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = {
    'batch_size': 64,
    'lr': 0.001,
    'epochs': 20,
    'initial_budget': 1000, # Quantas imagens começam rotuladas
    'cycle_budget': 500     # Quantas imagens adicionar por ciclo de AL
}


In [None]:
from torchvision import transforms
transform = transforms.Compose([
        transforms.Resize((640,640)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.439, 0.434, 0.439], std=[0.204, 0.208, 0.208])
    ])

# Carregar Dataset Completo (Pool)
# Usamos o Dataset puro aqui para podermos manipular os índices via Subset
train_ds = SVHNCustomDataset('C:/Repositorios/EasyDeepActiveLearning/csv_SVHN/train.csv', 'F:/SVHN/train/train/', transform=transform) # Add transforms here
test_ds = SVHNCustomDataset('C:/Repositorios/EasyDeepActiveLearning/csv_SVHN/test.csv', 'F:/SVHN/test/test/', transform=transform)
test_loader = DataLoader(test_ds, batch_size=config['batch_size'], shuffle=False, num_workers=4, pin_memory=True)

In [None]:
# Inicialização do Active Learning
num_train = len(train_ds)
indices = list(range(num_train))
np.random.shuffle(indices)

# Índices das imagens que o modelo "pode ver"
labeled_indices = indices[:config['initial_budget']]
# Índices do pool não rotulado (Unlabeled pool)
unlabeled_indices = indices[config['initial_budget']:]

In [None]:
def save_al_checkpoint(model, optimizer, cycle, labeled_indices, acc, path="checkpoints"):
    """
    Salva o estado do modelo, otimizador e metadados do ciclo de AL.
    """
    if not os.path.exists(path):
        os.makedirs(path)
        
    checkpoint = {
        'cycle': cycle,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'labeled_indices': labeled_indices, # Essencial para saber o que foi usado no treino
        'test_acc': acc
    }
    
    filename = os.path.join(path, f"model_cycle_{cycle}_acc_{acc:.2f}.pt")
    torch.save(checkpoint, filename)
    print(f"Checkpoint salvo: {filename}")

In [None]:
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

In [None]:
from ..model.models import SVHNCustomCNN
from ..query_strategies.strategies import entropy_query_strategy

# Loop de Ciclos de Active Learning
num_cycles = 5
checkpoint_dir = "al_checkpoints_svhn"
for cycle in range(num_cycles):
    print(f"\n--- Iniciando Ciclo de AL {cycle} | Labeled Size: {len(labeled_indices)} ---")
    
    
    # Inicializa/Reseta o modelo para cada ciclo (ou continua o treino)
    model = SVHNCustomCNN().to(device)
    model.apply(init_weights)
    
    # Treina o modelo com o que temos rotulado até agora
    model, last_acc, opt = run_active_learning_cycle(model, train_ds, labeled_indices, test_loader, device, config, cycle)
    
    save_al_checkpoint(
        model=model, 
        optimizer=opt, 
        cycle=cycle, 
        labeled_indices=labeled_indices, 
        acc=last_acc,
        path=checkpoint_dir
    )
    
    print(f"--- Iniciando Query Strategy: Entropy Sampling ---")
    
    # Chamada da estratégia
    new_indices = entropy_query_strategy(
        model=model,
        dataset=train_ds,
        unlabeled_indices=unlabeled_indices,
        budget=config['cycle_budget'],
        batch_size=config['batch_size'],
        device=device
    )
    
    # Atualiza os conjuntos para o próximo ciclo
    # Removendo os selecionados do unlabeled_indices
    # (Transformamos em set para uma busca mais rápida e depois voltamos para lista/array)
    new_indices_set = set(new_indices)
    unlabeled_indices = np.array([i for i in unlabeled_indices if i not in new_indices_set])
    
    # Adiciona ao labeled
    labeled_indices = np.concatenate([labeled_indices, new_indices])