# Self-Supervised Pre-training of B-cos PIP-Net

This notebook implements self-supervised pre-training combining B-cos Networks and PIP-Net for interpretable prototype learning.

## Features:
- 6-channel image encoding: [r,g,b,1-r,1-g,1-b]
- Contrastive pairs through data augmentation
- B-cos convolution layers for interpretability
- Combined loss: Align Loss (La) + Tanh Loss (Lt)
- CUDA-ready for CIFAR-10 and CUB datasets

## 1. Setup and Installation

In [None]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Clone the repository
!git clone https://github.com/Karo555/improved-Bcos-PIPNet
%cd improved-Bcos-PIPNet

# Initialize submodules
!git submodule update --init --recursive

In [None]:
import sys
import os

# Add source directories to Python path
sys.path.append('src')
sys.path.append('B-cos')
sys.path.append('PIPNet')

# Verify paths
print("Current directory:", os.getcwd())
print("Source files exist:", os.path.exists('src/train_pretraining.py'))
print("B-cos modules exist:", os.path.exists('B-cos/modules/bcosconv2d.py'))
print("PIPNet modules exist:", os.path.exists('PIPNet/pipnet/pipnet.py'))

## 2. Import Modules and Test Implementation

In [None]:
# Import our custom modules
from src.transforms import SixChannelTransform, ContrastiveAugmentationCIFAR
from src.bcos_features import bcos_simple_features, bcos_medium_features, bcos_large_features
from src.losses import CombinedPretrainingLoss, InfoNCELoss
from src.bcos_pipnet import create_bcos_pipnet
from src.datasets import get_dataset

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

print("All modules imported successfully!")

In [None]:
# Test 6-channel transformation
print("Testing 6-channel transformation...")

# Create a sample RGB image
rgb_tensor = torch.rand(3, 32, 32)  # Random RGB image
print(f"Original RGB shape: {rgb_tensor.shape}")

# Apply 6-channel transform
six_channel_transform = SixChannelTransform()
six_channel_tensor = six_channel_transform(rgb_tensor)
print(f"6-channel shape: {six_channel_tensor.shape}")

# Verify the transformation
r, g, b = rgb_tensor[0], rgb_tensor[1], rgb_tensor[2]
r_6ch, g_6ch, b_6ch = six_channel_tensor[0], six_channel_tensor[1], six_channel_tensor[2]
inv_r, inv_g, inv_b = six_channel_tensor[3], six_channel_tensor[4], six_channel_tensor[5]

print(f"R channel match: {torch.allclose(r, r_6ch)}")
print(f"Inverse R correct: {torch.allclose(1-r, inv_r)}")
print("6-channel transformation working correctly!")

In [None]:
# Test model creation
print("Testing model creation...")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create model with simple B-cos CNN backbone (no ResNet)
model = create_bcos_pipnet(
    num_prototypes=256,
    backbone='bcos_simple',  # Changed from 'bcos_resnet18'
    pretrained=False
).to(device)

print(f"Model created successfully!")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Test forward pass
batch_size = 4
dummy_input = torch.randn(batch_size, 6, 32, 32).to(device)
print(f"Input shape: {dummy_input.shape}")

with torch.no_grad():
    proto_features, pooled_features, projected_features = model(dummy_input, return_features=True)
    
print(f"Prototype features shape: {proto_features.shape}")
print(f"Pooled features shape: {pooled_features.shape}")
print(f"Projected features shape: {projected_features.shape}")
print("Forward pass successful!")

In [None]:
# Test forward pass with corrected B-cos implementation
print("Testing forward pass with corrected B-cos implementation...")

# Test forward pass
batch_size = 4
dummy_input = torch.randn(batch_size, 6, 32, 32).to(device)
print(f"Input shape: {dummy_input.shape}")

with torch.no_grad():
    try:
        proto_features, pooled_features, projected_features = model(dummy_input, return_features=True)
        
        print(f"✅ B-cos forward pass successful!")
        print(f"Prototype features shape: {proto_features.shape}")
        print(f"Pooled features shape: {pooled_features.shape}")
        print(f"Projected features shape: {projected_features.shape}")
        
        # Check activation patterns - B-cos should have different characteristics
        print(f"\nActivation statistics:")
        print(f"Pooled features - mean: {pooled_features.mean():.4f}, std: {pooled_features.std():.4f}")
        print(f"Prototype features - mean: {proto_features.mean():.4f}, std: {proto_features.std():.4f}")
        
        # Check for interpretability properties - activations should be meaningful
        active_prototypes = (pooled_features > 0.1).sum(dim=1)
        print(f"Average active prototypes per sample: {active_prototypes.float().mean():.1f}")
        
    except Exception as e:
        print(f"❌ Forward pass failed: {e}")
        import traceback
        traceback.print_exc()

In [None]:
# Re-verify B-cos implementation after removing remaining ReLU layers
print("Re-verifying B-cos implementation after fixes...")

# Create a fresh model instance to test the fixes
model_fixed = create_bcos_pipnet(
    num_prototypes=256,
    backbone='bcos_simple',  # Using simple B-cos CNN instead of ResNet
    pretrained=False
).to(device)

print(f"Fixed B-cos Model parameters: {sum(p.numel() for p in model_fixed.parameters()):,}")


# Test forward pass with the fixed model
with torch.no_grad():
    try:
        proto_features, pooled_features, projected_features = model_fixed(dummy_input, return_features=True)
        
        print(f"✅ Fixed B-cos forward pass successful!")
        print(f"Prototype features shape: {proto_features.shape}")
        print(f"Pooled features shape: {pooled_features.shape}")
        print(f"Projected features shape: {projected_features.shape}")
        
        # Verify prototype activations are non-negative (using abs instead of ReLU)
        print(f"All prototype activations non-negative: {(pooled_features >= 0).all().item()}")
        
        # Update the model reference for training
        model = model_fixed
        print("✅ Model updated to use the corrected B-cos implementation")
        
    except Exception as e:
        print(f"❌ Forward pass failed: {e}")
        import traceback
        traceback.print_exc()

## 3. Dataset Setup

In [None]:
# Setup CIFAR-10 dataset
print("Setting up CIFAR-10 dataset...")

dataset = get_dataset('cifar10', 
                     data_dir='./data',
                     batch_size=64,  # Smaller batch for Colab
                     num_workers=2)  # Fewer workers for Colab

train_loader, test_loader = dataset.get_dataloaders()
print(f"Train samples: {len(train_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")
print(f"Train batches: {len(train_loader)}")

# Test data loading
view1, view2, labels = next(iter(train_loader))
print(f"View 1 shape: {view1.shape} (6-channel)")
print(f"View 2 shape: {view2.shape} (6-channel)")
print(f"Labels shape: {labels.shape}")
print("Dataset setup successful!")

In [None]:
# Visualize the 6-channel transformation
print("Visualizing 6-channel transformation...")

# Get a sample
view1, view2, labels = next(iter(train_loader))
sample = view1[0]  # First sample, shape: (6, 32, 32)

# Extract RGB and inverse channels
rgb_channels = sample[:3]
inv_channels = sample[3:]

# Denormalize for visualization (CIFAR-10 normalization)
mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
std = torch.tensor([0.2023, 0.1994, 0.2010]).view(3, 1, 1)
rgb_denorm = rgb_channels * std + mean
rgb_denorm = torch.clamp(rgb_denorm, 0, 1)

# Create visualization
fig, axes = plt.subplots(2, 3, figsize=(12, 8))

# Original RGB channels
for i, (ax, title) in enumerate(zip(axes[0], ['Red', 'Green', 'Blue'])):
    ax.imshow(rgb_denorm[i], cmap='gray')
    ax.set_title(f'{title} Channel')
    ax.axis('off')

# Inverse channels
for i, (ax, title) in enumerate(zip(axes[1], ['1-Red', '1-Green', '1-Blue'])):
    # Denormalize inverse channels
    inv_denorm = inv_channels[i] * std[i] + mean[i]
    inv_denorm = torch.clamp(inv_denorm, 0, 1)
    ax.imshow(inv_denorm, cmap='gray')
    ax.set_title(f'{title} Channel')
    ax.axis('off')

plt.tight_layout()
plt.show()

print("6-channel visualization complete!")

In [None]:
# Training configuration for Colab
class TrainingConfig:
    # Dataset
    dataset = 'cifar10'
    data_dir = './data'
    batch_size = 64  # Reduced for Colab
    num_workers = 2
    
    # Model
    # backbone options: 'bcos_simple', 'bcos_medium', 'bcos_large'
    backbone = 'bcos_simple'  # Simple B-cos CNN instead of ResNet
    num_prototypes = 256
    
    # Training
    epochs = 50  # Reduced for demo
    lr = 1e-3
    weight_decay = 1e-4
    warmup_epochs = 5
    
    # Loss weights
    align_weight = 1.0
    tanh_weight = 1.0
    contrastive_weight = 0.5  # Optional contrastive loss
    temperature = 0.07
    
    # Device and logging
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    log_interval = 50
    save_interval = 10
    
    # Directories
    log_dir = './logs/colab_training'
    save_dir = './checkpoints/colab_training'
    
    # Other
    seed = 42

args = TrainingConfig()
print(f"Training configuration:")
print(f"  Device: {args.device}")
print(f"  Epochs: {args.epochs}")
print(f"  Batch size: {args.batch_size}")
print(f"  Learning rate: {args.lr}")
print(f"  Prototypes: {args.num_prototypes}")
print(f"  Backbone: {args.backbone} (Simple B-cos CNN)")

In [None]:
# Create directories
os.makedirs(args.log_dir, exist_ok=True)
os.makedirs(args.save_dir, exist_ok=True)
print(f"Created directories: {args.log_dir}, {args.save_dir}")

## 5. Training Loop

In [None]:
import torch.optim as optim
import torch.nn.functional as F
from tqdm.auto import tqdm
import json
import time

# Set seed
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

# Create model
model = create_bcos_pipnet(
    num_prototypes=args.num_prototypes,
    backbone=args.backbone,
    pretrained=False
).to(args.device)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Create loss functions
criterion = CombinedPretrainingLoss(
    align_weight=args.align_weight,
    tanh_weight=args.tanh_weight
)

contrastive_criterion = InfoNCELoss(temperature=args.temperature)

# Create optimizer
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

print("Training setup complete!")

In [None]:
def adjust_learning_rate(optimizer, epoch, args):
    """Cosine learning rate schedule with warmup"""
    if epoch < args.warmup_epochs:
        lr = args.lr * epoch / args.warmup_epochs
    else:
        lr = args.lr * 0.5 * (1. + np.cos(np.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
    
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    
    return lr

In [None]:
# Training loop
print("Starting training...")

# Lists to store metrics
train_losses = []
align_losses = []
tanh_losses = []
contrastive_losses = []
learning_rates = []

# Training loop
for epoch in range(args.epochs):
    model.train()
    
    # Adjust learning rate
    lr = adjust_learning_rate(optimizer, epoch, args)
    learning_rates.append(lr)
    
    # Initialize metrics
    epoch_loss = 0.0
    epoch_align_loss = 0.0
    epoch_tanh_loss = 0.0
    epoch_contrastive_loss = 0.0
    
    # Progress bar
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{args.epochs}')
    
    for batch_idx, (view1, view2, labels) in enumerate(pbar):
        view1, view2 = view1.to(args.device), view2.to(args.device)
        
        optimizer.zero_grad()
        
        # Forward pass through both views
        proto_features1, pooled1, projected1 = model(view1, return_features=True)
        proto_features2, pooled2, projected2 = model(view2, return_features=True)
        
        # Compute combined pre-training loss (La + Lt)
        loss_dict = criterion(model, proto_features1, proto_features2)
        loss = loss_dict['total_loss']
        
        # Add contrastive loss if specified
        contrastive_loss = torch.tensor(0.0, device=args.device)
        if args.contrastive_weight > 0:
            # Normalize projected features
            projected1_norm = F.normalize(projected1, dim=1)
            projected2_norm = F.normalize(projected2, dim=1)
            contrastive_loss = contrastive_criterion(projected1_norm, projected2_norm)
            loss += args.contrastive_weight * contrastive_loss
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Update metrics
        epoch_loss += loss.item()
        epoch_align_loss += loss_dict['align_loss'].item()
        epoch_tanh_loss += loss_dict['tanh_loss'].item()
        epoch_contrastive_loss += contrastive_loss.item()
        
        # Update progress bar
        pbar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'La': f'{loss_dict["align_loss"].item():.4f}',
            'Lt': f'{loss_dict["tanh_loss"].item():.4f}',
            'Lc': f'{contrastive_loss.item():.4f}',
            'LR': f'{lr:.6f}'
        })
    
    # Calculate average losses for the epoch
    avg_loss = epoch_loss / len(train_loader)
    avg_align_loss = epoch_align_loss / len(train_loader)
    avg_tanh_loss = epoch_tanh_loss / len(train_loader)
    avg_contrastive_loss = epoch_contrastive_loss / len(train_loader)
    
    # Store metrics
    train_losses.append(avg_loss)
    align_losses.append(avg_align_loss)
    tanh_losses.append(avg_tanh_loss)
    contrastive_losses.append(avg_contrastive_loss)
    
    # Print epoch summary
    print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, La={avg_align_loss:.4f}, "
          f"Lt={avg_tanh_loss:.4f}, Lc={avg_contrastive_loss:.4f}, LR={lr:.6f}")
    
    # Save checkpoint
    if (epoch + 1) % args.save_interval == 0:
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': avg_loss,
            'args': vars(args)
        }
        torch.save(checkpoint, os.path.join(args.save_dir, f'checkpoint_epoch_{epoch+1}.pth'))
        print(f"Checkpoint saved at epoch {epoch+1}")

print("Training completed!")

## 6. Results Visualization

In [None]:
# Plot training losses
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Total loss
axes[0,0].plot(train_losses)
axes[0,0].set_title('Total Loss')
axes[0,0].set_xlabel('Epoch')
axes[0,0].set_ylabel('Loss')
axes[0,0].grid(True)

# Align loss
axes[0,1].plot(align_losses, color='orange')
axes[0,1].set_title('Align Loss (La)')
axes[0,1].set_xlabel('Epoch')
axes[0,1].set_ylabel('Loss')
axes[0,1].grid(True)

# Tanh loss
axes[1,0].plot(tanh_losses, color='green')
axes[1,0].set_title('Tanh Loss (Lt)')
axes[1,0].set_xlabel('Epoch')
axes[1,0].set_ylabel('Loss')
axes[1,0].grid(True)

# Learning rate
axes[1,1].plot(learning_rates, color='red')
axes[1,1].set_title('Learning Rate')
axes[1,1].set_xlabel('Epoch')
axes[1,1].set_ylabel('Learning Rate')
axes[1,1].grid(True)

plt.tight_layout()
plt.show()

print(f"Final losses:")
print(f"  Total: {train_losses[-1]:.4f}")
print(f"  Align: {align_losses[-1]:.4f}")
print(f"  Tanh: {tanh_losses[-1]:.4f}")
if args.contrastive_weight > 0:
    print(f"  Contrastive: {contrastive_losses[-1]:.4f}")

In [None]:
# Analyze learned prototypes
print("Analyzing learned prototypes...")

model.eval()
with torch.no_grad():
    # Get a batch of data
    view1, view2, labels = next(iter(train_loader))
    view1 = view1.to(args.device)
    
    # Get prototype activations
    proto_features, pooled_features = model.get_prototype_activations(view1)
    
    print(f"Prototype features shape: {proto_features.shape}")
    print(f"Pooled features shape: {pooled_features.shape}")
    
    # Analyze prototype activation statistics
    proto_stats = {
        'mean': pooled_features.mean(dim=0),
        'std': pooled_features.std(dim=0),
        'max': pooled_features.max(dim=0)[0],
        'min': pooled_features.min(dim=0)[0]
    }
    
    # Plot prototype activation distribution
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Histogram of prototype activations
    axes[0].hist(pooled_features.cpu().numpy().flatten(), bins=50, alpha=0.7)
    axes[0].set_title('Distribution of Prototype Activations')
    axes[0].set_xlabel('Activation Value')
    axes[0].set_ylabel('Frequency')
    axes[0].grid(True)
    
    # Mean activation per prototype
    axes[1].plot(proto_stats['mean'].cpu().numpy())
    axes[1].set_title('Mean Activation per Prototype')
    axes[1].set_xlabel('Prototype Index')
    axes[1].set_ylabel('Mean Activation')
    axes[1].grid(True)
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print(f"\nPrototype activation statistics:")
    print(f"  Overall mean: {pooled_features.mean():.4f}")
    print(f"  Overall std: {pooled_features.std():.4f}")
    print(f"  Max activation: {pooled_features.max():.4f}")
    print(f"  Min activation: {pooled_features.min():.4f}")
    
    # Count active prototypes (> 0.1 threshold like in PIP-Net)
    active_prototypes = (pooled_features > 0.1).sum(dim=1)
    print(f"  Average active prototypes per sample: {active_prototypes.float().mean():.1f}")

## 7. Save Final Model and Results

In [None]:
# Save final model
final_checkpoint = {
    'epoch': args.epochs,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'train_losses': train_losses,
    'align_losses': align_losses,
    'tanh_losses': tanh_losses,
    'contrastive_losses': contrastive_losses,
    'learning_rates': learning_rates,
    'args': vars(args),
    'num_prototypes': args.num_prototypes,
    'backbone': args.backbone
}

final_path = os.path.join(args.save_dir, 'final_model.pth')
torch.save(final_checkpoint, final_path)
print(f"Final model saved to: {final_path}")

# Save training metrics as JSON
metrics = {
    'train_losses': train_losses,
    'align_losses': align_losses,
    'tanh_losses': tanh_losses,
    'contrastive_losses': contrastive_losses,
    'learning_rates': learning_rates,
    'final_loss': train_losses[-1],
    'min_loss': min(train_losses),
    'training_config': vars(args)
}

with open(os.path.join(args.save_dir, 'training_metrics.json'), 'w') as f:
    json.dump(metrics, f, indent=2)

print("Training metrics saved!")

In [None]:
# Create downloadable archive of results
import zipfile

zip_path = 'bcos_pipnet_training_results.zip'
with zipfile.ZipFile(zip_path, 'w') as zipf:
    # Add model checkpoints
    for file in os.listdir(args.save_dir):
        if file.endswith('.pth') or file.endswith('.json'):
            zipf.write(os.path.join(args.save_dir, file), f'checkpoints/{file}')
    
    # Add this notebook
    if os.path.exists('BcosPIPNet_Training.ipynb'):
        zipf.write('BcosPIPNet_Training.ipynb', 'BcosPIPNet_Training.ipynb')

print(f"Results packaged in: {zip_path}")
print(f"Download this file to save your training results!")

# Show file sizes
if os.path.exists(final_path):
    size_mb = os.path.getsize(final_path) / (1024*1024)
    print(f"Final model size: {size_mb:.1f} MB")

if os.path.exists(zip_path):
    zip_size_mb = os.path.getsize(zip_path) / (1024*1024)
    print(f"Results archive size: {zip_size_mb:.1f} MB")

## 8. Next Steps

After completing the self-supervised pre-training, you can:

1. **Fine-tune for Classification**: Use the learned prototypes as initialization for supervised classification tasks
2. **Analyze Interpretability**: Visualize what each prototype has learned to represent
3. **Transfer Learning**: Apply the pre-trained model to other datasets
4. **Prototype Analysis**: Study the alignment properties and interpretability of the learned representations

### Loading the Pre-trained Model

```python
# Load the pre-trained model
checkpoint = torch.load('final_model.pth')
model = create_bcos_pipnet(
    num_prototypes=checkpoint['num_prototypes'],
    backbone=checkpoint['backbone']
)
model.load_state_dict(checkpoint['model_state_dict'])
```

### Training Summary

This implementation successfully combines:
- **B-cos Networks**: For interpretable convolution operations
- **PIP-Net**: For prototype-based learning
- **Self-supervised Learning**: Using contrastive augmentations and alignment losses
- **6-channel Encoding**: Enhanced input representation with complementary channels

The model learns interpretable prototypes in a self-supervised manner, making it suitable for applications requiring both high performance and interpretability.

## B-cos Implementation Updates Applied

### Key Changes Made:

1. **Replaced ResNet with Simple B-cos CNN**:
   - ❌ **NO ResNet structure** - Replaced with sequential BcosConv2d layers
   - ✅ **Simple CNN architecture** - Following B-cos CIFAR10 experiments
   - ✅ **Sequential layers** - [64→128→256→512] channels with stride=2 for downsampling

2. **Maintained Proper B-cos Components**:
   - ✅ **MaxOut in BcosConv2d** - All layers use `max_out=2` parameter
   - ✅ **Global Average Pooling** - `MyAdaptiveAvgPool2d((1, 1))` from B-cos utils
   - ✅ **Linear initialization** - 'linear' nonlinearity for B-cos weights
   - ✅ **Scale factor** - `scale_fact=100` following B-cos experiments

3. **Removed ALL Forbidden Components**:
   - ❌ **NO ReLU activations** - All `nn.ReLU()` layers removed
   - ❌ **NO BatchNorm layers** - All `nn.BatchNorm2d()` layers removed  
   - ❌ **NO MaxPooling** - All `nn.MaxPool2d()` layers removed
   - ❌ **NO Residual connections** - Simplified to pure sequential architecture

4. **New Backbone Options**:
   - **bcos_simple**: 7 layers, 512 output channels (replaces bcos_resnet18)
   - **bcos_medium**: 10 layers, 512 output channels (replaces bcos_resnet50)
   - **bcos_large**: 13 layers, 1024 output channels (for complex datasets)

### Architecture Details:

**Simple B-cos CNN Configuration**:
```
Input (6 channels) → BcosConv2d(6→64, k=3, s=1) → 
BcosConv2d(64→128, k=3, s=2) → BcosConv2d(128→128, k=3, s=1) →
BcosConv2d(128→256, k=3, s=2) → BcosConv2d(256→256, k=3, s=1) →
BcosConv2d(256→512, k=3, s=2) → BcosConv2d(512→512, k=3, s=1) →
Global Average Pool → Flatten (512 features)
```

### B-cos Compliance:
- **Pure B-cos layers**: Only BcosConv2d with cosine similarity
- **Built-in MaxOut**: `max_out=2` handles feature selection
- **No forbidden operations**: Complete removal of ReLU/BatchNorm/MaxPool
- **Interpretability preserved**: Maintains B-cos alignment properties

This implementation now follows the original B-cos methodology exactly as demonstrated in their CIFAR10 experiments, ensuring proper interpretability while maintaining prototype learning capabilities.