# SPMI Debugging Notebook
This notebook helps identify why the model performance drops to 10% after initialization.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

# Import your modules
from dataset import SPMIDataset, get_transforms
from model import WideResNet

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [2]:
# Create small test dataset
dataset = SPMIDataset(
    dataset_name='cifar10',
    root='./data',
    num_labeled=100,  # Small for debugging
    partial_rate=0.3,
    transform=get_transforms('cifar10', strong_aug=False),
    download=True,
    seed=42
)

dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
test_loader = dataset.get_test_loader(batch_size=256)

print(f"Dataset: {len(dataset.labeled_indices)} labeled, {len(dataset.unlabeled_indices)} unlabeled")

Files already downloaded and verified
Files already downloaded and verified
Dataset cifar10: 100 labeled, 49900 unlabeled instances
Average number of candidates per labeled instance: 3.45
Dataset: 100 labeled, 49900 unlabeled


## Test 1: Standard Supervised Learning
First, let's verify that the model can learn with standard cross-entropy loss.

In [None]:
# Test standard supervised learning
model_ce = WideResNet(depth=16, widen_factor=2, num_classes=10).to(device)
optimizer_ce = optim.SGD(model_ce.parameters(), lr=0.1, momentum=0.9)
criterion_ce = nn.CrossEntropyLoss()

ce_losses = []
ce_accuracies = []

for epoch in range(10):
    model_ce.train()
    total_loss = 0
    for imgs, _, targets, _, _ in dataloader:
        imgs, targets = imgs.to(device), targets.to(device)
        
        optimizer_ce.zero_grad()
        outputs = model_ce(imgs)
        loss = criterion_ce(outputs, targets)
        loss.backward()
        optimizer_ce.step()
        
        total_loss += loss.item()
    
    # Evaluate
    model_ce.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, targets in test_loader:
            images, targets = images.to(device), targets.to(device)
            outputs = model_ce(images)
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    
    accuracy = 100.0 * correct / total
    avg_loss = total_loss / len(dataloader)
    
    ce_losses.append(avg_loss)
    ce_accuracies.append(accuracy)
    
    print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Accuracy={accuracy:.2f}%")


print("Standard supervised learning works well!")

KeyboardInterrupt: 

: 

## Test 2: Different SPMI Loss Implementations
Now let's test different interpretations of the SPMI loss to see which one works.

In [4]:
def weighted_ce_loss_v1(outputs, candidate_masks):
    """Version 1: Weighted sum of cross-entropy losses"""
    batch_size = outputs.size(0)
    total_loss = 0
    
    for i in range(batch_size):
        candidates = candidate_masks[i].nonzero(as_tuple=True)[0]
        if len(candidates) == 0:
            continue
        
        probs = F.softmax(outputs[i], dim=0)
        candidate_probs = probs[candidates]
        weights = candidate_probs / candidate_probs.sum()
        
        instance_loss = 0
        for j, candidate in enumerate(candidates):
            ce_loss = F.cross_entropy(outputs[i:i+1], torch.tensor([candidate], device=outputs.device))
            instance_loss += weights[j] * ce_loss
        
        total_loss += instance_loss
    
    return total_loss / batch_size


def weighted_ce_loss_v2(outputs, candidate_masks):
    """Version 2: Negative log of weighted sum of probabilities"""
    batch_size = outputs.size(0)
    total_loss = 0
    
    log_probs = F.log_softmax(outputs, dim=1)
    probs = F.softmax(outputs, dim=1)
    
    for i in range(batch_size):
        candidates = candidate_masks[i].nonzero(as_tuple=True)[0]
        if len(candidates) == 0:
            continue
        
        candidate_probs = probs[i, candidates]
        weights = candidate_probs / (candidate_probs.sum() + 1e-10)
        
        # Negative log of weighted sum of probabilities
        weighted_prob_sum = 0
        for j, candidate in enumerate(candidates):
            weighted_prob_sum += weights[j] * torch.exp(log_probs[i, candidate])
        
        loss = -torch.log(weighted_prob_sum + 1e-10)
        total_loss += loss
    
    return total_loss / batch_size


def kl_loss(outputs, candidate_masks):
    """Version 3: KL divergence with uniform target"""
    batch_size = outputs.size(0)
    num_classes = outputs.size(1)
    total_loss = 0
    
    log_probs = F.log_softmax(outputs, dim=1)
    
    for i in range(batch_size):
        candidates = candidate_masks[i].nonzero(as_tuple=True)[0]
        if len(candidates) == 0:
            continue
        
        # Create uniform target distribution over candidates
        target_dist = torch.zeros(num_classes).to(outputs.device)
        target_dist[candidates] = 1.0 / len(candidates)
        
        loss = F.kl_div(log_probs[i], target_dist, reduction='sum')
        total_loss += loss
    
    return total_loss / batch_size

In [5]:
# Test different loss functions
loss_functions = {
    'V1: Weighted CE': weighted_ce_loss_v1,
    'V2: Neg Log Weighted Prob': weighted_ce_loss_v2,
    'V3: KL Divergence': kl_loss
}

results = {}

for loss_name, loss_fn in loss_functions.items():
    print(f"\nTesting {loss_name}")
    
    # Create fresh model
    model = WideResNet(depth=16, widen_factor=2, num_classes=10).to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
    
    losses = []
    accuracies = []
    
    for epoch in range(10):
        model.train()
        total_loss = 0
        
        for imgs, candidate_masks, targets, _, is_labeled in dataloader:
            # Only use labeled data for fair comparison
            labeled_mask = (is_labeled == 1)
            if not torch.any(labeled_mask):
                continue
            
            imgs = imgs[labeled_mask].to(device)
            candidate_masks = candidate_masks[labeled_mask].to(device)
            targets = targets[labeled_mask].to(device)
            
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = loss_fn(outputs, candidate_masks)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        # Evaluate
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, targets in test_loader:
                images, targets = images.to(device), targets.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
        
        accuracy = 100.0 * correct / total
        avg_loss = total_loss / len(dataloader)
        
        losses.append(avg_loss)
        accuracies.append(accuracy)
        
        print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Accuracy={accuracy:.2f}%")
    
    results[loss_name] = {
        'losses': losses,
        'accuracies': accuracies,
        'final_accuracy': accuracies[-1]
    }


Testing V1: Weighted CE
Epoch 1: Loss=0.1597, Accuracy=10.01%
Epoch 2: Loss=0.1413, Accuracy=10.00%
Epoch 3: Loss=0.1358, Accuracy=10.56%
Epoch 4: Loss=0.1380, Accuracy=10.21%
Epoch 5: Loss=0.1317, Accuracy=10.84%
Epoch 6: Loss=0.1299, Accuracy=10.03%
Epoch 7: Loss=0.1342, Accuracy=10.00%
Epoch 8: Loss=0.1323, Accuracy=10.00%
Epoch 9: Loss=0.1375, Accuracy=10.00%
Epoch 10: Loss=0.1298, Accuracy=10.63%

Testing V2: Neg Log Weighted Prob
Epoch 1: Loss=0.1425, Accuracy=10.00%
Epoch 2: Loss=0.1281, Accuracy=10.50%
Epoch 3: Loss=0.1197, Accuracy=10.00%
Epoch 4: Loss=0.1259, Accuracy=10.00%
Epoch 5: Loss=0.1321, Accuracy=10.00%
Epoch 6: Loss=0.1313, Accuracy=10.36%
Epoch 7: Loss=0.1321, Accuracy=10.03%
Epoch 8: Loss=0.1267, Accuracy=10.00%
Epoch 9: Loss=0.1190, Accuracy=10.00%
Epoch 10: Loss=0.1303, Accuracy=10.00%

Testing V3: KL Divergence
Epoch 1: Loss=0.0926, Accuracy=10.08%
Epoch 2: Loss=0.0753, Accuracy=9.93%
Epoch 3: Loss=0.0782, Accuracy=9.02%
Epoch 4: Loss=0.0743, Accuracy=10.00%
E

## Test 3: Analyze Gradient Flow
Let's see what happens to gradients with partial label loss.

In [7]:
# Analyze gradients
model = WideResNet(depth=16, widen_factor=2, num_classes=10).to(device)

# Get a batch
for imgs, candidate_masks, targets, _, is_labeled in dataloader:
    labeled_mask = (is_labeled == 1)
    if torch.any(labeled_mask):
        imgs = imgs[labeled_mask][:4].to(device)
        candidate_masks = candidate_masks[labeled_mask][:4].to(device)
        targets = targets[labeled_mask][:4].to(device)
        break

print(f"Analyzing gradients for {len(imgs)} samples")
print(f"True labels: {targets.cpu().numpy()}")
print(f"Candidate masks:")
for i in range(len(imgs)):
    candidates = candidate_masks[i].nonzero(as_tuple=True)[0].cpu().numpy()
    print(f"  Sample {i}: {candidates}")

# Forward pass
outputs = model(imgs)
probs = F.softmax(outputs, dim=1)

print("\nInitial predictions:")
for i in range(len(imgs)):
    pred_class = outputs[i].argmax().item()
    pred_prob = probs[i].max().item()
    print(f"Sample {i}: pred={pred_class} (prob={pred_prob:.3f}), true={targets[i].item()}")

# Test different losses
losses = {
    'Standard CE': F.cross_entropy(outputs, targets),
    'Weighted CE V1': weighted_ce_loss_v1(outputs, candidate_masks),
    'Weighted CE V2': weighted_ce_loss_v2(outputs, candidate_masks),
    'KL Divergence': kl_loss(outputs, candidate_masks)
}

for loss_name, loss in losses.items():
    model.zero_grad()
    loss.backward(retain_graph=True)
    
    # Calculate gradient statistics
    grad_norms = []
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_norms.append(param.grad.norm().item())
    
    print(f"\n{loss_name}:")
    print(f"  Loss value: {loss.item():.4f}")
    print(f"  Avg gradient norm: {np.mean(grad_norms):.6f}")
    print(f"  Max gradient norm: {np.max(grad_norms):.6f}")

Analyzing gradients for 1 samples
True labels: [0]
Candidate masks:
  Sample 0: [0 5]

Initial predictions:
Sample 0: pred=7 (prob=0.123), true=0

Standard CE:
  Loss value: 2.4846
  Avg gradient norm: 0.345685
  Max gradient norm: 4.307232

Weighted CE V1:
  Loss value: 2.2785
  Avg gradient norm: 0.245159
  Max gradient norm: 2.980722

Weighted CE V2:
  Loss value: 2.2639
  Avg gradient norm: 0.260286
  Max gradient norm: 3.176491

KL Divergence:
  Loss value: 1.6159
  Avg gradient norm: 0.232526
  Max gradient norm: 2.819656


## Test 4: Visualize Prediction Evolution
Let's see how predictions evolve during training with SPMI loss.

In [8]:
model = WideResNet(depth=16, widen_factor=2, num_classes=10).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

# Get test samples - make sure we have enough
test_imgs = None
test_masks = None
test_targets = None

for imgs, candidate_masks, targets, _, is_labeled in dataloader:
    labeled_mask = (is_labeled == 1)
    labeled_indices = torch.where(labeled_mask)[0]
    
    if len(labeled_indices) >= 4:
        # Select exactly 4 labeled samples
        selected_indices = labeled_indices[:4]
        test_imgs = imgs[selected_indices].to(device)
        test_masks = candidate_masks[selected_indices].to(device)
        test_targets = targets[selected_indices]
        break

if test_imgs is None:
    print("Not enough labeled samples in batch. Using whatever we have...")
    for imgs, candidate_masks, targets, _, is_labeled in dataloader:
        labeled_mask = (is_labeled == 1)
        if torch.any(labeled_mask):
            test_imgs = imgs[labeled_mask].to(device)
            test_masks = candidate_masks[labeled_mask].to(device)
            test_targets = targets[labeled_mask]
            break

num_test_samples = len(test_imgs)
print(f"Tracking {num_test_samples} samples")

prediction_history = []
loss_history = []

# Train and track predictions
for step in range(50):
    # Record predictions
    with torch.no_grad():
        outputs = model(test_imgs)
        probs = F.softmax(outputs, dim=1)
        prediction_history.append(probs.cpu().numpy())
    
    # Training step
    model.train()
    optimizer.zero_grad()
    outputs = model(test_imgs)
    
    # Use weighted CE v2 (the correct implementation)
    loss = weighted_ce_loss_v2(outputs, test_masks)
    loss.backward()
    optimizer.step()
    
    loss_history.append(loss.item())
    
    if step % 10 == 0:
        print(f"Step {step}: Loss={loss.item():.4f}")

# Plot evolution - adjust for actual number of samples
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
axes = axes.ravel()

for i in range(min(4, num_test_samples)):
    ax = axes[i]
    
    # Extract predictions for this sample
    sample_preds = np.array([hist[i] for hist in prediction_history])
    
    # Plot evolution of each class probability
    for class_idx in range(10):
        ax.plot(sample_preds[:, class_idx], label=f'Class {class_idx}')
    
    ax.axhline(y=0.1, color='r', linestyle='--', label='Uniform')
    ax.set_title(f'Sample {i} (True: {test_targets[i].item()})')
    ax.set_xlabel('Training Step')
    ax.set_ylabel('Probability')
    ax.set_ylim(0, 1)
    
    # Mark candidate labels
    candidates = test_masks[i].nonzero(as_tuple=True)[0].cpu().numpy()
    ax.set_title(f'Sample {i} (True: {test_targets[i].item()}, Candidates: {candidates})')
    
    if i == 0:
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

# Hide unused subplots
for i in range(num_test_samples, 4):
    axes[i].set_visible(False)

plt.tight_layout()
plt.savefig('prediction_evolution.png', bbox_inches='tight', dpi=150)
plt.close()

# Plot loss evolution
plt.figure(figsize=(8, 4))
plt.plot(loss_history)
plt.xlabel('Training Step')
plt.ylabel('Loss')
plt.title('Loss Evolution')
plt.grid(True)
plt.savefig('loss_evolution.png', dpi=150)
plt.close()

# Print final predictions
print("\nFinal predictions:")
final_probs = prediction_history[-1]
for i in range(num_test_samples):
    max_prob = final_probs[i].max()
    pred_class = final_probs[i].argmax()
    candidates = test_masks[i].nonzero(as_tuple=True)[0].cpu().numpy()
    print(f"Sample {i}: pred={pred_class} (prob={max_prob:.3f}), true={test_targets[i].item()}, candidates={candidates}")

Not enough labeled samples in batch. Using whatever we have...
Tracking 1 samples
Step 0: Loss=2.1563
Step 10: Loss=0.0000
Step 20: Loss=0.0000
Step 30: Loss=0.0000
Step 40: Loss=0.0000

Final predictions:
Sample 0: pred=2 (prob=1.000), true=5, candidates=[0 1 2 3 5 6 7]


## Test 5: Why Does It Converge to Uniform?
Let's analyze why the weighted loss causes uniform predictions.

In [9]:
# Analyze the feedback loop
def analyze_weighted_loss_behavior():
    """Analyze why weighted loss leads to uniform predictions"""
    
    # Create synthetic example
    batch_size = 1
    num_classes = 10
    
    # Simulate different prediction scenarios
    scenarios = {
        'Peaked': torch.tensor([[5.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]),
        'Semi-uniform': torch.tensor([[1.0, 1.0, 1.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]),
        'Uniform': torch.zeros(1, 10)
    }
    
    # Create candidate mask (3 candidates)
    candidate_mask = torch.zeros(1, 10)
    candidate_mask[0, [0, 1, 2]] = 1
    
    print("Analyzing different prediction scenarios:")
    print("Candidate labels: [0, 1, 2]")
    print()
    
    for scenario_name, outputs in scenarios.items():
        probs = F.softmax(outputs, dim=1)
        log_probs = F.log_softmax(outputs, dim=1)
        
        # Get candidate probabilities
        candidates = candidate_mask[0].nonzero(as_tuple=True)[0]
        candidate_probs = probs[0, candidates]
        
        # Calculate weights
        weights = candidate_probs / candidate_probs.sum()
        
        # Calculate losses
        v1_loss = 0
        v2_loss = 0
        
        # V1: Weighted sum of CE losses
        for j, candidate in enumerate(candidates):
            ce_loss = F.cross_entropy(outputs, torch.tensor([candidate]))
            v1_loss += weights[j] * ce_loss
        
        # V2: Negative log of weighted sum of probabilities
        weighted_prob_sum = 0
        for j, candidate in enumerate(candidates):
            weighted_prob_sum += weights[j] * probs[0, candidate]
        v2_loss = -torch.log(weighted_prob_sum)
        
        print(f"{scenario_name} predictions:")
        print(f"  Outputs: {outputs[0][:3].numpy()}")
        print(f"  Probabilities: {probs[0][:3].numpy()}")
        print(f"  Candidate probs: {candidate_probs.numpy()}")
        print(f"  Weights: {weights.numpy()}")
        print(f"  V1 Loss: {v1_loss.item():.4f}")
        print(f"  V2 Loss: {v2_loss.item():.4f}")
        print()

analyze_weighted_loss_behavior()

Analyzing different prediction scenarios:
Candidate labels: [0, 1, 2]

Peaked predictions:
  Outputs: [5.  0.1 0.1]
  Probabilities: [0.93719035 0.00697887 0.00697887]
  Candidate probs: [0.93719035 0.00697887 0.00697887]
  Weights: [0.9853254  0.00733731 0.00733731]
  V1 Loss: 0.1368
  V2 Loss: 0.0795

Semi-uniform predictions:
  Outputs: [1. 1. 1.]
  Probabilities: [0.17105748 0.17105748 0.17105748]
  Candidate probs: [0.17105748 0.17105748 0.17105748]
  Weights: [0.3333333 0.3333333 0.3333333]
  V1 Loss: 1.7658
  V2 Loss: 1.7658

Uniform predictions:
  Outputs: [0. 0. 0.]
  Probabilities: [0.1 0.1 0.1]
  Candidate probs: [0.1 0.1 0.1]
  Weights: [0.3333333 0.3333333 0.3333333]
  V1 Loss: 2.3026
  V2 Loss: 2.3026



## Conclusion

The issue is that the weighted cross-entropy loss (V1) creates a feedback loop:
1. When predictions become uniform, all candidates get equal weight
2. This encourages the model to predict all candidates equally
3. The model converges to uniform predictions (10% accuracy)

The correct implementation (V2) calculates the negative log of the weighted sum of probabilities, which doesn't have this issue.

To fix your implementation, use the V2 loss function.