## Several Other Regimes

In [None]:
# Cell 1: Imports and Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import os
import math
from tqdm import tqdm

In [2]:
# Cell 2: Define Configurations
configs = {
    'baseline': {
        'alpha': 10.0,
        'tau': 0.5,
        'lambda_reg': 1e-3,
        'corruption_prob': 0.5,
        'num_samples': 5
    },
    'high_uncertainty': {
        'alpha': 20.0,
        'tau': 0.3,
        'lambda_reg': 1e-3,
        'corruption_prob': 0.5,
        'num_samples': 5
    },
    'low_uncertainty': {
        'alpha': 5.0,
        'tau': 0.7,
        'lambda_reg': 1e-3,
        'corruption_prob': 0.5,
        'num_samples': 5
    },
    'high_corruption': {
        'alpha': 10.0,
        'tau': 0.5,
        'lambda_reg': 1e-3,
        'corruption_prob': 0.7,
        'num_samples': 5
    },
    'high_samples': {
        'alpha': 10.0,
        'tau': 0.5,
        'lambda_reg': 1e-3,
        'corruption_prob': 0.5,
        'num_samples': 10
    }
}



In [3]:
# Cell 3: Model and Training Functions
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(channels, channels, 3, padding=1)
        )
    def forward(self, x):
        return x + self.block(x)

class SimpleUNet(nn.Module):
    def __init__(self, channels=32):
        super().__init__()
        self.down1 = nn.Conv2d(1, channels, 3, padding=1, stride=2)
        self.down2 = nn.Conv2d(channels, channels, 3, padding=1, stride=2)
        self.res = ResidualBlock(channels)
        self.up1 = nn.ConvTranspose2d(channels, channels, 4, stride=2, padding=1)
        self.up2 = nn.ConvTranspose2d(channels, 1, 4, stride=2, padding=1)
    def forward(self, x):
        x1 = F.relu(self.down1(x))
        x2 = F.relu(self.down2(x1))
        x3 = self.res(x2)
        x4 = F.relu(self.up1(x3))
        return torch.tanh(self.up2(x4))

def compute_uncertainty(model, x, num_samples=5):
    model.train()
    preds = [model(x) for _ in range(num_samples)]
    model.eval()
    return torch.var(torch.stack(preds), dim=0).mean(1, keepdim=True)

def generate_mask(uncertainty, alpha=10.0, tau=0.5):
    prob = torch.sigmoid(alpha * (uncertainty - tau))
    return torch.bernoulli(prob).detach()

def regularize_mask(mask, lambda_reg=1e-3):
    return lambda_reg * mask.mean()

def corrupt_input(x, mask):
    noise = torch.randn_like(x)
    return mask * noise + (1 - mask) * x

In [4]:
# Cell 4: Training Function
def train_ablation(config_name, config, num_epochs=10, device='cuda' if torch.cuda.is_available() else 'cpu'):
    print(f"\nRunning experiment: {config_name}")
    
    # Setup data
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x * 2 - 1)  # Scale to [-1, 1]
    ])
    mnist = datasets.MNIST(root='.', train=True, download=True, transform=transform)
    dataloader = DataLoader(mnist, batch_size=128, shuffle=True)
    
    # Create models
    model_adaptive = SimpleUNet().to(device)
    model_random = SimpleUNet().to(device)
    optimizer_adaptive = torch.optim.Adam(model_adaptive.parameters(), lr=1e-3)
    optimizer_random = torch.optim.Adam(model_random.parameters(), lr=1e-3)
    
    # Training metrics
    adaptive_losses = []
    random_losses = []
    
    for epoch in range(num_epochs):
        model_adaptive.train()
        model_random.train()
        epoch_adaptive_losses = []
        epoch_random_losses = []
        
        for x, _ in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            x = x.to(device)
            
            # Adaptive corruption
            unc = compute_uncertainty(model_adaptive, x, config['num_samples'])
            mask_adaptive = generate_mask(unc, config['alpha'], config['tau'])
            x_corrupt_adaptive = corrupt_input(x, mask_adaptive)
            pred_adaptive = model_adaptive(x_corrupt_adaptive)
            loss_adaptive = F.mse_loss(pred_adaptive, x) + regularize_mask(mask_adaptive, config['lambda_reg'])
            
            optimizer_adaptive.zero_grad()
            loss_adaptive.backward()
            optimizer_adaptive.step()
            
            # Random corruption
            mask_random = torch.bernoulli(torch.full_like(x, config['corruption_prob']))
            x_corrupt_random = corrupt_input(x, mask_random)
            pred_random = model_random(x_corrupt_random)
            loss_random = F.mse_loss(pred_random, x)
            
            optimizer_random.zero_grad()
            loss_random.backward()
            optimizer_random.step()
            
            epoch_adaptive_losses.append(loss_adaptive.item())
            epoch_random_losses.append(loss_random.item())
        
        # Record epoch losses
        adaptive_losses.append(np.mean(epoch_adaptive_losses))
        random_losses.append(np.mean(epoch_random_losses))
        
        print(f"Epoch {epoch+1}: Adaptive Loss = {adaptive_losses[-1]:.4f}, Random Loss = {random_losses[-1]:.4f}")
        
        # Save samples
        if (epoch + 1) % 5 == 0:
            model_adaptive.eval()
            model_random.eval()
            with torch.no_grad():
                x, _ = next(iter(dataloader))
                x = x.to(device)[:8]
                
                unc = compute_uncertainty(model_adaptive, x, config['num_samples'])
                mask_adaptive = generate_mask(unc, config['alpha'], config['tau'])
                mask_random = torch.bernoulli(torch.full_like(x, config['corruption_prob']))
                
                x_corrupt_adaptive = corrupt_input(x, mask_adaptive)
                x_corrupt_random = corrupt_input(x, mask_random)
                
                recon_adaptive = model_adaptive(x_corrupt_adaptive)
                recon_random = model_random(x_corrupt_random)
                
                # Save visualization
                images = torch.cat([x, x_corrupt_adaptive, recon_adaptive, x_corrupt_random, recon_random], dim=0)
                grid = utils.make_grid(images, nrow=8, normalize=True, value_range=(-1, 1))
                plt.figure(figsize=(20, 5))
                plt.imshow(grid.permute(1, 2, 0).cpu())
                plt.axis('off')
                plt.title(f"{config_name} - Epoch {epoch+1}")
                plt.savefig(f'ablation_{config_name}_epoch_{epoch+1}.png')
                plt.close()
    
    return {
        'adaptive_losses': adaptive_losses,
        'random_losses': random_losses,
        'config': config
    }

In [5]:
# Cell 5: Plotting Function
def plot_ablation_results(results):
    # Plot losses
    plt.figure(figsize=(12, 6))
    for config_name, result in results.items():
        plt.plot(result['adaptive_losses'], label=f"{config_name} (Adaptive)")
        plt.plot(result['random_losses'], label=f"{config_name} (Random)", linestyle='--')
    plt.title('Training Losses Across Configurations')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('ablation_losses.png')
    plt.close()
    
    # Create summary table
    summary = []
    for config_name, result in results.items():
        config = result['config']
        summary.append({
            'Configuration': config_name,
            'Final Adaptive Loss': f"{result['adaptive_losses'][-1]:.4f}",
            'Final Random Loss': f"{result['random_losses'][-1]:.4f}",
            'Alpha': config['alpha'],
            'Tau': config['tau'],
            'Corruption Prob': config['corruption_prob'],
            'Num Samples': config['num_samples']
        })
    
    # Save summary to file
    with open('ablation_summary.txt', 'w') as f:
        f.write("Adaptive Corruption Ablation Study Summary\n")
        f.write("========================================\n\n")
        for entry in summary:
            f.write(f"Configuration: {entry['Configuration']}\n")
            f.write(f"Final Adaptive Loss: {entry['Final Adaptive Loss']}\n")
            f.write(f"Final Random Loss: {entry['Final Random Loss']}\n")
            f.write(f"Alpha: {entry['Alpha']}\n")
            f.write(f"Tau: {entry['Tau']}\n")
            f.write(f"Corruption Probability: {entry['Corruption Prob']}\n")
            f.write(f"Number of Samples: {entry['Num Samples']}\n")
            f.write("\n")



In [6]:
# Cell 6: Run Experiments
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
results = {}

for config_name, config in configs.items():
    results[config_name] = train_ablation(
        config_name,
        config,
        num_epochs=10,
        device=device
    )

# Plot and save results
plot_ablation_results(results)




Running experiment: baseline


Epoch 1/10:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 1/10: 100%|██████████| 469/469 [00:11<00:00, 41.90it/s]


Epoch 1: Adaptive Loss = 0.4551, Random Loss = 0.1336


Epoch 2/10: 100%|██████████| 469/469 [00:10<00:00, 42.80it/s]


Epoch 2: Adaptive Loss = 0.4480, Random Loss = 0.0847


Epoch 3/10: 100%|██████████| 469/469 [00:11<00:00, 41.75it/s]


Epoch 3: Adaptive Loss = 0.4480, Random Loss = 0.0733


Epoch 4/10: 100%|██████████| 469/469 [00:11<00:00, 42.54it/s]


Epoch 4: Adaptive Loss = 0.4480, Random Loss = 0.0681


Epoch 5/10: 100%|██████████| 469/469 [00:11<00:00, 42.14it/s]


Epoch 5: Adaptive Loss = 0.4480, Random Loss = 0.0646


Epoch 6/10: 100%|██████████| 469/469 [00:09<00:00, 49.38it/s]


Epoch 6: Adaptive Loss = 0.4480, Random Loss = 0.0615


Epoch 7/10: 100%|██████████| 469/469 [00:09<00:00, 50.90it/s]


Epoch 7: Adaptive Loss = 0.4480, Random Loss = 0.0589


Epoch 8/10: 100%|██████████| 469/469 [00:09<00:00, 47.15it/s]


Epoch 8: Adaptive Loss = 0.4480, Random Loss = 0.0570


Epoch 9/10: 100%|██████████| 469/469 [00:10<00:00, 42.68it/s]


Epoch 9: Adaptive Loss = 0.4480, Random Loss = 0.0553


Epoch 10/10: 100%|██████████| 469/469 [00:10<00:00, 42.79it/s]


Epoch 10: Adaptive Loss = 0.4480, Random Loss = 0.0538

Running experiment: high_uncertainty


Epoch 1/10: 100%|██████████| 469/469 [00:11<00:00, 42.44it/s]


Epoch 1: Adaptive Loss = 0.0332, Random Loss = 0.1421


Epoch 2/10: 100%|██████████| 469/469 [00:10<00:00, 42.74it/s]


Epoch 2: Adaptive Loss = 0.0042, Random Loss = 0.0900


Epoch 3/10: 100%|██████████| 469/469 [00:11<00:00, 42.61it/s]


Epoch 3: Adaptive Loss = 0.0032, Random Loss = 0.0746


Epoch 4/10: 100%|██████████| 469/469 [00:10<00:00, 43.38it/s]


Epoch 4: Adaptive Loss = 0.0027, Random Loss = 0.0674


Epoch 5/10: 100%|██████████| 469/469 [00:11<00:00, 42.30it/s]


Epoch 5: Adaptive Loss = 0.0024, Random Loss = 0.0635


Epoch 6/10: 100%|██████████| 469/469 [00:11<00:00, 42.21it/s]


Epoch 6: Adaptive Loss = 0.0021, Random Loss = 0.0599


Epoch 7/10: 100%|██████████| 469/469 [00:10<00:00, 43.62it/s]


Epoch 7: Adaptive Loss = 0.0020, Random Loss = 0.0578


Epoch 8/10: 100%|██████████| 469/469 [00:10<00:00, 43.08it/s]


Epoch 8: Adaptive Loss = 0.0018, Random Loss = 0.0562


Epoch 9/10: 100%|██████████| 469/469 [00:10<00:00, 42.81it/s]


Epoch 9: Adaptive Loss = 0.0017, Random Loss = 0.0546


Epoch 10/10: 100%|██████████| 469/469 [00:10<00:00, 43.42it/s]


Epoch 10: Adaptive Loss = 0.0016, Random Loss = 0.0530

Running experiment: low_uncertainty


Epoch 1/10: 100%|██████████| 469/469 [00:10<00:00, 44.11it/s]


Epoch 1: Adaptive Loss = 0.0486, Random Loss = 0.4544


Epoch 2/10: 100%|██████████| 469/469 [00:10<00:00, 43.45it/s]


Epoch 2: Adaptive Loss = 0.0095, Random Loss = 0.4480


Epoch 3/10: 100%|██████████| 469/469 [00:10<00:00, 44.13it/s]


Epoch 3: Adaptive Loss = 0.0085, Random Loss = 0.4480


Epoch 4/10: 100%|██████████| 469/469 [00:10<00:00, 43.53it/s]


Epoch 4: Adaptive Loss = 0.0079, Random Loss = 0.4480


Epoch 5/10: 100%|██████████| 469/469 [00:10<00:00, 44.09it/s]


Epoch 5: Adaptive Loss = 0.0074, Random Loss = 0.4480


Epoch 6/10: 100%|██████████| 469/469 [00:10<00:00, 42.96it/s]


Epoch 6: Adaptive Loss = 0.0068, Random Loss = 0.4480


Epoch 7/10: 100%|██████████| 469/469 [00:11<00:00, 41.98it/s]


Epoch 7: Adaptive Loss = 0.0063, Random Loss = 0.4480


Epoch 8/10: 100%|██████████| 469/469 [00:11<00:00, 41.52it/s]


Epoch 8: Adaptive Loss = 0.0057, Random Loss = 0.4480


Epoch 9/10: 100%|██████████| 469/469 [00:10<00:00, 43.19it/s]


Epoch 9: Adaptive Loss = 0.0053, Random Loss = 0.4480


Epoch 10/10: 100%|██████████| 469/469 [00:10<00:00, 43.80it/s]


Epoch 10: Adaptive Loss = 0.0050, Random Loss = 0.4480

Running experiment: high_corruption


Epoch 1/10: 100%|██████████| 469/469 [00:11<00:00, 42.20it/s]


Epoch 1: Adaptive Loss = 0.0380, Random Loss = 0.4559


Epoch 2/10: 100%|██████████| 469/469 [00:10<00:00, 43.73it/s]


Epoch 2: Adaptive Loss = 0.0054, Random Loss = 0.4480


Epoch 3/10: 100%|██████████| 469/469 [00:09<00:00, 50.87it/s]


Epoch 3: Adaptive Loss = 0.0044, Random Loss = 0.4480


Epoch 4/10: 100%|██████████| 469/469 [00:10<00:00, 43.74it/s]


Epoch 4: Adaptive Loss = 0.0039, Random Loss = 0.4480


Epoch 5/10: 100%|██████████| 469/469 [00:10<00:00, 43.40it/s]


Epoch 5: Adaptive Loss = 0.0036, Random Loss = 0.4480


Epoch 6/10: 100%|██████████| 469/469 [00:10<00:00, 43.63it/s]


Epoch 6: Adaptive Loss = 0.0033, Random Loss = 0.4480


Epoch 7/10: 100%|██████████| 469/469 [00:10<00:00, 43.29it/s]


Epoch 7: Adaptive Loss = 0.0032, Random Loss = 0.4480


Epoch 8/10: 100%|██████████| 469/469 [00:11<00:00, 42.55it/s]


Epoch 8: Adaptive Loss = 0.0030, Random Loss = 0.4480


Epoch 9/10: 100%|██████████| 469/469 [00:11<00:00, 42.08it/s]


Epoch 9: Adaptive Loss = 0.0029, Random Loss = 0.4480


Epoch 10/10: 100%|██████████| 469/469 [00:10<00:00, 42.76it/s]


Epoch 10: Adaptive Loss = 0.0028, Random Loss = 0.4480

Running experiment: high_samples


Epoch 1/10: 100%|██████████| 469/469 [00:12<00:00, 37.69it/s]


Epoch 1: Adaptive Loss = 0.0417, Random Loss = 0.1301


Epoch 2/10: 100%|██████████| 469/469 [00:12<00:00, 38.38it/s]


Epoch 2: Adaptive Loss = 0.0054, Random Loss = 0.0823


Epoch 3/10: 100%|██████████| 469/469 [00:12<00:00, 37.96it/s]


Epoch 3: Adaptive Loss = 0.0044, Random Loss = 0.0704


Epoch 4/10: 100%|██████████| 469/469 [00:12<00:00, 38.14it/s]


Epoch 4: Adaptive Loss = 0.0039, Random Loss = 0.0658


Epoch 5/10: 100%|██████████| 469/469 [00:12<00:00, 38.32it/s]


Epoch 5: Adaptive Loss = 0.0036, Random Loss = 0.0627


Epoch 6/10: 100%|██████████| 469/469 [00:10<00:00, 46.69it/s]


Epoch 6: Adaptive Loss = 0.0033, Random Loss = 0.0595


Epoch 7/10: 100%|██████████| 469/469 [00:12<00:00, 37.53it/s]


Epoch 7: Adaptive Loss = 0.0031, Random Loss = 0.0569


Epoch 8/10: 100%|██████████| 469/469 [00:12<00:00, 37.99it/s]


Epoch 8: Adaptive Loss = 0.0030, Random Loss = 0.0554


Epoch 9/10: 100%|██████████| 469/469 [00:12<00:00, 38.51it/s]


Epoch 9: Adaptive Loss = 0.0029, Random Loss = 0.0542


Epoch 10/10: 100%|██████████| 469/469 [00:12<00:00, 37.87it/s]


Epoch 10: Adaptive Loss = 0.0028, Random Loss = 0.0529


In [7]:
# Cell 7: Display Results
print("\nAblation Study Results:")
print("======================")
for config_name, result in results.items():
    print(f"\nConfiguration: {config_name}")
    print(f"Final Adaptive Loss: {result['adaptive_losses'][-1]:.4f}")
    print(f"Final Random Loss: {result['random_losses'][-1]:.4f}")
    print(f"Alpha: {result['config']['alpha']}")
    print(f"Tau: {result['config']['tau']}")
    print(f"Corruption Probability: {result['config']['corruption_prob']}")
    print(f"Number of Samples: {result['config']['num_samples']}") 


Ablation Study Results:

Configuration: baseline
Final Adaptive Loss: 0.4480
Final Random Loss: 0.0538
Alpha: 10.0
Tau: 0.5
Corruption Probability: 0.5
Number of Samples: 5

Configuration: high_uncertainty
Final Adaptive Loss: 0.0016
Final Random Loss: 0.0530
Alpha: 20.0
Tau: 0.3
Corruption Probability: 0.5
Number of Samples: 5

Configuration: low_uncertainty
Final Adaptive Loss: 0.0050
Final Random Loss: 0.4480
Alpha: 5.0
Tau: 0.7
Corruption Probability: 0.5
Number of Samples: 5

Configuration: high_corruption
Final Adaptive Loss: 0.0028
Final Random Loss: 0.4480
Alpha: 10.0
Tau: 0.5
Corruption Probability: 0.7
Number of Samples: 5

Configuration: high_samples
Final Adaptive Loss: 0.0028
Final Random Loss: 0.0529
Alpha: 10.0
Tau: 0.5
Corruption Probability: 0.5
Number of Samples: 10


In [9]:
pip install scikit-image

Collecting scikit-image
  Downloading scikit_image-0.25.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (14 kB)
Collecting imageio!=2.35.0,>=2.33 (from scikit-image)
  Downloading imageio-2.37.0-py3-none-any.whl.metadata (5.2 kB)
Collecting tifffile>=2022.8.12 (from scikit-image)
  Downloading tifffile-2025.5.10-py3-none-any.whl.metadata (31 kB)
Collecting lazy-loader>=0.4 (from scikit-image)
  Downloading lazy_loader-0.4-py3-none-any.whl.metadata (7.6 kB)
Downloading scikit_image-0.25.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.0/15.0 MB[0m [31m24.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading imageio-2.37.0-py3-none-any.whl (315 kB)
Downloading lazy_loader-0.4-py3-none-any.whl (12 kB)
Downloading tifffile-2025.5.10-py3-none-any.whl (226 kB)
Installing collected packages: tifffile, lazy-loader, imageio, scikit-image
Successfully installed imageio-2.3

In [None]:
# Cell 1: Imports and Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import os
import math
from tqdm import tqdm
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

# Cell 2: Define Configurations
configs = {
    'baseline': {
        'alpha': 10.0,
        'tau': 0.5,
        'lambda_reg': 1e-3,
        'corruption_prob': 0.5,
        'num_samples': 5,
        'invert_uncertainty': False
    },
    'high_uncertainty': {
        'alpha': 20.0,
        'tau': 0.3,
        'lambda_reg': 1e-3,
        'corruption_prob': 0.5,
        'num_samples': 5,
        'invert_uncertainty': False
    },
    'low_uncertainty': {
        'alpha': 5.0,
        'tau': 0.7,
        'lambda_reg': 1e-3,
        'corruption_prob': 0.5,
        'num_samples': 5,
        'invert_uncertainty': False
    },
    'high_corruption': {
        'alpha': 10.0,
        'tau': 0.5,
        'lambda_reg': 1e-3,
        'corruption_prob': 0.7,
        'num_samples': 5,
        'invert_uncertainty': False
    },
    'high_samples': {
        'alpha': 10.0,
        'tau': 0.5,
        'lambda_reg': 1e-3,
        'corruption_prob': 0.5,
        'num_samples': 10,
        'invert_uncertainty': False
    },
    'inverted_uncertainty': {
        'alpha': 10.0,
        'tau': 0.5,
        'lambda_reg': 1e-3,
        'corruption_prob': 0.5,
        'num_samples': 5,
        'invert_uncertainty': True
    }
}

# Cell 3: Model and Training Functions
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(channels, channels, 3, padding=1)
        )
    def forward(self, x):
        return x + self.block(x)

class SimpleUNet(nn.Module):
    def __init__(self, channels=32):
        super().__init__()
        self.down1 = nn.Conv2d(1, channels, 3, padding=1, stride=2)
        self.down2 = nn.Conv2d(channels, channels, 3, padding=1, stride=2)
        self.res = ResidualBlock(channels)
        self.up1 = nn.ConvTranspose2d(channels, channels, 4, stride=2, padding=1)
        self.up2 = nn.ConvTranspose2d(channels, 1, 4, stride=2, padding=1)
    def forward(self, x):
        x1 = F.relu(self.down1(x))
        x2 = F.relu(self.down2(x1))
        x3 = self.res(x2)
        x4 = F.relu(self.up1(x3))
        return torch.tanh(self.up2(x4))

def compute_uncertainty(model, x, num_samples=5):
    model.train()
    preds = [model(x) for _ in range(num_samples)]
    model.eval()
    return torch.var(torch.stack(preds), dim=0).mean(1, keepdim=True)

def generate_mask(uncertainty, alpha=10.0, tau=0.5, invert_uncertainty=False):
    # if invert_uncertainty:
    #     uncertainty = 1.0 - uncertainty  # Invert uncertainty to corrupt certain regions
    prob = torch.sigmoid(alpha * (uncertainty - tau))
    if invert_uncertainty:
        prob = 1 - prob
    return torch.bernoulli(prob).detach()

def regularize_mask(mask, lambda_reg=1e-3):
    return lambda_reg * mask.mean()

def corrupt_input(x, mask):
    noise = torch.randn_like(x)
    return mask * noise + (1 - mask) * x

# Add evaluation metrics
def compute_psnr(img1, img2):
    """Compute PSNR between two images."""
    img1 = img1.detach().cpu().numpy()
    img2 = img2.detach().cpu().numpy()
    return psnr(img1, img2, data_range=2.0)  # data_range=2.0 for [-1,1] range

def compute_ssim(img1, img2):
    """Compute SSIM between two images."""
    # Convert to numpy and handle batch dimension
    img1 = img1.detach().cpu().numpy()
    img2 = img2.detach().cpu().numpy()
    
    # Handle batch dimension
    if len(img1.shape) == 4:  # [B, C, H, W]
        # Compute SSIM for each image in the batch
        ssim_values = []
        for i in range(img1.shape[0]):
            # Remove channel dimension for MNIST (single channel)
            img1_i = img1[i, 0]  # Shape: [H, W]
            img2_i = img2[i, 0]  # Shape: [H, W]
            ssim_val = ssim(img1_i, img2_i, data_range=2.0, win_size=3)
            ssim_values.append(ssim_val)
        return np.mean(ssim_values)
    else:
        # Single image case
        return ssim(img1[0], img2[0], data_range=2.0, win_size=3)

def analyze_corruption_coverage(mask, uncertainty):
    """Analyze where corruption is being applied."""
    # Compute average uncertainty in corrupted vs uncorrupted regions
    corrupted_uncertainty = (mask * uncertainty).sum() / (mask.sum() + 1e-6)
    uncorrupted_uncertainty = ((1 - mask) * uncertainty).sum() / ((1 - mask).sum() + 1e-6)
    
    # Compute corruption ratio
    corruption_ratio = mask.mean().item()
    
    return {
        'corrupted_uncertainty': corrupted_uncertainty.item(),
        'uncorrupted_uncertainty': uncorrupted_uncertainty.item(),
        'corruption_ratio': corruption_ratio
    }

# Add after the existing imports
def visualize_uncertainty_and_corruption(x, uncertainty, mask, pred, config_name, epoch, save_dir='visualizations'):
    """Generate detailed visualizations of uncertainty and corruption patterns."""
    os.makedirs(save_dir, exist_ok=True)
    
    # Convert tensors to numpy for visualization, properly detaching first
    x_np = x.detach().cpu().numpy()
    uncertainty_np = uncertainty.detach().cpu().numpy()
    mask_np = mask.detach().cpu().numpy()
    pred_np = pred.detach().cpu().numpy()
    
    # Create a figure with 2x2 subplots
    fig, axes = plt.subplots(2, 2, figsize=(12, 12))
    
    # Original image
    axes[0, 0].imshow(x_np[0, 0], cmap='gray')
    axes[0, 0].set_title('Original Image')
    axes[0, 0].axis('off')
    
    # Uncertainty map
    im = axes[0, 1].imshow(uncertainty_np[0, 0], cmap='hot')
    axes[0, 1].set_title('Uncertainty Map')
    axes[0, 1].axis('off')
    plt.colorbar(im, ax=axes[0, 1])
    
    # Corruption mask
    axes[1, 0].imshow(mask_np[0, 0], cmap='gray')
    axes[1, 0].set_title('Corruption Mask')
    axes[1, 0].axis('off')
    
    # Reconstruction
    axes[1, 1].imshow(pred_np[0, 0], cmap='gray')
    axes[1, 1].set_title('Reconstruction')
    axes[1, 1].axis('off')
    
    plt.suptitle(f'{config_name} - Epoch {epoch}')
    plt.tight_layout()
    plt.savefig(f'{save_dir}/uncertainty_analysis_{config_name}_epoch_{epoch}.png')
    plt.close()

def visualize_batch_comparison(x, x_corrupt_adaptive, recon_adaptive, x_corrupt_random, recon_random, config_name, epoch, save_dir='visualizations'):
    """Generate a detailed comparison of a batch of images."""
    os.makedirs(save_dir, exist_ok=True)
    
    # Create a figure with 5 rows (original, corrupted, reconstruction for both methods)
    fig, axes = plt.subplots(5, 8, figsize=(20, 12))
    
    # Helper function to plot a row of images
    def plot_row(images, row_idx, title):
        for col_idx in range(8):
            axes[row_idx, col_idx].imshow(images[col_idx, 0].detach().cpu(), cmap='gray')
            axes[row_idx, col_idx].axis('off')
        axes[row_idx, 0].set_ylabel(title)
    
    # Plot each row
    plot_row(x, 0, 'Original')
    plot_row(x_corrupt_adaptive, 1, 'Adaptive\nCorrupted')
    plot_row(recon_adaptive, 2, 'Adaptive\nReconstruction')
    plot_row(x_corrupt_random, 3, 'Random\nCorrupted')
    plot_row(recon_random, 4, 'Random\nReconstruction')
    
    plt.suptitle(f'{config_name} - Epoch {epoch}')
    plt.tight_layout()
    plt.savefig(f'{save_dir}/batch_comparison_{config_name}_epoch_{epoch}.png')
    plt.close()

# Cell 4: Training Function
def train_ablation(config_name, config, num_epochs=10, device='cuda' if torch.cuda.is_available() else 'cpu'):
    print(f"\nRunning experiment: {config_name}")
    
    # Setup data
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x * 2 - 1)  # Scale to [-1, 1]
    ])
    mnist = datasets.MNIST(root='.', train=True, download=True, transform=transform)
    dataloader = DataLoader(mnist, batch_size=128, shuffle=True)
    
    # Create models
    model_adaptive = SimpleUNet().to(device)
    model_random = SimpleUNet().to(device)
    optimizer_adaptive = torch.optim.Adam(model_adaptive.parameters(), lr=1e-3)
    optimizer_random = torch.optim.Adam(model_random.parameters(), lr=1e-3)
    
    # Training metrics
    adaptive_losses = []
    random_losses = []
    adaptive_psnr = []
    adaptive_ssim = []
    corruption_metrics = []
    
    for epoch in range(num_epochs):
        model_adaptive.train()
        model_random.train()
        epoch_adaptive_losses = []
        epoch_random_losses = []
        epoch_adaptive_psnr = []
        epoch_adaptive_ssim = []
        epoch_corruption_metrics = []
        
        for x, _ in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            x = x.to(device)
            
            # Adaptive corruption
            unc = compute_uncertainty(model_adaptive, x, config['num_samples'])
            mask_adaptive = generate_mask(unc, config['alpha'], config['tau'], config['invert_uncertainty'])
            x_corrupt_adaptive = corrupt_input(x, mask_adaptive)
            pred_adaptive = model_adaptive(x_corrupt_adaptive)
            loss_adaptive = F.mse_loss(pred_adaptive, x) + regularize_mask(mask_adaptive, config['lambda_reg'])
            
            optimizer_adaptive.zero_grad()
            loss_adaptive.backward()
            optimizer_adaptive.step()
            
            # Random corruption
            mask_random = torch.bernoulli(torch.full_like(x, config['corruption_prob']))
            x_corrupt_random = corrupt_input(x, mask_random)
            pred_random = model_random(x_corrupt_random)
            loss_random = F.mse_loss(pred_random, x)
            
            optimizer_random.zero_grad()
            loss_random.backward()
            optimizer_random.step()
            
            # Compute metrics
            epoch_adaptive_losses.append(loss_adaptive.item())
            epoch_random_losses.append(loss_random.item())
            epoch_adaptive_psnr.append(compute_psnr(pred_adaptive, x))
            epoch_adaptive_ssim.append(compute_ssim(pred_adaptive, x))
            epoch_corruption_metrics.append(analyze_corruption_coverage(mask_adaptive, unc))
        
        # Record epoch metrics
        adaptive_losses.append(np.mean(epoch_adaptive_losses))
        random_losses.append(np.mean(epoch_random_losses))
        adaptive_psnr.append(np.mean(epoch_adaptive_psnr))
        adaptive_ssim.append(np.mean(epoch_adaptive_ssim))
        
        # Average corruption metrics
        avg_corruption_metrics = {
            'corrupted_uncertainty': np.mean([m['corrupted_uncertainty'] for m in epoch_corruption_metrics]),
            'uncorrupted_uncertainty': np.mean([m['uncorrupted_uncertainty'] for m in epoch_corruption_metrics]),
            'corruption_ratio': np.mean([m['corruption_ratio'] for m in epoch_corruption_metrics])
        }
        corruption_metrics.append(avg_corruption_metrics)
        
        print(f"Epoch {epoch+1}:")
        print(f"  Adaptive Loss = {adaptive_losses[-1]:.4f}, Random Loss = {random_losses[-1]:.4f}")
        print(f"  PSNR = {adaptive_psnr[-1]:.2f}, SSIM = {adaptive_ssim[-1]:.4f}")
        print(f"  Corruption Ratio = {avg_corruption_metrics['corruption_ratio']:.4f}")
        print(f"  Corrupted/Uncorrupted Uncertainty = {avg_corruption_metrics['corrupted_uncertainty']:.4f}/{avg_corruption_metrics['uncorrupted_uncertainty']:.4f}")
        
        # Save samples and visualizations
        if (epoch + 1) % 5 == 0:
            model_adaptive.eval()
            model_random.eval()
            with torch.no_grad():
                x, _ = next(iter(dataloader))
                x = x.to(device)[:8]
                
                unc = compute_uncertainty(model_adaptive, x, config['num_samples'])
                mask_adaptive = generate_mask(unc, config['alpha'], config['tau'], config['invert_uncertainty'])
                mask_random = torch.bernoulli(torch.full_like(x, config['corruption_prob']))
                
                x_corrupt_adaptive = corrupt_input(x, mask_adaptive)
                x_corrupt_random = corrupt_input(x, mask_random)
                
                recon_adaptive = model_adaptive(x_corrupt_adaptive)
                recon_random = model_random(x_corrupt_random)
                
                # Generate detailed visualizations
                visualize_uncertainty_and_corruption(
                    x, unc, mask_adaptive, recon_adaptive,
                    config_name, epoch + 1
                )
                
                visualize_batch_comparison(
                    x, x_corrupt_adaptive, recon_adaptive,
                    x_corrupt_random, recon_random,
                    config_name, epoch + 1
                )
                
                # Keep the original grid visualization
                images = torch.cat([x, x_corrupt_adaptive, recon_adaptive, x_corrupt_random, recon_random], dim=0)
                grid = utils.make_grid(images, nrow=8, normalize=True, value_range=(-1, 1))
                plt.figure(figsize=(20, 5))
                plt.imshow(grid.permute(1, 2, 0).cpu())
                plt.axis('off')
                plt.title(f"{config_name} - Epoch {epoch+1}")
                plt.savefig(f'ablation_{config_name}_epoch_{epoch+1}.png')
                plt.close()
    
    return {
        'adaptive_losses': adaptive_losses,
        'random_losses': random_losses,
        'adaptive_psnr': adaptive_psnr,
        'adaptive_ssim': adaptive_ssim,
        'corruption_metrics': corruption_metrics,
        'config': config
    }

# Cell 5: Plotting Function
def plot_ablation_results(results):
    # Plot losses
    plt.figure(figsize=(15, 10))
    
    # Plot 1: Losses
    plt.subplot(2, 2, 1)
    for config_name, result in results.items():
        plt.plot(result['adaptive_losses'], label=f"{config_name} (Adaptive)")
        plt.plot(result['random_losses'], label=f"{config_name} (Random)", linestyle='--')
    plt.title('Training Losses')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    
    # Plot 2: PSNR
    plt.subplot(2, 2, 2)
    for config_name, result in results.items():
        plt.plot(result['adaptive_psnr'], label=config_name)
    plt.title('PSNR')
    plt.xlabel('Epochs')
    plt.ylabel('PSNR (dB)')
    plt.legend()
    
    # Plot 3: SSIM
    plt.subplot(2, 2, 3)
    for config_name, result in results.items():
        plt.plot(result['adaptive_ssim'], label=config_name)
    plt.title('SSIM')
    plt.xlabel('Epochs')
    plt.ylabel('SSIM')
    plt.legend()
    
    # Plot 4: Corruption Ratio
    plt.subplot(2, 2, 4)
    for config_name, result in results.items():
        corruption_ratios = [m['corruption_ratio'] for m in result['corruption_metrics']]
        plt.plot(corruption_ratios, label=config_name)
    plt.title('Corruption Ratio')
    plt.xlabel('Epochs')
    plt.ylabel('Ratio')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('ablation_metrics.png')
    plt.close()
    
    # Create summary table
    summary = []
    for config_name, result in results.items():
        config = result['config']
        summary.append({
            'Configuration': config_name,
            'Final Adaptive Loss': f"{result['adaptive_losses'][-1]:.4f}",
            'Final Random Loss': f"{result['random_losses'][-1]:.4f}",
            'Final PSNR': f"{result['adaptive_psnr'][-1]:.2f}",
            'Final SSIM': f"{result['adaptive_ssim'][-1]:.4f}",
            'Final Corruption Ratio': f"{result['corruption_metrics'][-1]['corruption_ratio']:.4f}",
            'Corrupted Uncertainty': f"{result['corruption_metrics'][-1]['corrupted_uncertainty']:.4f}",
            'Uncorrupted Uncertainty': f"{result['corruption_metrics'][-1]['uncorrupted_uncertainty']:.4f}",
            'Alpha': config['alpha'],
            'Tau': config['tau'],
            'Corruption Prob': config['corruption_prob'],
            'Num Samples': config['num_samples'],
            'Inverted Uncertainty': config['invert_uncertainty']
        })
    
    # Save summary to file
    with open('ablation_summary.txt', 'w') as f:
        f.write("Adaptive Corruption Ablation Study Summary\n")
        f.write("========================================\n\n")
        for entry in summary:
            f.write(f"Configuration: {entry['Configuration']}\n")
            f.write(f"Final Adaptive Loss: {entry['Final Adaptive Loss']}\n")
            f.write(f"Final Random Loss: {entry['Final Random Loss']}\n")
            f.write(f"Final PSNR: {entry['Final PSNR']}\n")
            f.write(f"Final SSIM: {entry['Final SSIM']}\n")
            f.write(f"Final Corruption Ratio: {entry['Final Corruption Ratio']}\n")
            f.write(f"Corrupted Uncertainty: {entry['Corrupted Uncertainty']}\n")
            f.write(f"Uncorrupted Uncertainty: {entry['Uncorrupted Uncertainty']}\n")
            f.write(f"Alpha: {entry['Alpha']}\n")
            f.write(f"Tau: {entry['Tau']}\n")
            f.write(f"Corruption Probability: {entry['Corruption Prob']}\n")
            f.write(f"Number of Samples: {entry['Num Samples']}\n")
            f.write(f"Inverted Uncertainty: {entry['Inverted Uncertainty']}\n")
            f.write("\n")

# Add a function to generate final comparison visualizations
def generate_final_comparison(results, device):
    """Generate final comparison visualizations for all configurations."""
    # Setup data
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x * 2 - 1)
    ])
    mnist = datasets.MNIST(root='.', train=True, download=True, transform=transform)
    dataloader = DataLoader(mnist, batch_size=8, shuffle=True)
    
    # Get a batch of images
    x, _ = next(iter(dataloader))
    x = x.to(device)
    
    # Create a figure for each configuration
    for config_name, result in results.items():
        config = result['config']
        
        # Create models
        model_adaptive = SimpleUNet().to(device)
        model_random = SimpleUNet().to(device)
        
        # Compute uncertainty and generate masks
        unc = compute_uncertainty(model_adaptive, x, config['num_samples'])
        mask_adaptive = generate_mask(unc, config['alpha'], config['tau'], config['invert_uncertainty'])
        mask_random = torch.bernoulli(torch.full_like(x, config['corruption_prob']))
        
        # Generate corrupted and reconstructed images
        x_corrupt_adaptive = corrupt_input(x, mask_adaptive)
        x_corrupt_random = corrupt_input(x, mask_random)
        recon_adaptive = model_adaptive(x_corrupt_adaptive)
        recon_random = model_random(x_corrupt_random)
        
        # Save visualizations
        visualize_uncertainty_and_corruption(
            x, unc, mask_adaptive, recon_adaptive,
            f"{config_name}_final", 0
        )
        
        visualize_batch_comparison(
            x, x_corrupt_adaptive, recon_adaptive,
            x_corrupt_random, recon_random,
            f"{config_name}_final", 0
        )

# Cell 6: Run Experiments
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
results = {}

for config_name, config in configs.items():
    results[config_name] = train_ablation(
        config_name,
        config,
        num_epochs=10,
        device=device
    )

# Plot and save results
plot_ablation_results(results)

# Generate final comparison visualizations
generate_final_comparison(results, device)

# Cell 7: Display Results
print("\nAblation Study Results:")
print("======================")
for config_name, result in results.items():
    print(f"\nConfiguration: {config_name}")
    print(f"Final Adaptive Loss: {result['adaptive_losses'][-1]:.4f}")
    print(f"Final Random Loss: {result['random_losses'][-1]:.4f}")
    print(f"Final PSNR: {result['adaptive_psnr'][-1]:.2f}")
    print(f"Final SSIM: {result['adaptive_ssim'][-1]:.4f}")
    print(f"Final Corruption Ratio: {result['corruption_metrics'][-1]['corruption_ratio']:.4f}")
    print(f"Corrupted/Uncorrupted Uncertainty: {result['corruption_metrics'][-1]['corrupted_uncertainty']:.4f}/{result['corruption_metrics'][-1]['uncorrupted_uncertainty']:.4f}")
    print(f"Alpha: {result['config']['alpha']}")
    print(f"Tau: {result['config']['tau']}")
    print(f"Corruption Probability: {result['config']['corruption_prob']}")
    print(f"Number of Samples: {result['config']['num_samples']}")
    print(f"Inverted Uncertainty: {result['config']['invert_uncertainty']}") 


Running experiment: baseline


Epoch 1/10: 100%|██████████| 469/469 [00:22<00:00, 20.99it/s]


Epoch 1:
  Adaptive Loss = 0.0457, Random Loss = 0.4521
  PSNR = 24.14, SSIM = 0.8617
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 2/10: 100%|██████████| 469/469 [00:21<00:00, 21.83it/s]


Epoch 2:
  Adaptive Loss = 0.0055, Random Loss = 0.4480
  PSNR = 28.62, SSIM = 0.9554
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 3/10: 100%|██████████| 469/469 [00:21<00:00, 21.88it/s]


Epoch 3:
  Adaptive Loss = 0.0045, Random Loss = 0.4480
  PSNR = 29.54, SSIM = 0.9619
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 4/10: 100%|██████████| 469/469 [00:21<00:00, 21.72it/s]


Epoch 4:
  Adaptive Loss = 0.0039, Random Loss = 0.4480
  PSNR = 30.12, SSIM = 0.9659
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 5/10: 100%|██████████| 469/469 [00:20<00:00, 22.42it/s]


Epoch 5:
  Adaptive Loss = 0.0035, Random Loss = 0.4480
  PSNR = 30.54, SSIM = 0.9686
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 6/10: 100%|██████████| 469/469 [00:20<00:00, 22.46it/s]


Epoch 6:
  Adaptive Loss = 0.0033, Random Loss = 0.4480
  PSNR = 30.87, SSIM = 0.9708
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 7/10: 100%|██████████| 469/469 [00:21<00:00, 22.17it/s]


Epoch 7:
  Adaptive Loss = 0.0031, Random Loss = 0.4480
  PSNR = 31.15, SSIM = 0.9726
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 8/10: 100%|██████████| 469/469 [00:21<00:00, 22.28it/s]


Epoch 8:
  Adaptive Loss = 0.0029, Random Loss = 0.4480
  PSNR = 31.36, SSIM = 0.9741
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 9/10: 100%|██████████| 469/469 [00:21<00:00, 22.00it/s]


Epoch 9:
  Adaptive Loss = 0.0028, Random Loss = 0.4480
  PSNR = 31.60, SSIM = 0.9758
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 10/10: 100%|██████████| 469/469 [00:20<00:00, 22.38it/s]


Epoch 10:
  Adaptive Loss = 0.0027, Random Loss = 0.4480
  PSNR = 31.79, SSIM = 0.9767
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000

Running experiment: high_uncertainty


Epoch 1/10: 100%|██████████| 469/469 [00:21<00:00, 22.21it/s]


Epoch 1:
  Adaptive Loss = 0.0355, Random Loss = 0.4542
  PSNR = 24.95, SSIM = 0.8809
  Corruption Ratio = 0.0025
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 2/10: 100%|██████████| 469/469 [00:21<00:00, 22.21it/s]


Epoch 2:
  Adaptive Loss = 0.0045, Random Loss = 0.4480
  PSNR = 29.50, SSIM = 0.9620
  Corruption Ratio = 0.0025
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 3/10: 100%|██████████| 469/469 [00:21<00:00, 21.91it/s]


Epoch 3:
  Adaptive Loss = 0.0035, Random Loss = 0.4480
  PSNR = 30.61, SSIM = 0.9676
  Corruption Ratio = 0.0025
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 4/10: 100%|██████████| 469/469 [00:21<00:00, 22.17it/s]


Epoch 4:
  Adaptive Loss = 0.0029, Random Loss = 0.4480
  PSNR = 31.38, SSIM = 0.9712
  Corruption Ratio = 0.0025
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 5/10: 100%|██████████| 469/469 [00:21<00:00, 21.81it/s]


Epoch 5:
  Adaptive Loss = 0.0026, Random Loss = 0.4480
  PSNR = 31.91, SSIM = 0.9734
  Corruption Ratio = 0.0025
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 6/10: 100%|██████████| 469/469 [00:21<00:00, 22.03it/s]


Epoch 6:
  Adaptive Loss = 0.0023, Random Loss = 0.4480
  PSNR = 32.33, SSIM = 0.9753
  Corruption Ratio = 0.0025
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 7/10: 100%|██████████| 469/469 [00:21<00:00, 22.07it/s]


Epoch 7:
  Adaptive Loss = 0.0022, Random Loss = 0.4480
  PSNR = 32.65, SSIM = 0.9768
  Corruption Ratio = 0.0025
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 8/10: 100%|██████████| 469/469 [00:21<00:00, 21.81it/s]


Epoch 8:
  Adaptive Loss = 0.0020, Random Loss = 0.4480
  PSNR = 33.02, SSIM = 0.9784
  Corruption Ratio = 0.0025
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 9/10: 100%|██████████| 469/469 [00:20<00:00, 22.86it/s]


Epoch 9:
  Adaptive Loss = 0.0019, Random Loss = 0.4480
  PSNR = 33.30, SSIM = 0.9799
  Corruption Ratio = 0.0025
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 10/10: 100%|██████████| 469/469 [00:21<00:00, 22.15it/s]


Epoch 10:
  Adaptive Loss = 0.0018, Random Loss = 0.4480
  PSNR = 33.58, SSIM = 0.9811
  Corruption Ratio = 0.0025
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000

Running experiment: low_uncertainty


Epoch 1/10: 100%|██████████| 469/469 [00:21<00:00, 21.90it/s]


Epoch 1:
  Adaptive Loss = 0.4569, Random Loss = 0.4532
  PSNR = 9.45, SSIM = 0.6275
  Corruption Ratio = 0.0293
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 2/10: 100%|██████████| 469/469 [00:20<00:00, 22.53it/s]


Epoch 2:
  Adaptive Loss = 0.4480, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.0293
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 3/10: 100%|██████████| 469/469 [00:21<00:00, 22.25it/s]


Epoch 3:
  Adaptive Loss = 0.4481, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.0294
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 4/10: 100%|██████████| 469/469 [00:20<00:00, 22.46it/s]


Epoch 4:
  Adaptive Loss = 0.4480, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.0293
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 5/10: 100%|██████████| 469/469 [00:20<00:00, 23.34it/s]


Epoch 5:
  Adaptive Loss = 0.4480, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.0293
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 6/10: 100%|██████████| 469/469 [00:18<00:00, 24.92it/s]


Epoch 6:
  Adaptive Loss = 0.4480, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.0293
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 7/10: 100%|██████████| 469/469 [00:20<00:00, 22.65it/s]


Epoch 7:
  Adaptive Loss = 0.4480, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.0293
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 8/10: 100%|██████████| 469/469 [00:20<00:00, 22.43it/s]


Epoch 8:
  Adaptive Loss = 0.4480, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.0293
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 9/10: 100%|██████████| 469/469 [00:21<00:00, 22.28it/s]


Epoch 9:
  Adaptive Loss = 0.4480, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.0293
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 10/10: 100%|██████████| 469/469 [00:21<00:00, 22.16it/s]


Epoch 10:
  Adaptive Loss = 0.4481, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.0293
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000

Running experiment: high_corruption


Epoch 1/10: 100%|██████████| 469/469 [00:20<00:00, 22.55it/s]


Epoch 1:
  Adaptive Loss = 0.0424, Random Loss = 0.2117
  PSNR = 24.09, SSIM = 0.8657
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 2/10: 100%|██████████| 469/469 [00:21<00:00, 22.17it/s]


Epoch 2:
  Adaptive Loss = 0.0056, Random Loss = 0.1685
  PSNR = 28.60, SSIM = 0.9561
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 3/10: 100%|██████████| 469/469 [00:21<00:00, 21.80it/s]


Epoch 3:
  Adaptive Loss = 0.0046, Random Loss = 0.1486
  PSNR = 29.45, SSIM = 0.9621
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 4/10: 100%|██████████| 469/469 [00:21<00:00, 22.23it/s]


Epoch 4:
  Adaptive Loss = 0.0040, Random Loss = 0.1402
  PSNR = 29.98, SSIM = 0.9656
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 5/10: 100%|██████████| 469/469 [00:21<00:00, 21.83it/s]


Epoch 5:
  Adaptive Loss = 0.0037, Random Loss = 0.1339
  PSNR = 30.39, SSIM = 0.9681
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 6/10: 100%|██████████| 469/469 [00:20<00:00, 22.44it/s]


Epoch 6:
  Adaptive Loss = 0.0034, Random Loss = 0.1289
  PSNR = 30.73, SSIM = 0.9705
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 7/10: 100%|██████████| 469/469 [00:20<00:00, 22.55it/s]


Epoch 7:
  Adaptive Loss = 0.0032, Random Loss = 0.1245
  PSNR = 31.03, SSIM = 0.9726
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 8/10: 100%|██████████| 469/469 [00:21<00:00, 22.09it/s]


Epoch 8:
  Adaptive Loss = 0.0030, Random Loss = 0.1215
  PSNR = 31.28, SSIM = 0.9743
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 9/10: 100%|██████████| 469/469 [00:21<00:00, 22.04it/s]


Epoch 9:
  Adaptive Loss = 0.0028, Random Loss = 0.1170
  PSNR = 31.52, SSIM = 0.9759
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 10/10: 100%|██████████| 469/469 [00:21<00:00, 21.60it/s]


Epoch 10:
  Adaptive Loss = 0.0027, Random Loss = 0.1112
  PSNR = 31.76, SSIM = 0.9773
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000

Running experiment: high_samples


Epoch 1/10: 100%|██████████| 469/469 [00:21<00:00, 21.72it/s]


Epoch 1:
  Adaptive Loss = 0.4551, Random Loss = 0.4532
  PSNR = 9.46, SSIM = 0.6286
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 2/10: 100%|██████████| 469/469 [00:22<00:00, 20.85it/s]


Epoch 2:
  Adaptive Loss = 0.4480, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 3/10: 100%|██████████| 469/469 [00:21<00:00, 21.40it/s]


Epoch 3:
  Adaptive Loss = 0.4480, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 4/10: 100%|██████████| 469/469 [00:22<00:00, 21.15it/s]


Epoch 4:
  Adaptive Loss = 0.4480, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 5/10: 100%|██████████| 469/469 [00:22<00:00, 21.17it/s]


Epoch 5:
  Adaptive Loss = 0.4480, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 6/10: 100%|██████████| 469/469 [00:22<00:00, 20.84it/s]


Epoch 6:
  Adaptive Loss = 0.4480, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 7/10: 100%|██████████| 469/469 [00:21<00:00, 21.96it/s]


Epoch 7:
  Adaptive Loss = 0.4480, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 8/10: 100%|██████████| 469/469 [00:20<00:00, 23.18it/s]


Epoch 8:
  Adaptive Loss = 0.4480, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 9/10: 100%|██████████| 469/469 [00:22<00:00, 21.14it/s]


Epoch 9:
  Adaptive Loss = 0.4480, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 10/10: 100%|██████████| 469/469 [00:22<00:00, 21.19it/s]


Epoch 10:
  Adaptive Loss = 0.4480, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000

Running experiment: inverted_uncertainty


Epoch 1/10: 100%|██████████| 469/469 [00:21<00:00, 22.16it/s]


Epoch 1:
  Adaptive Loss = 0.4560, Random Loss = 0.4524
  PSNR = 9.46, SSIM = 0.6294
  Corruption Ratio = 0.9933
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 2/10: 100%|██████████| 469/469 [00:21<00:00, 22.28it/s]


Epoch 2:
  Adaptive Loss = 0.4490, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.9933
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 3/10: 100%|██████████| 469/469 [00:20<00:00, 22.39it/s]


Epoch 3:
  Adaptive Loss = 0.4490, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.9933
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 4/10: 100%|██████████| 469/469 [00:21<00:00, 21.91it/s]


Epoch 4:
  Adaptive Loss = 0.4490, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.9933
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 5/10: 100%|██████████| 469/469 [00:20<00:00, 22.82it/s]


Epoch 5:
  Adaptive Loss = 0.4490, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.9933
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 6/10: 100%|██████████| 469/469 [00:21<00:00, 21.95it/s]


Epoch 6:
  Adaptive Loss = 0.4490, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.9933
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 7/10: 100%|██████████| 469/469 [00:21<00:00, 22.24it/s]


Epoch 7:
  Adaptive Loss = 0.4490, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.9933
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 8/10: 100%|██████████| 469/469 [00:20<00:00, 22.60it/s]


Epoch 8:
  Adaptive Loss = 0.4490, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.9933
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 9/10: 100%|██████████| 469/469 [00:21<00:00, 22.31it/s]


Epoch 9:
  Adaptive Loss = 0.4490, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.9933
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 10/10: 100%|██████████| 469/469 [00:19<00:00, 23.46it/s]


Epoch 10:
  Adaptive Loss = 0.4490, Random Loss = 0.4480
  PSNR = 9.51, SSIM = 0.6428
  Corruption Ratio = 0.9933
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000

Ablation Study Results:

Configuration: baseline
Final Adaptive Loss: 0.0027
Final Random Loss: 0.4480
Final PSNR: 31.79
Final SSIM: 0.9767
Final Corruption Ratio: 0.0067
Corrupted/Uncorrupted Uncertainty: 0.0000/0.0000
Alpha: 10.0
Tau: 0.5
Corruption Probability: 0.5
Number of Samples: 5
Inverted Uncertainty: False

Configuration: high_uncertainty
Final Adaptive Loss: 0.0018
Final Random Loss: 0.4480
Final PSNR: 33.58
Final SSIM: 0.9811
Final Corruption Ratio: 0.0025
Corrupted/Uncorrupted Uncertainty: 0.0000/0.0000
Alpha: 20.0
Tau: 0.3
Corruption Probability: 0.5
Number of Samples: 5
Inverted Uncertainty: False

Configuration: low_uncertainty
Final Adaptive Loss: 0.4481
Final Random Loss: 0.4480
Final PSNR: 9.51
Final SSIM: 0.6428
Final Corruption Ratio: 0.0293
Corrupted/Uncorrupted Uncertainty: 0.0000/0.0000
Alpha: 5.0


# MNIST Results

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import os
import math
from tqdm import tqdm
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

# Define configurations
configs = {
    'baseline': {
        'alpha': 10.0,
        'tau': 0.5,
        'lambda_reg': 1e-3,
        'corruption_prob': 0.5,
        'num_samples': 5,
        'invert_uncertainty': False,
        'random_corruption': False
    },
    'inverted_uncertainty': {
        'alpha': 10.0,
        'tau': 0.5,
        'lambda_reg': 1e-3,
        'corruption_prob': 0.5,
        'num_samples': 5,
        'invert_uncertainty': True,
        'random_corruption': False
    },
    'random_corruption': {
        'alpha': 10.0,
        'tau': 0.5,
        'lambda_reg': 1e-3,
        'corruption_prob': 0.5,
        'num_samples': 5,
        'invert_uncertainty': False,
        'random_corruption': True
    }
}

# Model architecture
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(channels, channels, 3, padding=1)
        )
    def forward(self, x):
        return x + self.block(x)

class SimpleUNet(nn.Module):
    def __init__(self, channels=32):
        super().__init__()
        self.down1 = nn.Conv2d(1, channels, 3, padding=1, stride=2)
        self.down2 = nn.Conv2d(channels, channels, 3, padding=1, stride=2)
        self.res = ResidualBlock(channels)
        self.up1 = nn.ConvTranspose2d(channels, channels, 4, stride=2, padding=1)
        self.up2 = nn.ConvTranspose2d(channels, 1, 4, stride=2, padding=1)
    def forward(self, x):
        x1 = F.relu(self.down1(x))
        x2 = F.relu(self.down2(x1))
        x3 = self.res(x2)
        x4 = F.relu(self.up1(x3))
        return torch.tanh(self.up2(x4))

# Utility functions
def compute_uncertainty(model, x, num_samples=5):
    model.train()
    preds = [model(x) for _ in range(num_samples)]
    model.eval()
    return torch.var(torch.stack(preds), dim=0).mean(1, keepdim=True)

def generate_mask(uncertainty, alpha=10.0, tau=0.5, invert_uncertainty=False, random_corruption=False):
    if random_corruption:
        # Generate random mask with same shape as uncertainty
        return torch.bernoulli(torch.ones_like(uncertainty) * 0.2).detach()
    
    # if invert_uncertainty:
    #     uncertainty = 1.0 - uncertainty
    prob = torch.sigmoid(alpha * (uncertainty - tau))
    if invert_uncertainty:
        prob = 1 - prob
    return torch.bernoulli(prob).detach()

def regularize_mask(mask, lambda_reg=1e-3):
    return lambda_reg * mask.mean()

def corrupt_input(x, mask):
    noise = torch.randn_like(x)
    return mask * noise + (1 - mask) * x

def compute_psnr(img1, img2):
    img1 = img1.detach().cpu().numpy()
    img2 = img2.detach().cpu().numpy()
    return psnr(img1, img2, data_range=2.0)

def compute_ssim(img1, img2):
    img1 = img1.detach().cpu().numpy()
    img2 = img2.detach().cpu().numpy()
    if len(img1.shape) == 4:
        ssim_values = []
        for i in range(img1.shape[0]):
            img1_i = img1[i, 0]
            img2_i = img2[i, 0]
            ssim_val = ssim(img1_i, img2_i, data_range=2.0, win_size=3)
            ssim_values.append(ssim_val)
        return np.mean(ssim_values)
    else:
        return ssim(img1[0], img2[0], data_range=2.0, win_size=3)

def analyze_corruption_coverage(mask, uncertainty):
    corrupted_uncertainty = (mask * uncertainty).sum() / (mask.sum() + 1e-6)
    uncorrupted_uncertainty = ((1 - mask) * uncertainty).sum() / ((1 - mask).sum() + 1e-6)
    corruption_ratio = mask.mean().item()
    return {
        'corrupted_uncertainty': corrupted_uncertainty.item(),
        'uncorrupted_uncertainty': uncorrupted_uncertainty.item(),
        'corruption_ratio': corruption_ratio
    }

# Visualization functions
def visualize_uncertainty_and_corruption(x, uncertainty, mask, pred, config_name, epoch, save_dir='ablations'):
    os.makedirs(save_dir, exist_ok=True)
    
    x_np = x.detach().cpu().numpy()
    uncertainty_np = uncertainty.detach().cpu().numpy()
    mask_np = mask.detach().cpu().numpy()
    pred_np = pred.detach().cpu().numpy()
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 12))
    
    axes[0, 0].imshow(x_np[0, 0], cmap='gray')
    axes[0, 0].set_title('Original Image')
    axes[0, 0].axis('off')
    
    im = axes[0, 1].imshow(uncertainty_np[0, 0], cmap='hot')
    axes[0, 1].set_title('Uncertainty Map')
    axes[0, 1].axis('off')
    plt.colorbar(im, ax=axes[0, 1])
    
    axes[1, 0].imshow(mask_np[0, 0], cmap='gray')
    axes[1, 0].set_title('Corruption Mask')
    axes[1, 0].axis('off')
    
    axes[1, 1].imshow(pred_np[0, 0], cmap='gray')
    axes[1, 1].set_title('Reconstruction')
    axes[1, 1].axis('off')
    
    plt.suptitle(f'{config_name} - Epoch {epoch}')
    plt.tight_layout()
    plt.savefig(f'{save_dir}/uncertainty_analysis_{config_name}_epoch_{epoch}.png')
    plt.close()

def visualize_batch_comparison(x, x_corrupt, recon, config_name, epoch, save_dir='ablations'):
    os.makedirs(save_dir, exist_ok=True)
    
    fig, axes = plt.subplots(3, 8, figsize=(20, 8))
    
    def plot_row(images, row_idx, title):
        for col_idx in range(8):
            axes[row_idx, col_idx].imshow(images[col_idx, 0].detach().cpu(), cmap='gray')
            axes[row_idx, col_idx].axis('off')
        axes[row_idx, 0].set_ylabel(title)
    
    plot_row(x, 0, 'Original')
    plot_row(x_corrupt, 1, 'Corrupted')
    plot_row(recon, 2, 'Reconstruction')
    
    plt.suptitle(f'{config_name} - Epoch {epoch}')
    plt.tight_layout()
    plt.savefig(f'{save_dir}/batch_comparison_{config_name}_epoch_{epoch}.png')
    plt.close()

# Training function
def train_model(config_name, config, num_epochs=10, device='cuda' if torch.cuda.is_available() else 'cpu'):
    print(f"\nTraining {config_name} configuration")
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x * 2 - 1)
    ])
    mnist = datasets.MNIST(root='.', train=True, download=True, transform=transform)
    dataloader = DataLoader(mnist, batch_size=128, shuffle=True)
    
    model = SimpleUNet().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    losses = []
    psnr_values = []
    ssim_values = []
    corruption_metrics = []
    
    for epoch in range(num_epochs):
        model.train()
        epoch_losses = []
        epoch_psnr = []
        epoch_ssim = []
        epoch_corruption_metrics = []
        
        for x, _ in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            x = x.to(device)
            
            unc = compute_uncertainty(model, x, config['num_samples'])
            mask = generate_mask(
                unc, 
                config['alpha'], 
                config['tau'], 
                config['invert_uncertainty'],
                config['random_corruption']
            )
            x_corrupt = corrupt_input(x, mask)
            pred = model(x_corrupt)
            
            loss = F.mse_loss(pred, x) + regularize_mask(mask, config['lambda_reg'])
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_losses.append(loss.item())
            epoch_psnr.append(compute_psnr(pred, x))
            epoch_ssim.append(compute_ssim(pred, x))
            epoch_corruption_metrics.append(analyze_corruption_coverage(mask, unc))
        
        losses.append(np.mean(epoch_losses))
        psnr_values.append(np.mean(epoch_psnr))
        ssim_values.append(np.mean(epoch_ssim))
        
        avg_corruption_metrics = {
            'corrupted_uncertainty': np.mean([m['corrupted_uncertainty'] for m in epoch_corruption_metrics]),
            'uncorrupted_uncertainty': np.mean([m['uncorrupted_uncertainty'] for m in epoch_corruption_metrics]),
            'corruption_ratio': np.mean([m['corruption_ratio'] for m in epoch_corruption_metrics])
        }
        corruption_metrics.append(avg_corruption_metrics)
        
        print(f"Epoch {epoch+1}:")
        print(f"  Loss = {losses[-1]:.4f}")
        print(f"  PSNR = {psnr_values[-1]:.2f}")
        print(f"  SSIM = {ssim_values[-1]:.4f}")
        print(f"  Corruption Ratio = {avg_corruption_metrics['corruption_ratio']:.4f}")
        print(f"  Corrupted/Uncorrupted Uncertainty = {avg_corruption_metrics['corrupted_uncertainty']:.4f}/{avg_corruption_metrics['uncorrupted_uncertainty']:.4f}")
        
        if (epoch + 1) % 5 == 0:
            model.eval()
            with torch.no_grad():
                x, _ = next(iter(dataloader))
                x = x.to(device)[:8]
                
                unc = compute_uncertainty(model, x, config['num_samples'])
                mask = generate_mask(
                    unc, 
                    config['alpha'], 
                    config['tau'], 
                    config['invert_uncertainty'],
                    config['random_corruption']
                )
                x_corrupt = corrupt_input(x, mask)
                pred = model(x_corrupt)
                
                visualize_uncertainty_and_corruption(
                    x, unc, mask, pred,
                    config_name, epoch + 1
                )
                
                visualize_batch_comparison(
                    x, x_corrupt, pred,
                    config_name, epoch + 1
                )
    
    return {
        'losses': losses,
        'psnr': psnr_values,
        'ssim': ssim_values,
        'corruption_metrics': corruption_metrics,
        'config': config
    }

# Plot results
def plot_comparison(results):
    plt.figure(figsize=(15, 10))
    
    # Plot losses
    plt.subplot(2, 2, 1)
    for config_name, result in results.items():
        plt.plot(result['losses'], label=config_name)
    plt.title('Training Losses')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    
    # Plot PSNR
    plt.subplot(2, 2, 2)
    for config_name, result in results.items():
        plt.plot(result['psnr'], label=config_name)
    plt.title('PSNR')
    plt.xlabel('Epochs')
    plt.ylabel('PSNR (dB)')
    plt.legend()
    
    # Plot SSIM
    plt.subplot(2, 2, 3)
    for config_name, result in results.items():
        plt.plot(result['ssim'], label=config_name)
    plt.title('SSIM')
    plt.xlabel('Epochs')
    plt.ylabel('SSIM')
    plt.legend()
    
    # Plot corruption ratio
    plt.subplot(2, 2, 4)
    for config_name, result in results.items():
        corruption_ratios = [m['corruption_ratio'] for m in result['corruption_metrics']]
        plt.plot(corruption_ratios, label=config_name)
    plt.title('Corruption Ratio')
    plt.xlabel('Epochs')
    plt.ylabel('Ratio')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('/data/healthy-ml/scratch/abinitha/68300-final/ablations/ablation_results/uncertainty_comparison.png')
    plt.close()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
results = {}

for config_name, config in configs.items():
    results[config_name] = train_model(
        config_name,
        config,
        num_epochs=10,
        device=device
    )

plot_comparison(results)

# Print final results
print("\nFinal Results:")
print("=============")
for config_name, result in results.items():
    print(f"\nConfiguration: {config_name}")
    print(f"Final Loss: {result['losses'][-1]:.4f}")
    print(f"Final PSNR: {result['psnr'][-1]:.2f}")
    print(f"Final SSIM: {result['ssim'][-1]:.4f}")
    print(f"Final Corruption Ratio: {result['corruption_metrics'][-1]['corruption_ratio']:.4f}")
    print(f"Corrupted/Uncorrupted Uncertainty: {result['corruption_metrics'][-1]['corrupted_uncertainty']:.4f}/{result['corruption_metrics'][-1]['uncorrupted_uncertainty']:.4f}") 


Training baseline configuration


Epoch 1/10: 100%|██████████| 469/469 [00:21<00:00, 21.87it/s]


Epoch 1:
  Loss = 0.0494
  PSNR = 23.87
  SSIM = 0.8548
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 2/10: 100%|██████████| 469/469 [00:19<00:00, 23.82it/s]


Epoch 2:
  Loss = 0.0057
  PSNR = 28.50
  SSIM = 0.9540
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 3/10: 100%|██████████| 469/469 [00:19<00:00, 23.53it/s]


Epoch 3:
  Loss = 0.0045
  PSNR = 29.53
  SSIM = 0.9617
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 4/10: 100%|██████████| 469/469 [00:20<00:00, 22.73it/s]


Epoch 4:
  Loss = 0.0039
  PSNR = 30.11
  SSIM = 0.9660
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 5/10: 100%|██████████| 469/469 [00:20<00:00, 22.52it/s]


Epoch 5:
  Loss = 0.0036
  PSNR = 30.51
  SSIM = 0.9687
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 6/10: 100%|██████████| 469/469 [00:21<00:00, 21.93it/s]


Epoch 6:
  Loss = 0.0034
  PSNR = 30.75
  SSIM = 0.9705
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 7/10: 100%|██████████| 469/469 [00:20<00:00, 22.89it/s]


Epoch 7:
  Loss = 0.0032
  PSNR = 30.96
  SSIM = 0.9718
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 8/10: 100%|██████████| 469/469 [00:20<00:00, 22.93it/s]


Epoch 8:
  Loss = 0.0031
  PSNR = 31.14
  SSIM = 0.9730
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 9/10: 100%|██████████| 469/469 [00:20<00:00, 22.78it/s]


Epoch 9:
  Loss = 0.0030
  PSNR = 31.30
  SSIM = 0.9741
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 10/10: 100%|██████████| 469/469 [00:20<00:00, 23.18it/s]


Epoch 10:
  Loss = 0.0029
  PSNR = 31.47
  SSIM = 0.9750
  Corruption Ratio = 0.0067
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000

Training inverted_uncertainty configuration


Epoch 1/10: 100%|██████████| 469/469 [00:20<00:00, 23.14it/s]


Epoch 1:
  Loss = 0.4589
  PSNR = 9.45
  SSIM = 0.6280
  Corruption Ratio = 0.9933
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 2/10: 100%|██████████| 469/469 [00:20<00:00, 23.16it/s]


Epoch 2:
  Loss = 0.4490
  PSNR = 9.51
  SSIM = 0.6428
  Corruption Ratio = 0.9933
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 3/10: 100%|██████████| 469/469 [00:20<00:00, 23.05it/s]


Epoch 3:
  Loss = 0.4490
  PSNR = 9.51
  SSIM = 0.6428
  Corruption Ratio = 0.9933
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 4/10: 100%|██████████| 469/469 [00:19<00:00, 23.50it/s]


Epoch 4:
  Loss = 0.4490
  PSNR = 9.51
  SSIM = 0.6428
  Corruption Ratio = 0.9933
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 5/10: 100%|██████████| 469/469 [00:20<00:00, 23.15it/s]


Epoch 5:
  Loss = 0.4490
  PSNR = 9.51
  SSIM = 0.6428
  Corruption Ratio = 0.9933
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 6/10: 100%|██████████| 469/469 [00:20<00:00, 23.22it/s]


Epoch 6:
  Loss = 0.4490
  PSNR = 9.51
  SSIM = 0.6428
  Corruption Ratio = 0.9933
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 7/10: 100%|██████████| 469/469 [00:20<00:00, 23.15it/s]


Epoch 7:
  Loss = 0.4490
  PSNR = 9.51
  SSIM = 0.6428
  Corruption Ratio = 0.9933
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 8/10: 100%|██████████| 469/469 [00:20<00:00, 22.80it/s]


Epoch 8:
  Loss = 0.4490
  PSNR = 9.51
  SSIM = 0.6428
  Corruption Ratio = 0.9933
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 9/10: 100%|██████████| 469/469 [00:19<00:00, 23.45it/s]


Epoch 9:
  Loss = 0.4490
  PSNR = 9.51
  SSIM = 0.6428
  Corruption Ratio = 0.9933
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 10/10: 100%|██████████| 469/469 [00:20<00:00, 23.20it/s]


Epoch 10:
  Loss = 0.4490
  PSNR = 9.51
  SSIM = 0.6428
  Corruption Ratio = 0.9933
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000

Training random_corruption configuration


Epoch 1/10: 100%|██████████| 469/469 [00:20<00:00, 23.04it/s]


Epoch 1:
  Loss = 0.4561
  PSNR = 9.46
  SSIM = 0.6292
  Corruption Ratio = 0.2000
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 2/10: 100%|██████████| 469/469 [00:20<00:00, 22.80it/s]


Epoch 2:
  Loss = 0.4482
  PSNR = 9.51
  SSIM = 0.6428
  Corruption Ratio = 0.2001
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 3/10: 100%|██████████| 469/469 [00:20<00:00, 22.91it/s]


Epoch 3:
  Loss = 0.4482
  PSNR = 9.51
  SSIM = 0.6428
  Corruption Ratio = 0.1999
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 4/10: 100%|██████████| 469/469 [00:20<00:00, 23.13it/s]


Epoch 4:
  Loss = 0.4482
  PSNR = 9.51
  SSIM = 0.6428
  Corruption Ratio = 0.1999
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 5/10: 100%|██████████| 469/469 [00:20<00:00, 23.05it/s]


Epoch 5:
  Loss = 0.4482
  PSNR = 9.51
  SSIM = 0.6428
  Corruption Ratio = 0.1999
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 6/10: 100%|██████████| 469/469 [00:21<00:00, 22.09it/s]


Epoch 6:
  Loss = 0.4482
  PSNR = 9.51
  SSIM = 0.6428
  Corruption Ratio = 0.2001
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 7/10: 100%|██████████| 469/469 [00:20<00:00, 23.33it/s]


Epoch 7:
  Loss = 0.4482
  PSNR = 9.51
  SSIM = 0.6428
  Corruption Ratio = 0.2000
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 8/10: 100%|██████████| 469/469 [00:20<00:00, 22.71it/s]


Epoch 8:
  Loss = 0.4482
  PSNR = 9.51
  SSIM = 0.6428
  Corruption Ratio = 0.2000
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 9/10: 100%|██████████| 469/469 [00:20<00:00, 23.31it/s]


Epoch 9:
  Loss = 0.4482
  PSNR = 9.51
  SSIM = 0.6428
  Corruption Ratio = 0.1999
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000


Epoch 10/10: 100%|██████████| 469/469 [00:20<00:00, 23.34it/s]


Epoch 10:
  Loss = 0.4482
  PSNR = 9.51
  SSIM = 0.6428
  Corruption Ratio = 0.2001
  Corrupted/Uncorrupted Uncertainty = 0.0000/0.0000

Final Results:

Configuration: baseline
Final Loss: 0.0029
Final PSNR: 31.47
Final SSIM: 0.9750
Final Corruption Ratio: 0.0067
Corrupted/Uncorrupted Uncertainty: 0.0000/0.0000

Configuration: inverted_uncertainty
Final Loss: 0.4490
Final PSNR: 9.51
Final SSIM: 0.6428
Final Corruption Ratio: 0.9933
Corrupted/Uncorrupted Uncertainty: 0.0000/0.0000

Configuration: random_corruption
Final Loss: 0.4482
Final PSNR: 9.51
Final SSIM: 0.6428
Final Corruption Ratio: 0.2001
Corrupted/Uncorrupted Uncertainty: 0.0000/0.0000


# CIFAR-10 Results

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import os
import math
from tqdm import tqdm
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

# Define configurations
configs = {
    'baseline': {
        'alpha': 10.0,
        'tau': 0.5,
        'lambda_reg': 1e-3,
        'corruption_prob': 0.5,
        'num_samples': 5,
        'invert_uncertainty': False,
        'random_corruption': False
    },
    'inverted_uncertainty': {
        'alpha': 10.0,
        'tau': 0.5,
        'lambda_reg': 1e-3,
        'corruption_prob': 0.5,
        'num_samples': 5,
        'invert_uncertainty': True,
        'random_corruption': False
    },
    'random_corruption': {
        'alpha': 10.0,
        'tau': 0.5,
        'lambda_reg': 1e-3,
        'corruption_prob': 0.5,
        'num_samples': 5,
        'invert_uncertainty': False,
        'random_corruption': True
    }
}

# Model architecture for CIFAR-10 (3 channels)
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU(),
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels)
        )
    def forward(self, x):
        return x + self.block(x)

class SimpleUNet(nn.Module):
    def __init__(self, channels=64):
        super().__init__()
        # Encoder
        self.down1 = nn.Sequential(
            nn.Conv2d(3, channels, 3, padding=1, stride=2),
            nn.BatchNorm2d(channels),
            nn.ReLU()
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(channels, channels*2, 3, padding=1, stride=2),
            nn.BatchNorm2d(channels*2),
            nn.ReLU()
        )
        self.down3 = nn.Sequential(
            nn.Conv2d(channels*2, channels*4, 3, padding=1, stride=2),
            nn.BatchNorm2d(channels*4),
            nn.ReLU()
        )
        
        # Residual blocks
        self.res1 = ResidualBlock(channels*4)
        self.res2 = ResidualBlock(channels*4)
        
        # Decoder
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(channels*4, channels*2, 4, stride=2, padding=1),
            nn.BatchNorm2d(channels*2),
            nn.ReLU()
        )
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(channels*2, channels, 4, stride=2, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU()
        )
        self.up3 = nn.Sequential(
            nn.ConvTranspose2d(channels, 3, 4, stride=2, padding=1),
            nn.Tanh()
        )
        
    def forward(self, x):
        # Encoder
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        
        # Residual blocks
        x4 = self.res1(x3)
        x5 = self.res2(x4)
        
        # Decoder
        x6 = self.up1(x5)
        x7 = self.up2(x6)
        x8 = self.up3(x7)
        
        return x8

# Utility functions
def compute_uncertainty(model, x, num_samples=5):
    model.train()
    preds = [model(x) for _ in range(num_samples)]
    model.eval()
    return torch.var(torch.stack(preds), dim=0).mean(1, keepdim=True)

def generate_mask(uncertainty, alpha=10.0, tau=0.5, invert_uncertainty=False, random_corruption=False):
    if random_corruption:
        # Generate random mask with same shape as uncertainty
        return torch.bernoulli(torch.ones_like(uncertainty) * 0.5).detach()
    
    if invert_uncertainty:
        uncertainty = 1.0 - uncertainty
    prob = torch.sigmoid(alpha * (uncertainty - tau))
    return torch.bernoulli(prob).detach()

def regularize_mask(mask, lambda_reg=1e-3):
    return lambda_reg * mask.mean()

def corrupt_input(x, mask):
    noise = torch.randn_like(x)
    return mask * noise + (1 - mask) * x

def compute_psnr(img1, img2):
    img1 = img1.detach().cpu().numpy()
    img2 = img2.detach().cpu().numpy()
    return psnr(img1, img2, data_range=2.0)

def compute_ssim(img1, img2):
    img1 = img1.detach().cpu().numpy()
    img2 = img2.detach().cpu().numpy()
    if len(img1.shape) == 4:
        ssim_values = []
        for i in range(img1.shape[0]):
            img1_i = img1[i].transpose(1, 2, 0)  # Convert from CxHxW to HxWxC
            img2_i = img2[i].transpose(1, 2, 0)
            ssim_val = ssim(img1_i, img2_i, data_range=2.0, win_size=3, channel_axis=2)
            ssim_values.append(ssim_val)
        return np.mean(ssim_values)
    else:
        img1 = img1.transpose(1, 2, 0)
        img2 = img2.transpose(1, 2, 0)
        return ssim(img1, img2, data_range=2.0, win_size=3, channel_axis=2)

def analyze_corruption_coverage(mask, uncertainty):
    corrupted_uncertainty = (mask * uncertainty).sum() / (mask.sum() + 1e-6)
    uncorrupted_uncertainty = ((1 - mask) * uncertainty).sum() / ((1 - mask).sum() + 1e-6)
    corruption_ratio = mask.mean().item()
    return {
        'corrupted_uncertainty': corrupted_uncertainty.item(),
        'uncorrupted_uncertainty': uncorrupted_uncertainty.item(),
        'corruption_ratio': corruption_ratio
    }

# Visualization functions
def visualize_uncertainty_and_corruption(x, uncertainty, mask, pred, config_name, epoch, save_dir='visualizations'):
    os.makedirs(save_dir, exist_ok=True)
    
    x_np = x.detach().cpu().numpy()
    uncertainty_np = uncertainty.detach().cpu().numpy()
    mask_np = mask.detach().cpu().numpy()
    pred_np = pred.detach().cpu().numpy()
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 12))
    
    # Convert from CxHxW to HxWxC for RGB images
    x_np = x_np[0].transpose(1, 2, 0)
    pred_np = pred_np[0].transpose(1, 2, 0)
    
    # Normalize images for display
    x_np = (x_np + 1) / 2
    pred_np = (pred_np + 1) / 2
    
    axes[0, 0].imshow(x_np)
    axes[0, 0].set_title('Original Image')
    axes[0, 0].axis('off')
    
    im = axes[0, 1].imshow(uncertainty_np[0, 0], cmap='hot')
    axes[0, 1].set_title('Uncertainty Map')
    axes[0, 1].axis('off')
    plt.colorbar(im, ax=axes[0, 1])
    
    axes[1, 0].imshow(mask_np[0, 0], cmap='gray')
    axes[1, 0].set_title('Corruption Mask')
    axes[1, 0].axis('off')
    
    axes[1, 1].imshow(pred_np)
    axes[1, 1].set_title('Reconstruction')
    axes[1, 1].axis('off')
    
    plt.suptitle(f'{config_name} - Epoch {epoch}')
    plt.tight_layout()
    plt.savefig(f'{save_dir}/uncertainty_analysis_{config_name}_epoch_{epoch}.png')
    plt.close()

def visualize_batch_comparison(x, x_corrupt, recon, config_name, epoch, save_dir='visualizations'):
    os.makedirs(save_dir, exist_ok=True)
    
    fig, axes = plt.subplots(3, 8, figsize=(20, 8))
    
    def plot_row(images, row_idx, title):
        for col_idx in range(8):
            img = images[col_idx].detach().cpu().numpy().transpose(1, 2, 0)
            img = (img + 1) / 2  # Normalize for display
            axes[row_idx, col_idx].imshow(img)
            axes[row_idx, col_idx].axis('off')
        axes[row_idx, 0].set_ylabel(title)
    
    plot_row(x, 0, 'Original')
    plot_row(x_corrupt, 1, 'Corrupted')
    plot_row(recon, 2, 'Reconstruction')
    
    plt.suptitle(f'{config_name} - Epoch {epoch}')
    plt.tight_layout()
    plt.savefig(f'{save_dir}/batch_comparison_{config_name}_epoch_{epoch}.png')
    plt.close()

# Training function
def train_model(config_name, config, num_epochs=10, device='cuda' if torch.cuda.is_available() else 'cpu'):
    print(f"\nTraining {config_name} configuration")
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x * 2 - 1)
    ])
    cifar = datasets.CIFAR10(root='.', train=True, download=True, transform=transform)
    dataloader = DataLoader(cifar, batch_size=128, shuffle=True)
    
    model = SimpleUNet().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    losses = []
    psnr_values = []
    ssim_values = []
    corruption_metrics = []
    
    for epoch in range(num_epochs):
        model.train()
        epoch_losses = []
        epoch_psnr = []
        epoch_ssim = []
        epoch_corruption_metrics = []
        
        for x, _ in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            x = x.to(device)
            
            unc = compute_uncertainty(model, x, config['num_samples'])
            mask = generate_mask(
                unc, 
                config['alpha'], 
                config['tau'], 
                config['invert_uncertainty'],
                config['random_corruption']
            )
            x_corrupt = corrupt_input(x, mask)
            pred = model(x_corrupt)
            
            loss = F.mse_loss(pred, x) + regularize_mask(mask, config['lambda_reg'])
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_losses.append(loss.item())
            epoch_psnr.append(compute_psnr(pred, x))
            epoch_ssim.append(compute_ssim(pred, x))
            epoch_corruption_metrics.append(analyze_corruption_coverage(mask, unc))
        
        losses.append(np.mean(epoch_losses))
        psnr_values.append(np.mean(epoch_psnr))
        ssim_values.append(np.mean(epoch_ssim))
        
        avg_corruption_metrics = {
            'corrupted_uncertainty': np.mean([m['corrupted_uncertainty'] for m in epoch_corruption_metrics]),
            'uncorrupted_uncertainty': np.mean([m['uncorrupted_uncertainty'] for m in epoch_corruption_metrics]),
            'corruption_ratio': np.mean([m['corruption_ratio'] for m in epoch_corruption_metrics])
        }
        corruption_metrics.append(avg_corruption_metrics)
        
        print(f"Epoch {epoch+1}:")
        print(f"  Loss = {losses[-1]:.4f}")
        print(f"  PSNR = {psnr_values[-1]:.2f}")
        print(f"  SSIM = {ssim_values[-1]:.4f}")
        print(f"  Corruption Ratio = {avg_corruption_metrics['corruption_ratio']:.4f}")
        print(f"  Corrupted/Uncorrupted Uncertainty = {avg_corruption_metrics['corrupted_uncertainty']:.4f}/{avg_corruption_metrics['uncorrupted_uncertainty']:.4f}")
        
        if (epoch + 1) % 5 == 0:
            model.eval()
            with torch.no_grad():
                x, _ = next(iter(dataloader))
                x = x.to(device)[:8]
                
                unc = compute_uncertainty(model, x, config['num_samples'])
                mask = generate_mask(
                    unc, 
                    config['alpha'], 
                    config['tau'], 
                    config['invert_uncertainty'],
                    config['random_corruption']
                )
                x_corrupt = corrupt_input(x, mask)
                pred = model(x_corrupt)
                
                visualize_uncertainty_and_corruption(
                    x, unc, mask, pred,
                    config_name, epoch + 1
                )
                
                visualize_batch_comparison(
                    x, x_corrupt, pred,
                    config_name, epoch + 1
                )
    
    return {
        'losses': losses,
        'psnr': psnr_values,
        'ssim': ssim_values,
        'corruption_metrics': corruption_metrics,
        'config': config
    }

# Plot results
def plot_comparison(results):
    plt.figure(figsize=(15, 10))
    
    # Plot losses
    plt.subplot(2, 2, 1)
    for config_name, result in results.items():
        plt.plot(result['losses'], label=config_name)
    plt.title('Training Losses')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    
    # Plot PSNR
    plt.subplot(2, 2, 2)
    for config_name, result in results.items():
        plt.plot(result['psnr'], label=config_name)
    plt.title('PSNR')
    plt.xlabel('Epochs')
    plt.ylabel('PSNR (dB)')
    plt.legend()
    
    # Plot SSIM
    plt.subplot(2, 2, 3)
    for config_name, result in results.items():
        plt.plot(result['ssim'], label=config_name)
    plt.title('SSIM')
    plt.xlabel('Epochs')
    plt.ylabel('SSIM')
    plt.legend()
    
    # Plot corruption ratio
    plt.subplot(2, 2, 4)
    for config_name, result in results.items():
        corruption_ratios = [m['corruption_ratio'] for m in result['corruption_metrics']]
        plt.plot(corruption_ratios, label=config_name)
    plt.title('Corruption Ratio')
    plt.xlabel('Epochs')
    plt.ylabel('Ratio')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('cifar_uncertainty_comparison.png')
    plt.close()


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
results = {}

for config_name, config in configs.items():
    results[config_name] = train_model(
        config_name,
        config,
        num_epochs=10,
        device=device
    )

plot_comparison(results)

# Print final results
print("\nFinal Results:")
print("=============")
for config_name, result in results.items():
    print(f"\nConfiguration: {config_name}")
    print(f"Final Loss: {result['losses'][-1]:.4f}")
    print(f"Final PSNR: {result['psnr'][-1]:.2f}")
    print(f"Final SSIM: {result['ssim'][-1]:.4f}")
    print(f"Final Corruption Ratio: {result['corruption_metrics'][-1]['corruption_ratio']:.4f}")
    print(f"Corrupted/Uncorrupted Uncertainty: {result['corruption_metrics'][-1]['corrupted_uncertainty']:.4f}/{result['corruption_metrics'][-1]['uncorrupted_uncertainty']:.4f}") 