# Test Corrected SPMI Implementation
This notebook tests the corrected SPMI implementation that addresses the collapse issue.

In [4]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

from dataset import SPMIDataset, get_transforms
from model import WideResNet, LeNet
from spmi import SPMI as SPMICorrected

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [5]:
# Configuration
config = {
    'dataset': 'cifar10',
    'num_labeled': 1000,  # Start smaller for testing
    'partial_rate': 0.3,
    'batch_size': 128,
    'lr': 0.1,
    'num_epochs': 50,
    'warmup_epochs': 5,
    'eval_interval': 2
}

# Create dataset
transform = get_transforms(config['dataset'], strong_aug=True)
dataset = SPMIDataset(
    dataset_name=config['dataset'],
    root='./data',
    num_labeled=config['num_labeled'],
    partial_rate=config['partial_rate'],
    transform=transform,
    download=True,
    seed=42
)

dataloader = DataLoader(
    dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=4,
    drop_last=True,
    pin_memory=True
)

test_loader = dataset.get_test_loader(batch_size=256)

print(f"Dataset: {config['num_labeled']} labeled, {len(dataset.unlabeled_indices)} unlabeled")

Files already downloaded and verified
Files already downloaded and verified
Dataset cifar10: 1000 labeled, 49000 unlabeled instances
Average number of candidates per labeled instance: 3.67
Dataset: 1000 labeled, 49000 unlabeled


In [6]:
# Create model
model = WideResNet(depth=28, widen_factor=2, num_classes=10).to(device)

# Optimizer with cosine annealing
optimizer = optim.SGD(
    model.parameters(), 
    lr=config['lr'],
    momentum=0.9, 
    weight_decay=5e-4,
    nesterov=True
)

scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=config['num_epochs']
)

# Initialize corrected SPMI
spmi = SPMICorrected(model, num_classes=10)

# Get original masks
original_masks = dataset.get_candidate_masks().to(device)

In [None]:
# Training loop with monitoring
history = {'loss': [], 'accuracy': [], 'temperature': []}
best_accuracy = 0.0

for epoch in range(config['num_epochs']):
    # Initialize unlabeled data after warmup
    if epoch == config['warmup_epochs']:
        print("\nInitializing unlabeled data...")
        spmi.initialize_unlabeled(dataloader, device)
        
        # Check initialization quality
        unlabeled_sample = min(100, len(dataset.unlabeled_indices))
        correct = 0
        for idx in dataset.unlabeled_indices[:unlabeled_sample]:
            if dataset.targets[idx] in dataset.candidate_labels[idx]:
                correct += 1
        print(f"Initialization quality: {correct}/{unlabeled_sample} ({100*correct/unlabeled_sample:.1f}%)")
    
    # Train epoch
    model.train()
    total_loss = 0
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}", leave=False)
    
    for batch_idx, batch_data in enumerate(pbar):
        imgs, candidate_masks, targets, indices, is_labeled = batch_data
        imgs = imgs.to(device)
        candidate_masks = candidate_masks.to(device)
        is_labeled = is_labeled.to(device)
        
        # Skip if no labeled data during warmup
        if epoch < config['warmup_epochs'] and not torch.any(is_labeled == 1):
            continue
        
        outputs = model(imgs)
        
        # Calculate loss
        if epoch < config['warmup_epochs']:
            labeled_mask = (is_labeled == 1)
            if labeled_mask.sum() > 0:
                loss = spmi.calculate_loss(outputs[labeled_mask], candidate_masks[labeled_mask])
            else:
                continue
        else:
            loss = spmi.calculate_loss(outputs, candidate_masks)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        # Update candidate labels
        if epoch >= config['warmup_epochs'] and batch_idx % 5 == 0:
            spmi.update_candidate_labels(
                outputs, candidate_masks, is_labeled, indices, 
                dataset, original_masks, device
            )
    
    scheduler.step()
    avg_loss = total_loss / len(dataloader)
    
    # Evaluate periodically
    if (epoch + 1) % config['eval_interval'] == 0 or epoch == 0:
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, targets in test_loader:
                images, targets = images.to(device), targets.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
        
        accuracy = 100.0 * correct / total
        
        if accuracy > best_accuracy:
            best_accuracy = accuracy
        
        history['loss'].append(avg_loss)
        history['accuracy'].append(accuracy)
        history['temperature'].append(spmi.temperature)
        
        print(f"Epoch {epoch+1}/{config['num_epochs']}: Loss={avg_loss:.4f}, "
              f"Acc={accuracy:.2f}%, Best={best_accuracy:.2f}%, "
              f"Temp={spmi.temperature:.2f}")

Epoch 1:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 1/50: Loss=0.8622, Acc=10.00%, Best=10.00%, Temp=1.00


Epoch 2:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 2/50: Loss=0.8333, Acc=10.00%, Best=10.00%, Temp=1.00


Epoch 3:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 4:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 4/50: Loss=0.8089, Acc=10.00%, Best=10.00%, Temp=1.00


Epoch 5:   0%|          | 0/390 [00:00<?, ?it/s]


Initializing unlabeled data...
Initialization quality: 40/100 (40.0%)


Epoch 6:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 6/50: Loss=-0.0152, Acc=10.00%, Best=10.00%, Temp=1.00


Epoch 7:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 8:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 8/50: Loss=-0.0311, Acc=10.00%, Best=10.00%, Temp=1.00


Epoch 9:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 10:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 10/50: Loss=-0.0378, Acc=11.71%, Best=11.71%, Temp=1.00


Epoch 11:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 12:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 12/50: Loss=-0.0421, Acc=10.01%, Best=11.71%, Temp=1.00


Epoch 13:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 14:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 14/50: Loss=-0.0452, Acc=10.00%, Best=11.71%, Temp=1.00


Epoch 15:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 16:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 16/50: Loss=-0.0469, Acc=10.00%, Best=11.71%, Temp=1.00


Epoch 17:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 18:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 18/50: Loss=-0.0480, Acc=10.00%, Best=11.71%, Temp=1.00


Epoch 19:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 20:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 20/50: Loss=-0.0488, Acc=10.00%, Best=11.71%, Temp=1.00


Epoch 21:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 22:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 22/50: Loss=-0.0494, Acc=10.00%, Best=11.71%, Temp=1.00


Epoch 23:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 24:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 24/50: Loss=-0.0497, Acc=10.00%, Best=11.71%, Temp=1.00


Epoch 25:   0%|          | 0/390 [00:00<?, ?it/s]

RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


: 

In [None]:
# Plot results
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Loss
axes[0].plot(history['loss'], marker='o')
axes[0].set_xlabel('Evaluation Step')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].grid(True)

# Accuracy
axes[1].plot(history['accuracy'], marker='o', color='green')
axes[1].set_xlabel('Evaluation Step')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Test Accuracy')
axes[1].grid(True)

# Temperature
axes[2].plot(history['temperature'], marker='o', color='red')
axes[2].set_xlabel('Evaluation Step')
axes[2].set_ylabel('Temperature')
axes[2].set_title('Temperature Schedule')
axes[2].grid(True)

plt.tight_layout()
plt.show()

print(f"\nFinal Results:")
print(f"Best accuracy: {best_accuracy:.2f}%")
print(f"Final accuracy: {history['accuracy'][-1]:.2f}%")

In [None]:
# Analyze prediction distribution
model.eval()
predictions = []
confidences = []

with torch.no_grad():
    for images, targets in test_loader:
        images = images.to(device)
        outputs = model(images)
        probs = F.softmax(outputs, dim=1)
        
        max_probs, preds = torch.max(probs, dim=1)
        predictions.extend(preds.cpu().numpy())
        confidences.extend(max_probs.cpu().numpy())

# Plot prediction distribution
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.hist(predictions, bins=10, alpha=0.7, edgecolor='black')
plt.xlabel('Predicted Class')
plt.ylabel('Count')
plt.title('Prediction Distribution')

plt.subplot(1, 2, 2)
plt.hist(confidences, bins=30, alpha=0.7, edgecolor='black')
plt.xlabel('Confidence')
plt.ylabel('Count')
plt.title('Confidence Distribution')

plt.tight_layout()
plt.show()

# Check if we have uniform predictions
pred_counts = np.bincount(predictions, minlength=10)
print("\nPrediction counts per class:")
for i, count in enumerate(pred_counts):
    print(f"Class {i}: {count} ({100*count/len(predictions):.1f}%)")

print(f"\nAverage confidence: {np.mean(confidences):.3f}")
print(f"Max confidence: {np.max(confidences):.3f}")
print(f"Min confidence: {np.min(confidences):.3f}")