In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from data_utils import get_svhn_loaders, SVHNCustomDataset
import numpy as np


In [None]:
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for images, labels, _ in loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        
        # Como labels são sequências (batch, max_len), 
        # a Loss depende de como você estruturou a saída do seu modelo
        loss = criterion(outputs.view(-1, 11), labels.view(-1)) 
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    return running_loss / len(loader)

def validate(model, loader, criterion, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels, _ in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            # Lógica de predição simplificada
            _, predicted = torch.max(outputs.data, 2)
            total += labels.size(0) * labels.size(1)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

In [None]:
def run_active_learning_cycle(model, train_dataset, labeled_indices, test_loader, device, config):
    """
    Executa o treinamento para um conjunto específico de índices rotulados.
    """
    # Criar DataLoader apenas com as amostras selecionadas (Labeled Set)
    train_subset = Subset(train_dataset, labeled_indices)
    train_loader = DataLoader(train_subset, batch_size=config['batch_size'], 
                              shuffle=True, num_workers=4, pin_memory=True)
    
    optimizer = optim.Adam(model.parameters(), lr=config['lr'])
    criterion = nn.CrossEntropyLoss(ignore_index=10) # 10 é o pad_token
    
    for epoch in range(config['epochs']):
        loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
        acc = validate(model, test_loader, criterion, device)
        print(f"Epoch {epoch+1}: Loss {loss:.4f} | Test Acc: {acc:.2f}%")
    
    return model, acc, optimizer

In [None]:
# HYPERPARAMETROS
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = {
    'batch_size': 64,
    'lr': 0.001,
    'epochs': 20,
    'initial_budget': 1000, # Quantas imagens começam rotuladas
    'cycle_budget': 500     # Quantas imagens adicionar por ciclo de AL
}


In [None]:
# Carregar Dataset Completo (Pool)
# Usamos o Dataset puro aqui para podermos manipular os índices via Subset
train_ds = SVHNCustomDataset('train.csv', 'train_images/', transform=None) # Add transforms here
test_ds = SVHNCustomDataset('test.csv', 'test_images/', transform=None)
test_loader = DataLoader(test_ds, batch_size=config['batch_size'], shuffle=False)

In [None]:
# Inicialização do Active Learning
num_train = len(train_ds)
indices = list(range(num_train))
np.random.shuffle(indices)

# Índices das imagens que o modelo "pode ver"
labeled_indices = indices[:config['initial_budget']]
# Índices do pool não rotulado (Unlabeled pool)
unlabeled_indices = indices[config['initial_budget']:]

In [None]:
def save_al_checkpoint(model, optimizer, cycle, labeled_indices, acc, path="checkpoints"):
    """
    Salva o estado do modelo, otimizador e metadados do ciclo de AL.
    """
    if not os.path.exists(path):
        os.makedirs(path)
        
    checkpoint = {
        'cycle': cycle,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'labeled_indices': labeled_indices, # Essencial para saber o que foi usado no treino
        'test_acc': acc
    }
    
    filename = os.path.join(path, f"model_cycle_{cycle}_acc_{acc:.2f}.pt")
    torch.save(checkpoint, filename)
    print(f"Checkpoint salvo: {filename}")

In [None]:


# Loop de Ciclos de Active Learning
num_cycles = 5
checkpoint_dir = "al_checkpoints_svhn"
for cycle in range(num_cycles):
    print(f"\n--- Iniciando Ciclo de AL {cycle} | Labeled Size: {len(labeled_indices)} ---")
    
    # Inicializa/Reseta o modelo para cada ciclo (ou continua o treino)
    model = YourModelArchitecture().to(device)
    
    # Treina o modelo com o que temos rotulado até agora
    model, last_acc, opt = run_active_learning_cycle(model, train_ds, labeled_indices, test_loader, device, config)
    
    save_al_checkpoint(
        model=model, 
        optimizer=opt, 
        cycle=cycle, 
        labeled_indices=labeled_indices, 
        acc=last_acc,
        path=checkpoint_dir
    )
    
    # --- ESPAÇO PARA SUA ESTRATÉGIA DE QUERY ---
    # Aqui você chamaria sua função de seleção:
    # new_indices = your_strategy.query(model, train_ds, unlabeled_indices, config['cycle_budget'])
    
    # Exemplo manual simplificado (Random Sampling):
    new_indices = unlabeled_indices[:config['cycle_budget']]
    
    # Atualiza os conjuntos
    labeled_indices = np.concatenate([labeled_indices, new_indices])
    unlabeled_indices = unlabeled_indices[config['cycle_budget']:]