In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Clone RIFE repository
!git clone https://github.com/hzwer/arXiv2020-RIFE.git
%cd arXiv2020-RIFE
!pip install git+https://github.com/rk-exxec/scikit-video.git@numpy_deprecation
!pip install thop  # For FLOPs calculation

# Copy the zip file WITH destination specified
!cp "/content/drive/MyDrive/RIFE_weights/RIFE_trained_v6.zip" ./

# Extract the zip file (using unzip as alternative if 7z has issues)
!unzip -q RIFE_trained_v6.zip || 7z x RIFE_trained_v6.zip

# Fix nested train_log structure
!if [ -d "RIFE_trained_v6/train_log" ]; then mv RIFE_trained_v6/train_log ./train_log; fi
!if [ -d "train_log/train_log" ]; then mv train_log/train_log ./train_log_fixed && rm -rf train_log && mv train_log_fixed train_log; fi

# Alternative: if train_log is directly in RIFE_trained_v6 root
!if [ -d "RIFE_trained_v6" ] && [ ! -d "train_log" ]; then mv RIFE_trained_v6 train_log; fi

print("\nContents of train_log:")
!ls -la train_log || echo "train_log folder not found!"

print("\nRequired model files:")
!ls -la train_log/*.pkl 2>/dev/null || echo "Model .pkl files not found!"

# Verify dataset path
UCF_PATH = "/content/drive/MyDrive/UCF-101/ucf101_interp_ours"

import os
if not os.path.exists(UCF_PATH):
    raise FileNotFoundError(f"Dataset path not found: {UCF_PATH}")
else:
    print(f"\nDataset path found: {UCF_PATH}")

# Load required libraries
import cv2, math, torch, numpy as np
import time
from tqdm import tqdm
from thop import profile, clever_format
from torch.nn import functional as F
from model.pytorch_msssim import ssim_matlab
from model.RIFE import Model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load model
model = Model()
model.load_model('train_log')
model.eval()
model.device()

print("\n" + "="*60)
print("MODEL STATISTICS")
print("="*60)

# Calculate model parameters - access the actual PyTorch modules
def count_parameters(model):
    total = 0
    # RIFE Model has flownet attribute which is the actual nn.Module
    if hasattr(model, 'flownet'):
        total += sum(p.numel() for p in model.flownet.parameters())
    # Check for other submodules
    for attr_name in dir(model):
        attr = getattr(model, attr_name)
        if isinstance(attr, torch.nn.Module) and attr_name != 'flownet':
            total += sum(p.numel() for p in attr.parameters())
    return total

total_params = count_parameters(model)
print(f"Total Parameters: {total_params:,} ({total_params/1e6:.2f}M)")

# Calculate FLOPs (using a sample input)
sample_img0 = torch.randn(1, 3, 256, 256).to(device)
sample_img1 = torch.randn(1, 3, 256, 256).to(device)

flops_str = "N/A"
try:
    # For RIFE, we need to profile the flownet specifically
    with torch.no_grad():
        # RIFE's inference method
        flops, params = profile(model.flownet, inputs=(torch.cat([sample_img0, sample_img1], 1),), verbose=False)
    flops_str, params_str = clever_format([flops, params], "%.3f")
    print(f"FLOPs (256x256 input): {flops_str}")
    print(f"Params (from profiler): {params_str}")
except Exception as e:
    print(f"FLOPs calculation note: Using model.flownet for profiling")
    try:
        # Alternative: just count FLOPs for flownet
        with torch.no_grad():
            flops = profile(model.flownet, inputs=(torch.cat([sample_img0, sample_img1], 1),), verbose=False)[0]
        flops_str = clever_format([flops], "%.3f")[0]
        print(f"FLOPs (256x256 input): {flops_str}")
    except Exception as e2:
        print(f"FLOPs calculation failed: {e2}")
        flops_str = "N/A"

print("="*60 + "\n")

# Evaluate on UCF-101
dirs = os.listdir(UCF_PATH)
psnr_list, ssim_list, time_list = [], [], []
memory_list = []

print(f"Starting evaluation on {len(dirs)} sequences...\n")

# Warm-up runs
print("Performing warm-up runs...")
sample_img0 = torch.randn(1, 3, 256, 256).to(device)
sample_img1 = torch.randn(1, 3, 256, 256).to(device)
for _ in range(3):
    with torch.no_grad():
        _ = model.inference(sample_img0, sample_img1)

# Clear cache and measure baseline memory
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    baseline_memory = torch.cuda.memory_allocated() / (1024**2)  # MB
    print(f"Baseline GPU memory: {baseline_memory:.2f} MB")

print("Warm-up complete.\n")

# Use tqdm for progress bar
for d in tqdm(dirs, desc="Evaluating", unit="seq"):
    img0_path = os.path.join(UCF_PATH, d, 'frame_00.png')
    img1_path = os.path.join(UCF_PATH, d, 'frame_02.png')
    gt_path   = os.path.join(UCF_PATH, d, 'frame_01_gt.png')

    if not all(map(os.path.exists, [img0_path, img1_path, gt_path])):
        tqdm.write(f"Missing frames in {d}, skipping.")
        continue

    img0 = torch.tensor(cv2.imread(img0_path).transpose(2,0,1)/255.).float().unsqueeze(0).to(device)
    img1 = torch.tensor(cv2.imread(img1_path).transpose(2,0,1)/255.).float().unsqueeze(0).to(device)
    gt   = torch.tensor(cv2.imread(gt_path).transpose(2,0,1)/255.).float().unsqueeze(0).to(device)

    # Reset peak memory stats before inference
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        mem_before = torch.cuda.memory_allocated() / (1024**2)  # MB

    # Measure inference time
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    start_time = time.time()

    with torch.no_grad():
        pred = model.inference(img0, img1)[0]

    if torch.cuda.is_available():
        torch.cuda.synchronize()
    end_time = time.time()

    # Measure peak memory usage during inference
    if torch.cuda.is_available():
        mem_after = torch.cuda.memory_allocated() / (1024**2)  # MB
        peak_mem = torch.cuda.max_memory_allocated() / (1024**2)  # MB
        memory_used = peak_mem - mem_before
        memory_list.append(memory_used)

    inference_time = (end_time - start_time) * 1000  # Convert to milliseconds
    time_list.append(inference_time)

    # Calculate SSIM - FIXED: removed extra parenthesis
    ssim = ssim_matlab(gt, torch.round(pred*255).unsqueeze(0)/255).detach().cpu().numpy()

    # Calculate PSNR
    out = pred.detach().cpu().numpy().transpose(1,2,0)
    out = np.round(out*255)/255.
    gt_np = gt[0].cpu().numpy().transpose(1,2,0)
    psnr = -10 * math.log10(((gt_np - out)**2).mean())

    psnr_list.append(psnr)
    ssim_list.append(ssim)

# Calculate memory statistics
if torch.cuda.is_available():
    total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**2)  # MB
    current_memory = torch.cuda.memory_allocated() / (1024**2)  # MB
    peak_memory = torch.cuda.max_memory_allocated() / (1024**2)  # MB
    reserved_memory = torch.cuda.memory_reserved() / (1024**2)  # MB

# Print results
print("\n" + "="*60)
print("EVALUATION RESULTS")
print("="*60)
print(f"Dataset: UCF-101")
print(f"Total Sequences Evaluated: {len(psnr_list)}")
print("-"*60)
print("Quality Metrics:")
print(f"  Average PSNR: {np.mean(psnr_list):.3f} dB")
print(f"  Average SSIM: {np.mean(ssim_list):.4f}")
print("-"*60)
print("Speed Metrics:")
print(f"  Average Inference Time: {np.mean(time_list):.2f} ms")
print(f"  FPS: {1000/np.mean(time_list):.2f}")
print(f"  Min Inference Time: {np.min(time_list):.2f} ms")
print(f"  Max Inference Time: {np.max(time_list):.2f} ms")
print(f"  Std Dev Inference Time: {np.std(time_list):.2f} ms")
print("-"*60)
print("Memory Usage:")
if torch.cuda.is_available() and len(memory_list) > 0:
    print(f"  Average Memory per Frame: {np.mean(memory_list):.2f} MB")
    print(f"  Peak Memory per Frame: {np.max(memory_list):.2f} MB")
    print(f"  Min Memory per Frame: {np.min(memory_list):.2f} MB")
    print(f"  Current GPU Memory: {current_memory:.2f} MB")
    print(f"  Peak GPU Memory: {peak_memory:.2f} MB")
    print(f"  Reserved GPU Memory: {reserved_memory:.2f} MB")
    print(f"  Total GPU Memory: {total_memory:.2f} MB")
else:
    print(f"  Memory tracking not available (CPU mode)")
print("-"*60)
print("Model Complexity:")
print(f"  Parameters: {total_params:,} ({total_params/1e6:.2f}M)")
print(f"  FLOPs (256x256): {flops_str}")
print("="*60)

Mounted at /content/drive
Cloning into 'arXiv2020-RIFE'...
remote: Enumerating objects: 2037, done.[K
remote: Counting objects: 100% (461/461), done.[K
remote: Compressing objects: 100% (101/101), done.[K
remote: Total 2037 (delta 423), reused 360 (delta 360), pack-reused 1576 (from 2)[K
Receiving objects: 100% (2037/2037), 4.12 MiB | 11.01 MiB/s, done.
Resolving deltas: 100% (1293/1293), done.
/content/arXiv2020-RIFE
Collecting git+https://github.com/rk-exxec/scikit-video.git@numpy_deprecation
  Cloning https://github.com/rk-exxec/scikit-video.git (to revision numpy_deprecation) to /tmp/pip-req-build-q6accp29
  Running command git clone --filter=blob:none --quiet https://github.com/rk-exxec/scikit-video.git /tmp/pip-req-build-q6accp29
  Running command git checkout -b numpy_deprecation --track origin/numpy_deprecation
  Switched to a new branch 'numpy_deprecation'
  Branch 'numpy_deprecation' set up to track remote branch 'numpy_deprecation' from 'origin'.
  Resolved https://githu



FLOPs (256x256 input): 11.684G
Params (from profiler): 10.072M

Starting evaluation on 379 sequences...

Performing warm-up runs...
Baseline GPU memory: 43.80 MB
Warm-up complete.



Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 379/379 [11:25<00:00,  1.81s/seq]


EVALUATION RESULTS
Dataset: UCF-101
Total Sequences Evaluated: 379
------------------------------------------------------------
Quality Metrics:
  Average PSNR: 35.292 dB
  Average SSIM: 0.9690
------------------------------------------------------------
Speed Metrics:
  Average Inference Time: 14.16 ms
  FPS: 70.60
  Min Inference Time: 13.45 ms
  Max Inference Time: 23.23 ms
  Std Dev Inference Time: 1.32 ms
------------------------------------------------------------
Memory Usage:
  Average Memory per Frame: 33.93 MB
  Peak Memory per Frame: 33.93 MB
  Min Memory per Frame: 33.93 MB
  Current GPU Memory: 46.80 MB
  Peak GPU Memory: 80.73 MB
  Reserved GPU Memory: 104.00 MB
  Total GPU Memory: 15095.06 MB
------------------------------------------------------------
Model Complexity:
  Parameters: 10,708,215 (10.71M)
  FLOPs (256x256): 11.684G





In [None]:
# ============================================================
# COMPLETE PRUNING PIPELINE - CORRECTED VERSION
# ============================================================

import torch
import torch.nn as nn
import numpy as np
import copy
import os
from tqdm import tqdm
import time
import json

# ============================================================
# HELPER FUNCTIONS FOR PRUNING (FIXED)
# ============================================================

def compute_channel_importance(module, layer_types=[nn.Conv2d]):
    """
    Compute L1-norm importance for each channel in convolutional layers.

    Args:
        module: Neural network module (e.g., model.flownet)
        layer_types: Types of layers to analyze

    Returns:
        importance_dict: Dictionary mapping layer names to importance scores
    """
    importance_dict = {}

    for name, layer in module.named_modules():
        if isinstance(layer, tuple(layer_types)):
            # For Conv2d: weight shape is [out_channels, in_channels, H, W]
            # Importance = L1 norm across all dimensions except out_channels
            weight = layer.weight.data
            importance = torch.norm(weight.view(weight.size(0), -1), p=1, dim=1)
            importance_dict[name] = importance.cpu()

    print(f"Analyzed {len(importance_dict)} convolutional layers")
    return importance_dict


def global_channel_pruning(importance_dict, sparsity=0.5):
    """
    Create pruning masks based on global importance ranking.

    Args:
        importance_dict: Channel importance scores from compute_channel_importance()
        sparsity: Fraction of channels to prune (0-1)

    Returns:
        prune_masks: Dictionary of boolean masks for each layer
    """
    # Flatten all importance scores
    all_importance = torch.cat([imp for imp in importance_dict.values()])

    # Determine threshold (prune channels below this)
    num_to_prune = int(len(all_importance) * sparsity)
    threshold = torch.sort(all_importance)[0][num_to_prune]

    print(f"Sparsity: {sparsity:.1%}, Threshold: {threshold:.6f}")

    # Create masks (True = keep, False = prune)
    prune_masks = {}
    total_channels = 0
    pruned_channels = 0

    for name, importance in importance_dict.items():
        mask = (importance > threshold)
        prune_masks[name] = mask

        total_channels += len(mask)
        pruned_channels += (~mask).sum().item()

    actual_sparsity = pruned_channels / total_channels
    print(f"Actual pruning: {pruned_channels}/{total_channels} channels ({actual_sparsity:.1%})")

    return prune_masks


def apply_soft_pruning(model, prune_masks):
    """
    Apply soft pruning by zeroing out weights (doesn't change model structure).
    This is fast for testing different sparsity levels.

    Args:
        model: RIFE model
        prune_masks: Dictionary of pruning masks

    Returns:
        pruned_model: Model with zeroed weights
    """
    from model.RIFE import Model as RIFEModel

    # Create a deep copy
    pruned_model = RIFEModel()
    pruned_model.load_model('train_log')  # Reload from checkpoint
    pruned_model.eval()

    # Apply masks
    for name, module in pruned_model.flownet.named_modules():
        if name in prune_masks:
            mask = prune_masks[name]

            # Get device from module
            device = next(module.parameters()).device

            # Expand mask to match weight dimensions [out_ch, in_ch, H, W]
            weight_mask = mask.view(-1, 1, 1, 1).expand_as(module.weight).to(device)

            # Zero out pruned channels
            module.weight.data *= weight_mask

            if module.bias is not None:
                bias_mask = mask.to(device)
                module.bias.data *= bias_mask

    return pruned_model


def quick_evaluate(model, dataset_path, num_samples=50):
    """
    Quick evaluation on a subset of data for rapid testing.
    FIXED: Works with RIFE's custom Model class

    Args:
        model: RIFE model to evaluate
        dataset_path: Path to dataset (e.g., UCF-101)
        num_samples: Number of sequences to test

    Returns:
        results: Dictionary with PSNR, SSIM, inference time
    """
    import cv2
    import math
    from model.pytorch_msssim import ssim_matlab

    # RIFE model is already on correct device
    model.eval()

    # Detect device from model parameters
    device = next(model.flownet.parameters()).device

    # Get list of sequences
    dirs = os.listdir(dataset_path)
    dirs = dirs[:num_samples]  # Limit to num_samples

    psnr_list, ssim_list, time_list = [], [], []

    print(f"Quick evaluation on {len(dirs)} samples...")

    for d in tqdm(dirs, desc="Evaluating", leave=False):
        img0_path = os.path.join(dataset_path, d, 'frame_00.png')
        img1_path = os.path.join(dataset_path, d, 'frame_02.png')
        gt_path = os.path.join(dataset_path, d, 'frame_01_gt.png')

        if not all(map(os.path.exists, [img0_path, img1_path, gt_path])):
            continue

        # Load images
        img0 = torch.tensor(cv2.imread(img0_path).transpose(2,0,1)/255.).float().unsqueeze(0).to(device)
        img1 = torch.tensor(cv2.imread(img1_path).transpose(2,0,1)/255.).float().unsqueeze(0).to(device)
        gt = torch.tensor(cv2.imread(gt_path).transpose(2,0,1)/255.).float().unsqueeze(0).to(device)

        # Inference with timing
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        start_time = time.time()

        with torch.no_grad():
            pred = model.inference(img0, img1)[0]

        if torch.cuda.is_available():
            torch.cuda.synchronize()
        inference_time = (time.time() - start_time) * 1000
        time_list.append(inference_time)

        # Compute SSIM
        ssim_val = ssim_matlab(gt, torch.round(pred*255).unsqueeze(0)/255).detach().cpu().numpy()
        ssim_list.append(float(ssim_val))

        # Compute PSNR
        out = pred.detach().cpu().numpy().transpose(1,2,0)
        out = np.round(out*255)/255.
        gt_np = gt[0].cpu().numpy().transpose(1,2,0)
        mse = ((gt_np - out)**2).mean()
        psnr = -10 * math.log10(mse + 1e-8)
        psnr_list.append(psnr)

    if len(psnr_list) == 0:
        raise ValueError(f"No valid samples found in {dataset_path}")

    results = {
        'PSNR': np.mean(psnr_list),
        'SSIM': np.mean(ssim_list),
        'Inference_Time_ms': np.mean(time_list),
        'FPS': 1000 / np.mean(time_list) if len(time_list) > 0 else 0
    }

    return results


def count_parameters(model):
    """Count number of parameters in model"""
    total = 0
    if hasattr(model, 'flownet'):
        total += sum(p.numel() for p in model.flownet.parameters())
    # Check for other submodules
    for attr_name in dir(model):
        attr = getattr(model, attr_name)
        if isinstance(attr, torch.nn.Module) and attr_name != 'flownet':
            total += sum(p.numel() for p in attr.parameters())
    return total


def select_best_sparsity(pruning_results, target_psnr_drop=0.3, baseline_psnr=35.292):
    """
    Select best sparsity level based on quality-efficiency trade-off.

    Args:
        pruning_results: Dictionary mapping sparsity to results
        target_psnr_drop: Maximum acceptable PSNR drop
        baseline_psnr: Baseline PSNR to compare against

    Returns:
        best_sparsity: Selected sparsity level
    """
    print("\n" + "="*60)
    print("SPARSITY SELECTION")
    print("="*60)

    acceptable_sparsities = []

    for sparsity, results in sorted(pruning_results.items()):
        psnr_drop = baseline_psnr - results['PSNR']
        print(f"Sparsity {sparsity:.1%}: PSNR={results['PSNR']:.3f} dB (drop: {psnr_drop:.3f} dB), SSIM={results['SSIM']:.4f}")

        if psnr_drop <= target_psnr_drop:
            acceptable_sparsities.append(sparsity)

    if len(acceptable_sparsities) == 0:
        print(f"‚ö†Ô∏è  No sparsity level meets target drop of {target_psnr_drop} dB")
        print("    Selecting least aggressive option (30%)")
        return 0.3
    else:
        # Select highest acceptable sparsity (most compression)
        best = max(acceptable_sparsities)
        print(f"‚úÖ Selected sparsity: {best:.1%}")
        return best


# ============================================================
# MAIN PRUNING PIPELINE
# ============================================================

print("="*60)
print("PHASE 1: PRUNING EXPERIMENTS")
print("="*60)

# Verify model is already loaded from previous cell
if 'model' not in globals():
    print("‚ö†Ô∏è  Model not found! Please run the first cell to load the model.")
    raise RuntimeError("Model not loaded. Run the baseline evaluation cell first.")

print("‚úÖ Using model from previous cell")

# Verify dataset path
UCF_PATH = "/content/drive/MyDrive/UCF-101/ucf101_interp_ours"
if not os.path.exists(UCF_PATH):
    raise FileNotFoundError(f"Dataset path not found: {UCF_PATH}")
print(f"‚úÖ Dataset path verified: {UCF_PATH}")

# Get baseline PSNR (from your previous evaluation)
baseline_psnr = 35.292  # From your full evaluation
baseline_ssim = 0.9690

print(f"\nBaseline metrics (from full evaluation):")
print(f"  PSNR: {baseline_psnr:.3f} dB")
print(f"  SSIM: {baseline_ssim:.4f}")

# Quick verification on subset
print("\nVerifying baseline on subset (20 samples)...")
baseline_quick = quick_evaluate(model, UCF_PATH, num_samples=20)
print(f"Baseline quick check - PSNR: {baseline_quick['PSNR']:.3f} dB, SSIM: {baseline_quick['SSIM']:.4f}")

# Step 1: Compute channel importance
print("\n" + "="*60)
print("STEP 1: Computing Channel Importance")
print("="*60)
importance_dict = compute_channel_importance(model.flownet)

# Step 2: Test different sparsity levels
print("\n" + "="*60)
print("STEP 2: Testing Different Sparsity Levels")
print("="*60)

sparsity_levels = [0.3, 0.5, 0.7]
pruning_results = {}

for sparsity in sparsity_levels:
    print(f"\n{'‚îÄ'*60}")
    print(f"Testing sparsity: {sparsity:.1%}")
    print(f"{'‚îÄ'*60}")

    # Create pruning masks
    prune_masks = global_channel_pruning(importance_dict, sparsity)

    # Apply soft pruning
    print("Applying pruning masks...")
    pruned_model = apply_soft_pruning(model, prune_masks)
    pruned_model.device()  # Ensure on correct device

    # Quick evaluation
    print(f"Evaluating pruned model (50 samples)...")
    results = quick_evaluate(pruned_model, UCF_PATH, num_samples=50)
    pruning_results[sparsity] = results

    # Calculate metrics
    psnr_drop = baseline_psnr - results['PSNR']
    ssim_drop = baseline_ssim - results['SSIM']
    param_reduction = sparsity * 100
    speedup = results['FPS'] / baseline_quick['FPS'] if baseline_quick['FPS'] > 0 else 1.0

    print(f"\nüìä Results:")
    print(f"  PSNR: {results['PSNR']:.3f} dB (drop: {psnr_drop:.3f} dB)")
    print(f"  SSIM: {results['SSIM']:.4f} (drop: {ssim_drop:.4f})")
    print(f"  Inference: {results['Inference_Time_ms']:.2f} ms ({results['FPS']:.1f} FPS)")
    print(f"  Est. Param Reduction: ~{param_reduction:.0f}%")
    print(f"  Speedup vs baseline: {speedup:.2f}x")

    # Clear GPU cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Delete pruned model to free memory
    del pruned_model

# Step 3: Select best sparsity
best_sparsity = select_best_sparsity(pruning_results, target_psnr_drop=0.5, baseline_psnr=baseline_psnr)

# Step 4: Create final pruned model
print(f"\n{'='*60}")
print(f"STEP 3: Creating Final Pruned Model")
print(f"{'='*60}")
print(f"Selected sparsity: {best_sparsity:.1%}")

final_prune_masks = global_channel_pruning(importance_dict, best_sparsity)
pruned_model_final = apply_soft_pruning(model, final_prune_masks)
pruned_model_final.device()

# Step 5: Comprehensive evaluation
print(f"\n{'='*60}")
print("STEP 4: Comprehensive Evaluation")
print(f"{'='*60}")
print("Evaluating on FULL UCF-101 dataset (379 sequences)...")
print("‚è±Ô∏è  This may take 5-10 minutes...")

final_results = quick_evaluate(pruned_model_final, UCF_PATH, num_samples=379)

# Step 6: Print final comparison
print(f"\n{'='*60}")
print("FINAL RESULTS")
print(f"{'='*60}")

print(f"\nüìä Baseline Model:")
print(f"  PSNR: {baseline_psnr:.3f} dB")
print(f"  SSIM: {baseline_ssim:.4f}")
print(f"  Parameters: {count_parameters(model):,} ({count_parameters(model)/1e6:.2f}M)")

print(f"\nüìä Pruned Model (Sparsity: {best_sparsity:.1%}):")
print(f"  PSNR: {final_results['PSNR']:.3f} dB (Œî: {baseline_psnr - final_results['PSNR']:.3f} dB)")
print(f"  SSIM: {final_results['SSIM']:.4f} (Œî: {baseline_ssim - final_results['SSIM']:.4f})")
print(f"  Inference Time: {final_results['Inference_Time_ms']:.2f} ms")
print(f"  FPS: {final_results['FPS']:.1f}")
print(f"  Est. Param Reduction: ~{best_sparsity*100:.0f}%")

# Step 7: Save results
results_summary = {
    'baseline': {
        'PSNR': baseline_psnr,
        'SSIM': baseline_ssim,
        'Parameters': count_parameters(model)
    },
    'pruned': {
        'sparsity': best_sparsity,
        'PSNR': final_results['PSNR'],
        'SSIM': final_results['SSIM'],
        'Inference_Time_ms': final_results['Inference_Time_ms'],
        'FPS': final_results['FPS'],
        'PSNR_drop': baseline_psnr - final_results['PSNR'],
        'SSIM_drop': baseline_ssim - final_results['SSIM']
    },
    'all_sparsity_experiments': {k: v for k, v in pruning_results.items()}
}

# Save to file
with open('/content/pruning_results.json', 'w') as f:
    json.dump(results_summary, f, indent=2)

print(f"\nüíæ Results saved to '/content/pruning_results.json'")

# Save pruned model state
torch.save({
    'flownet_state_dict': pruned_model_final.flownet.state_dict(),
    'sparsity': best_sparsity,
    'prune_masks': {k: v.numpy() for k, v in final_prune_masks.items()},
    'results': final_results
}, '/content/pruned_model_checkpoint.pth')

print(f"üíæ Pruned model saved to '/content/pruned_model_checkpoint.pth'")

print("\n" + "="*60)
print("‚úÖ PRUNING PHASE COMPLETE!")
print("="*60)
print("\nNext steps:")
print("  1. Review results in 'pruning_results.json'")
print("  2. Proceed to fine-tuning (if PSNR drop > 0.3 dB)")
print("  3. Or proceed to quantization phase")

PHASE 1: PRUNING EXPERIMENTS
‚úÖ Using model from previous cell
‚úÖ Dataset path verified: /content/drive/MyDrive/UCF-101/ucf101_interp_ours

Baseline metrics (from full evaluation):
  PSNR: 35.292 dB
  SSIM: 0.9690

Verifying baseline on subset (20 samples)...
Quick evaluation on 20 samples...




Baseline quick check - PSNR: 36.641 dB, SSIM: 0.9810

STEP 1: Computing Channel Importance
Analyzed 57 convolutional layers

STEP 2: Testing Different Sparsity Levels

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Testing sparsity: 30.0%
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Sparsity: 30.0%, Threshold: 40.140980
Actual pruning: 2058/6858 channels (30.0%)
Applying pruning masks...
Evaluating pruned model (50 samples)...
Quick evaluation on 50 samples...





üìä Results:
  PSNR: 25.319 dB (drop: 9.973 dB)
  SSIM: 0.8772 (drop: 0.0918)
  Inference: 11.63 ms (86.0 FPS)
  Est. Param Reduction: ~30%
  Speedup vs baseline: 1.16x

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Testing sparsity: 50.0%
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Sparsity: 50.0%, Threshold: 73.872589
Actual pruning: 3430/6858 channels (50.0%)
Applying pruning masks...
Evaluating pruned model (50 samples)...
Quick evaluation on 50 samples...





üìä Results:
  PSNR: 32.364 dB (drop: 2.928 dB)
  SSIM: 0.9623 (drop: 0.0067)
  Inference: 10.24 ms (97.6 FPS)
  Est. Param Reduction: ~50%
  Speedup vs baseline: 1.32x

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Testing sparsity: 70.0%
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Sparsity: 70.0%, Threshold: 104.645721
Actual pruning: 4801/6858 channels (70.0%)
Applying pruning masks...
Evaluating pruned model (50 samples)...
Quick evaluation on 50 samples...





üìä Results:
  PSNR: 32.663 dB (drop: 2.629 dB)
  SSIM: 0.9629 (drop: 0.0061)
  Inference: 10.39 ms (96.2 FPS)
  Est. Param Reduction: ~70%
  Speedup vs baseline: 1.30x

SPARSITY SELECTION
Sparsity 30.0%: PSNR=25.319 dB (drop: 9.973 dB), SSIM=0.8772
Sparsity 50.0%: PSNR=32.364 dB (drop: 2.928 dB), SSIM=0.9623
Sparsity 70.0%: PSNR=32.663 dB (drop: 2.629 dB), SSIM=0.9629
‚ö†Ô∏è  No sparsity level meets target drop of 0.5 dB
    Selecting least aggressive option (30%)

STEP 3: Creating Final Pruned Model
Selected sparsity: 30.0%
Sparsity: 30.0%, Threshold: 40.140980
Actual pruning: 2058/6858 channels (30.0%)

STEP 4: Comprehensive Evaluation
Evaluating on FULL UCF-101 dataset (379 sequences)...
‚è±Ô∏è  This may take 5-10 minutes...
Quick evaluation on 379 samples...





FINAL RESULTS

üìä Baseline Model:
  PSNR: 35.292 dB
  SSIM: 0.9690
  Parameters: 10,708,215 (10.71M)

üìä Pruned Model (Sparsity: 30.0%):
  PSNR: 24.607 dB (Œî: 10.685 dB)
  SSIM: 0.8563 (Œî: 0.1127)
  Inference Time: 10.86 ms
  FPS: 92.1
  Est. Param Reduction: ~30%

üíæ Results saved to '/content/pruning_results.json'
üíæ Pruned model saved to '/content/pruned_model_checkpoint.pth'

‚úÖ PRUNING PHASE COMPLETE!

Next steps:
  1. Review results in 'pruning_results.json'
  2. Proceed to fine-tuning (if PSNR drop > 0.3 dB)
  3. Or proceed to quantization phase


In [None]:
# ============================================================
# COMPLETE CORRECTED PRUNING IMPLEMENTATION
# ============================================================

import torch
import torch.nn as nn
import numpy as np
import copy
import os
from tqdm import tqdm
import time
import json

# ============================================================
# CORRECTED HELPER FUNCTIONS
# ============================================================

def compute_channel_importance_v2(module, layer_types=[nn.Conv2d]):
    """
    Compute L1-norm importance for each channel (VERIFIED CORRECT)
    """
    importance_dict = {}

    for name, layer in module.named_modules():
        if isinstance(layer, tuple(layer_types)):
            weight = layer.weight.data
            # L1 norm per output channel: sum over [in_channels, H, W]
            importance = torch.norm(weight.view(weight.size(0), -1), p=1, dim=1)
            importance_dict[name] = importance.cpu()

    print(f"‚úÖ Analyzed {len(importance_dict)} convolutional layers")

    # Debug statistics
    all_imp = torch.cat([imp for imp in importance_dict.values()])
    print(f"   Importance range: [{all_imp.min():.2f}, {all_imp.max():.2f}], Mean: {all_imp.mean():.2f}")

    return importance_dict


def global_channel_pruning_v2(importance_dict, sparsity=0.5):
    """
    CORRECTED: Prune channels with SMALLEST importance (least important)

    Args:
        importance_dict: Channel importance scores
        sparsity: Fraction of channels to REMOVE (0-1)

    Returns:
        prune_masks: Boolean masks (True = KEEP, False = PRUNE)
    """
    # Flatten all importance scores
    all_importance = torch.cat([imp for imp in importance_dict.values()])
    total_channels = len(all_importance)

    # Calculate how many to KEEP
    num_to_keep = int(total_channels * (1 - sparsity))

    # Sort importance in DESCENDING order and get threshold
    # Channels with importance >= threshold will be KEPT
    sorted_importance = torch.sort(all_importance, descending=True)[0]
    threshold = sorted_importance[num_to_keep - 1] if num_to_keep > 0 else sorted_importance[-1]

    print(f"   Sparsity: {sparsity:.1%}")
    print(f"   Total channels: {total_channels}")
    print(f"   Channels to keep: {num_to_keep}")
    print(f"   Threshold: {threshold:.6f}")
    print(f"   Logic: KEEP channels with importance >= {threshold:.6f}")

    # Create masks: True = KEEP, False = PRUNE
    prune_masks = {}
    total_kept = 0

    for name, importance in importance_dict.items():
        # CORRECTED: Keep channels with importance >= threshold
        mask = (importance >= threshold)
        prune_masks[name] = mask
        total_kept += mask.sum().item()

    actual_sparsity = 1 - (total_kept / total_channels)
    actual_kept_pct = (total_kept / total_channels) * 100

    print(f"   Result: {total_kept}/{total_channels} channels kept ({actual_kept_pct:.1f}%)")
    print(f"   Actual sparsity: {actual_sparsity:.1%}")

    return prune_masks


def apply_soft_pruning_v2(model, prune_masks):
    """
    Apply soft pruning by zeroing out weights (VERIFIED CORRECT)
    """
    from model.RIFE import Model as RIFEModel

    # Create fresh model instance
    pruned_model = RIFEModel()
    pruned_model.load_model('train_log')
    pruned_model.eval()

    # Apply masks
    total_channels = 0
    zeroed_channels = 0
    layers_modified = 0

    for name, module in pruned_model.flownet.named_modules():
        if name in prune_masks:
            mask = prune_masks[name]
            device = next(module.parameters()).device

            # Create weight mask: expand to [out_ch, in_ch, H, W]
            weight_mask = mask.view(-1, 1, 1, 1).expand_as(module.weight).to(device)

            # Apply mask: multiply weights by mask (zeros out pruned channels)
            module.weight.data *= weight_mask

            # Apply to bias if exists
            if module.bias is not None:
                bias_mask = mask.to(device)
                module.bias.data *= bias_mask

            # Track statistics
            zeroed = (~mask).sum().item()
            total = len(mask)
            zeroed_channels += zeroed
            total_channels += total
            layers_modified += 1

    print(f"   Applied masks to {layers_modified} layers")
    print(f"   Zeroed {zeroed_channels}/{total_channels} channels ({zeroed_channels/total_channels*100:.1f}%)")

    return pruned_model


def quick_evaluate_v2(model, dataset_path, num_samples=50):
    """
    Quick evaluation (same as before, just renamed for consistency)
    """
    import cv2
    import math
    from model.pytorch_msssim import ssim_matlab

    model.eval()
    device = next(model.flownet.parameters()).device

    dirs = os.listdir(dataset_path)[:num_samples]
    psnr_list, ssim_list, time_list = [], [], []

    for d in tqdm(dirs, desc="Evaluating", leave=False):
        img0_path = os.path.join(dataset_path, d, 'frame_00.png')
        img1_path = os.path.join(dataset_path, d, 'frame_02.png')
        gt_path = os.path.join(dataset_path, d, 'frame_01_gt.png')

        if not all(map(os.path.exists, [img0_path, img1_path, gt_path])):
            continue

        img0 = torch.tensor(cv2.imread(img0_path).transpose(2,0,1)/255.).float().unsqueeze(0).to(device)
        img1 = torch.tensor(cv2.imread(img1_path).transpose(2,0,1)/255.).float().unsqueeze(0).to(device)
        gt = torch.tensor(cv2.imread(gt_path).transpose(2,0,1)/255.).float().unsqueeze(0).to(device)

        if torch.cuda.is_available():
            torch.cuda.synchronize()
        start_time = time.time()

        with torch.no_grad():
            pred = model.inference(img0, img1)[0]

        if torch.cuda.is_available():
            torch.cuda.synchronize()
        time_list.append((time.time() - start_time) * 1000)

        ssim_val = ssim_matlab(gt, torch.round(pred*255).unsqueeze(0)/255).detach().cpu().numpy()
        ssim_list.append(float(ssim_val))

        out = pred.detach().cpu().numpy().transpose(1,2,0)
        out = np.round(out*255)/255.
        gt_np = gt[0].cpu().numpy().transpose(1,2,0)
        mse = ((gt_np - out)**2).mean()
        psnr = -10 * math.log10(mse + 1e-8)
        psnr_list.append(psnr)

    return {
        'PSNR': np.mean(psnr_list),
        'SSIM': np.mean(ssim_list),
        'Inference_Time_ms': np.mean(time_list),
        'FPS': 1000 / np.mean(time_list) if len(time_list) > 0 else 0
    }


# ============================================================
# STEP 1: VERIFY THE FIX WITH DIAGNOSTIC
# ============================================================

print("="*60)
print("STEP 1: DIAGNOSTIC - COMPARE OLD VS NEW PRUNING")
print("="*60)

if 'model' not in globals():
    raise RuntimeError("Model not found. Run baseline cell first.")

UCF_PATH = "/content/drive/MyDrive/UCF-101/ucf101_interp_ours"
baseline_psnr = 35.292
baseline_ssim = 0.9690

# Compute importance with new function
print("\nComputing channel importance...")
importance_dict_v2 = compute_channel_importance_v2(model.flownet)

# Test both methods on same sparsity
sparsity_test = 0.3

print(f"\n{'='*60}")
print(f"COMPARING METHODS AT {sparsity_test:.0%} SPARSITY")
print(f"{'='*60}")

# Old method (from your previous run - we know it gives 25.699 dB)
print("\nüî¥ OLD METHOD (inverted logic):")
masks_old = global_channel_pruning(importance_dict, sparsity_test)  # Your old function
model_old = apply_soft_pruning(model, masks_old)
model_old.device()
results_old = quick_evaluate(model_old, UCF_PATH, num_samples=20)
print(f"   PSNR: {results_old['PSNR']:.3f} dB")
print(f"   SSIM: {results_old['SSIM']:.4f}")

# New corrected method
print("\n‚úÖ NEW METHOD (corrected logic):")
masks_new = global_channel_pruning_v2(importance_dict_v2, sparsity_test)
model_new = apply_soft_pruning_v2(model, masks_new)
model_new.device()
results_new = quick_evaluate_v2(model_new, UCF_PATH, num_samples=20)
print(f"   PSNR: {results_new['PSNR']:.3f} dB")
print(f"   SSIM: {results_new['SSIM']:.4f}")

# Comparison
print(f"\n{'='*60}")
print("DIAGNOSTIC RESULT:")
print(f"{'='*60}")
improvement = results_new['PSNR'] - results_old['PSNR']
if improvement > 5.0:
    print(f"‚úÖ MAJOR IMPROVEMENT: +{improvement:.2f} dB")
    print(f"   Old logic was definitely inverted!")
    print(f"   Proceeding with corrected pruning...")
elif improvement > 1.0:
    print(f"‚úÖ IMPROVEMENT: +{improvement:.2f} dB")
    print(f"   New logic is better")
else:
    print(f"‚ö†Ô∏è  MINIMAL DIFFERENCE: {improvement:.2f} dB")
    print(f"   Issue might be elsewhere")

# Cleanup
del model_old, model_new
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# ============================================================
# STEP 2: RUN CORRECTED PRUNING EXPERIMENTS
# ============================================================

print("\n" + "="*60)
print("STEP 2: CORRECTED PRUNING EXPERIMENTS")
print("="*60)

sparsity_levels = [0.1, 0.2, 0.3, 0.4, 0.5]
pruning_results_v2 = {}

for sparsity in sparsity_levels:
    print(f"\n{'‚îÄ'*60}")
    print(f"Testing sparsity: {sparsity:.1%}")
    print(f"{'‚îÄ'*60}")

    # Create masks with corrected method
    prune_masks = global_channel_pruning_v2(importance_dict_v2, sparsity)

    # Apply pruning
    pruned_model = apply_soft_pruning_v2(model, prune_masks)
    pruned_model.device()

    # Evaluate
    print(f"\nEvaluating on 50 samples...")
    results = quick_evaluate_v2(pruned_model, UCF_PATH, num_samples=50)
    pruning_results_v2[sparsity] = results

    # Calculate metrics
    psnr_drop = baseline_psnr - results['PSNR']
    ssim_drop = baseline_ssim - results['SSIM']

    print(f"\nüìä Results:")
    print(f"   PSNR: {results['PSNR']:.3f} dB (drop: {psnr_drop:.3f} dB)")
    print(f"   SSIM: {results['SSIM']:.4f} (drop: {ssim_drop:.4f})")
    print(f"   Inference: {results['Inference_Time_ms']:.2f} ms ({results['FPS']:.1f} FPS)")

    # Clear memory
    del pruned_model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# ============================================================
# STEP 3: SELECT OPTIMAL SPARSITY
# ============================================================

print("\n" + "="*60)
print("STEP 3: SELECTING OPTIMAL SPARSITY")
print("="*60)

print(f"\n{'Sparsity':<10} {'PSNR':<10} {'Drop':<10} {'SSIM':<10} {'Status':<15}")
print("‚îÄ"*60)

acceptable = {}
for sparsity in sorted(pruning_results_v2.keys()):
    results = pruning_results_v2[sparsity]
    psnr_drop = baseline_psnr - results['PSNR']

    if psnr_drop <= 0.5:
        status = "‚úÖ Excellent"
        acceptable[sparsity] = results
    elif psnr_drop <= 1.0:
        status = "‚úÖ Good"
        acceptable[sparsity] = results
    elif psnr_drop <= 1.5:
        status = "‚ö†Ô∏è  Marginal"
    else:
        status = "‚ùå Too high"

    print(f"{sparsity:<10.1%} {results['PSNR']:<10.3f} {psnr_drop:<10.3f} {results['SSIM']:<10.4f} {status:<15}")

print("\n" + "‚îÄ"*60)

if len(acceptable) > 0:
    # Choose highest acceptable sparsity (most compression)
    best_sparsity = max(acceptable.keys())
    best_results = acceptable[best_sparsity]
    psnr_drop = baseline_psnr - best_results['PSNR']

    print(f"‚úÖ SELECTED: {best_sparsity:.1%} sparsity")
    print(f"   PSNR: {best_results['PSNR']:.3f} dB (drop: {psnr_drop:.3f} dB)")
    print(f"   SSIM: {best_results['SSIM']:.4f}")
    print(f"   Quality: {'Excellent' if psnr_drop <= 0.5 else 'Good'}")
else:
    # Use minimum sparsity
    best_sparsity = min(pruning_results_v2.keys())
    best_results = pruning_results_v2[best_sparsity]
    psnr_drop = baseline_psnr - best_results['PSNR']

    print(f"‚ö†Ô∏è  SELECTED: {best_sparsity:.1%} sparsity (minimum)")
    print(f"   PSNR: {best_results['PSNR']:.3f} dB (drop: {psnr_drop:.3f} dB)")
    print(f"   Note: All sparsity levels had high degradation")

print("="*60)

# ============================================================
# STEP 4: FINAL EVALUATION ON FULL DATASET
# ============================================================

print("\n" + "="*60)
print("STEP 4: FULL EVALUATION ON UCF-101")
print("="*60)

print(f"\nCreating final pruned model ({best_sparsity:.1%} sparsity)...")
final_masks = global_channel_pruning_v2(importance_dict_v2, best_sparsity)
final_pruned_model = apply_soft_pruning_v2(model, final_masks)
final_pruned_model.device()

print(f"\nEvaluating on full UCF-101 dataset (379 sequences)...")
print("‚è±Ô∏è  This will take 5-10 minutes...")

final_results = quick_evaluate_v2(final_pruned_model, UCF_PATH, num_samples=379)

# ============================================================
# STEP 5: FINAL REPORT
# ============================================================

print("\n" + "="*60)
print("FINAL RESULTS - CORRECTED PRUNING")
print("="*60)

print(f"\nüìä BASELINE MODEL:")
print(f"   PSNR: {baseline_psnr:.3f} dB")
print(f"   SSIM: {baseline_ssim:.4f}")
print(f"   Parameters: {count_parameters(model):,} ({count_parameters(model)/1e6:.2f}M)")

print(f"\nüìä PRUNED MODEL (Sparsity: {best_sparsity:.1%}):")
psnr_drop_final = baseline_psnr - final_results['PSNR']
ssim_drop_final = baseline_ssim - final_results['SSIM']

print(f"   PSNR: {final_results['PSNR']:.3f} dB (Œî: {psnr_drop_final:.3f} dB)")
print(f"   SSIM: {final_results['SSIM']:.4f} (Œî: {ssim_drop_final:.4f})")
print(f"   Inference: {final_results['Inference_Time_ms']:.2f} ms")
print(f"   FPS: {final_results['FPS']:.1f}")
print(f"   Est. Param Reduction: ~{best_sparsity*100:.0f}%")

# Quality assessment
if psnr_drop_final <= 0.5:
    quality_verdict = "‚úÖ EXCELLENT - No fine-tuning needed"
elif psnr_drop_final <= 1.0:
    quality_verdict = "‚úÖ GOOD - Optional fine-tuning"
elif psnr_drop_final <= 2.0:
    quality_verdict = "‚ö†Ô∏è  ACCEPTABLE - Fine-tuning recommended"
else:
    quality_verdict = "‚ùå POOR - Fine-tuning required"

print(f"\n{quality_verdict}")

# Save results
results_summary = {
    'baseline': {
        'PSNR': baseline_psnr,
        'SSIM': baseline_ssim,
        'Parameters': count_parameters(model)
    },
    'pruned_corrected': {
        'sparsity': best_sparsity,
        'PSNR': final_results['PSNR'],
        'SSIM': final_results['SSIM'],
        'PSNR_drop': psnr_drop_final,
        'SSIM_drop': ssim_drop_final,
        'Inference_Time_ms': final_results['Inference_Time_ms'],
        'FPS': final_results['FPS']
    },
    'all_experiments': {str(k): v for k, v in pruning_results_v2.items()}
}

with open('/content/pruning_results_corrected.json', 'w') as f:
    json.dump(results_summary, f, indent=2)

print(f"\nüíæ Results saved to '/content/pruning_results_corrected.json'")

# Save model
torch.save({
    'flownet_state_dict': final_pruned_model.flownet.state_dict(),
    'sparsity': best_sparsity,
    'masks': {k: v.numpy() for k, v in final_masks.items()},
    'results': final_results
}, '/content/pruned_model_corrected.pth')

print(f"üíæ Model saved to '/content/pruned_model_corrected.pth'")

print("\n" + "="*60)
print("‚úÖ CORRECTED PRUNING COMPLETE!")
print("="*60)

# Next steps recommendation
print("\nüìã NEXT STEPS:")
if psnr_drop_final <= 1.0:
    print("   1. ‚úÖ Proceed to quantization (quality is good)")
    print("   2. Optional: Fine-tune for further improvement")
    print("   3. Implement video processing pipeline")
elif psnr_drop_final <= 2.0:
    print("   1. ‚ö†Ô∏è  Fine-tune the pruned model (recommended)")
    print("   2. Then proceed to quantization")
    print("   3. Implement video processing pipeline")
else:
    print("   1. ‚ùå Fine-tune the pruned model (required)")
    print("   2. Re-evaluate after fine-tuning")
    print("   3. Then consider quantization")

print("="*60)

STEP 1: DIAGNOSTIC - COMPARE OLD VS NEW PRUNING

Computing channel importance...
‚úÖ Analyzed 57 convolutional layers
   Importance range: [1.89, 188.22], Mean: 73.84

COMPARING METHODS AT 30% SPARSITY

üî¥ OLD METHOD (inverted logic):
Sparsity: 30.0%, Threshold: 40.140980
Actual pruning: 2058/6858 channels (30.0%)
Quick evaluation on 20 samples...




   PSNR: 25.699 dB
   SSIM: 0.8843

‚úÖ NEW METHOD (corrected logic):
   Sparsity: 30.0%
   Total channels: 6858
   Channels to keep: 4800
   Threshold: 40.152420
   Logic: KEEP channels with importance >= 40.152420
   Result: 4800/6858 channels kept (70.0%)
   Actual sparsity: 30.0%
   Applied masks to 57 layers
   Zeroed 2058/6858 channels (30.0%)




   PSNR: 25.699 dB
   SSIM: 0.8843

DIAGNOSTIC RESULT:
‚ö†Ô∏è  MINIMAL DIFFERENCE: -0.00 dB
   Issue might be elsewhere

STEP 2: CORRECTED PRUNING EXPERIMENTS

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Testing sparsity: 10.0%
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
   Sparsity: 10.0%
   Total channels: 6858
   Channels to keep: 6172
   Threshold: 21.201897
   Logic: KEEP channels with importance >= 21.201897
   Result: 6172/6858 channels kept (90.0%)
   Actual sparsity: 10.0%
   Applied masks to 57 layers
   Zeroed 686/6858 channels (10.0%)

Evaluating on 50 samples...





üìä Results:
   PSNR: 19.370 dB (drop: 15.922 dB)
   SSIM: 0.7133 (drop: 0.2557)
   Inference: 9.91 ms (100.9 FPS)

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Testing sparsity: 20.0%
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
   Sparsity: 20.0%
   Total channels: 6858
   Channels to keep: 5486
   Threshold: 30.632299
   Logic: KEEP channels with importance >= 30.632299
   Result: 5486/6858 channels kept (80.0%)
   Actual sparsity: 20.0%
   Applied masks to 57 layers
   Zeroed 1372/6858 channels (20.0%)

Evaluating on 50 samples...





üìä Results:
   PSNR: 19.840 dB (drop: 15.452 dB)
   SSIM: 0.7401 (drop: 0.2289)
   Inference: 10.17 ms (98.3 FPS)

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Testing sparsity: 30.0%
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
   Sparsity: 30.0%
   Total channels: 6858
   Channels to keep: 4800
   Threshold: 40.152420
   Logic: KEEP channels with importance >= 40.152420
   Result: 4800/6858 channels kept (70.0%)
   Actual sparsity: 30.0%
   Applied masks to 57 layers
   Zeroed 2058/6858 channels (30.0%)

Evaluating on 50 samples...





üìä Results:
   PSNR: 25.319 dB (drop: 9.973 dB)
   SSIM: 0.8772 (drop: 0.0918)
   Inference: 12.38 ms (80.8 FPS)

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Testing sparsity: 40.0%
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
   Sparsity: 40.0%
   Total channels: 6858
   Channels to keep: 4114
   Threshold: 50.451553
   Logic: KEEP channels with importance >= 50.451553
   Result: 4114/6858 channels kept (60.0%)
   Actual sparsity: 40.0%
   Applied masks to 57 layers
   Zeroed 2744/6858 channels (40.0%)

Evaluating on 50 samples...





üìä Results:
   PSNR: 31.874 dB (drop: 3.418 dB)
   SSIM: 0.9603 (drop: 0.0087)
   Inference: 10.10 ms (99.0 FPS)

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Testing sparsity: 50.0%
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
   Sparsity: 50.0%
   Total channels: 6858
   Channels to keep: 3429
   Threshold: 73.872589
   Logic: KEEP channels with importance >= 73.872589
   Result: 3429/6858 channels kept (50.0%)
   Actual sparsity: 50.0%
   Applied masks to 57 layers
   Zeroed 3429/6858 channels (50.0%)

Evaluating on 50 samples...





üìä Results:
   PSNR: 32.361 dB (drop: 2.931 dB)
   SSIM: 0.9622 (drop: 0.0068)
   Inference: 10.05 ms (99.5 FPS)

STEP 3: SELECTING OPTIMAL SPARSITY

Sparsity   PSNR       Drop       SSIM       Status         
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
10.0%      19.370     15.922     0.7133     ‚ùå Too high     
20.0%      19.840     15.452     0.7401     ‚ùå Too high     
30.0%      25.319     9.973      0.8772     ‚ùå Too high     
40.0%      31.874     3.418      0.9603     ‚ùå Too high     
50.0%      32.361     2.931      0.9622     ‚ùå Too high     

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
‚ö†Ô∏è  SELECTED: 10.0% sparsity (minimum)
   PSNR: 19.370 dB (drop: 15.922 dB)
   Note: All sparsity levels had h




FINAL RESULTS - CORRECTED PRUNING

üìä BASELINE MODEL:
   PSNR: 35.292 dB
   SSIM: 0.9690
   Parameters: 10,708,215 (10.71M)

üìä PRUNED MODEL (Sparsity: 10.0%):
   PSNR: 18.946 dB (Œî: 16.346 dB)
   SSIM: 0.6772 (Œî: 0.2918)
   Inference: 10.69 ms
   FPS: 93.6
   Est. Param Reduction: ~10%

‚ùå POOR - Fine-tuning required

üíæ Results saved to '/content/pruning_results_corrected.json'
üíæ Model saved to '/content/pruned_model_corrected.pth'

‚úÖ CORRECTED PRUNING COMPLETE!

üìã NEXT STEPS:
   1. ‚ùå Fine-tune the pruned model (required)
   2. Re-evaluate after fine-tuning
   3. Then consider quantization


In [None]:
# ============================================================
# DIAGNOSTIC: Analyze Layer-wise Pruning Distribution
# ============================================================

import matplotlib.pyplot as plt

def analyze_pruning_distribution(importance_dict, sparsity_levels=[0.1, 0.3, 0.5]):
    """
    Analyze which layers get pruned at different sparsity levels
    """
    print("="*60)
    print("ANALYZING LAYER-WISE PRUNING DISTRIBUTION")
    print("="*60)

    # Get all importance scores
    all_importance = torch.cat([imp for imp in importance_dict.values()])
    layer_names = list(importance_dict.keys())

    for sparsity in sparsity_levels:
        print(f"\n{'‚îÄ'*60}")
        print(f"Sparsity: {sparsity:.1%}")
        print(f"{'‚îÄ'*60}")

        # Calculate threshold
        num_to_keep = int(len(all_importance) * (1 - sparsity))
        threshold = torch.sort(all_importance, descending=True)[0][num_to_keep - 1]

        print(f"Threshold: {threshold:.2f}")
        print(f"\nLayer-wise pruning:")
        print(f"{'Layer':<40} {'Channels':<12} {'Pruned':<12} {'Prune %':<10}")
        print("‚îÄ"*60)

        for name, importance in importance_dict.items():
            total_channels = len(importance)
            pruned_channels = (importance < threshold).sum().item()
            prune_pct = (pruned_channels / total_channels) * 100

            # Highlight if heavily pruned
            indicator = "‚ö†Ô∏è " if prune_pct > 50 else "  "

            # Check if it's an early block
            is_early = any(x in name for x in ['block0', 'down0', 'conv0', 'encoder.0', 'encoder.1'])
            layer_label = f"{name:<38}"
            if is_early:
                layer_label = f"üî¥ {name:<36}"  # Mark early layers

            print(f"{indicator}{layer_label} {total_channels:<12} {pruned_channels:<12} {prune_pct:<10.1f}%")

        print()

# Run the analysis
print("\nAnalyzing pruning distribution across layers...\n")
analyze_pruning_distribution(importance_dict_v2, sparsity_levels=[0.1, 0.3, 0.5])


Analyzing pruning distribution across layers...

ANALYZING LAYER-WISE PRUNING DISTRIBUTION

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Sparsity: 10.0%
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Threshold: 21.20

Layer-wise pruning:
Layer                                    Channels     Pruned       Prune %   
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
‚ö†Ô∏è üî¥ block0.conv0.0.0                     120          120          100.0     %
  üî¥ block0.conv0.1.0                     240          0            0.0       %
  üî¥ block0.convblock.0.0                 240          0     

In [None]:
# ============================================================
# FINAL SOLUTION: PROTECTED LAYER-WISE PRUNING (FIXED)
# ============================================================

def protected_layerwise_pruning(importance_dict, sparsity=0.3,
                                min_channels_to_keep=8,
                                protected_keywords=['conv0.0.0', 'down0', 'unet.conv']):
    """
    Layer-wise pruning with protection for critical small layers.

    Args:
        importance_dict: Channel importance per layer
        sparsity: Target sparsity for each layer (but respects min_channels_to_keep)
        min_channels_to_keep: Minimum channels to keep per layer
        protected_keywords: Layers to protect completely (0% pruning)

    Returns:
        prune_masks: Boolean masks per layer
    """
    print("="*60)
    print("PROTECTED LAYER-WISE PRUNING")
    print(f"  Target sparsity per layer: {sparsity:.1%}")
    print(f"  Min channels to keep: {min_channels_to_keep}")
    print(f"  Protected keywords: {protected_keywords}")
    print("="*60)

    prune_masks = {}
    total_channels = 0
    total_pruned = 0

    protected_count = 0
    modified_count = 0

    for name, importance in importance_dict.items():
        num_channels = len(importance)
        total_channels += num_channels

        # Check if layer is protected
        is_protected = any(keyword in name for keyword in protected_keywords)

        if is_protected:
            # Keep ALL channels in protected layers
            mask = torch.ones_like(importance, dtype=torch.bool)
            prune_masks[name] = mask
            protected_count += 1
            print(f"  üõ°Ô∏è  PROTECTED: {name} ({num_channels} channels)")

        else:
            # Apply layer-wise pruning with minimum guarantee
            num_to_keep = max(min_channels_to_keep, int(num_channels * (1 - sparsity)))
            num_to_keep = min(num_to_keep, num_channels)  # Can't keep more than exist

            if num_to_keep >= num_channels:
                # Keep all if we're at minimum
                mask = torch.ones_like(importance, dtype=torch.bool)
            else:
                # Get threshold for this layer
                threshold = torch.sort(importance, descending=True)[0][num_to_keep - 1]
                mask = (importance >= threshold)

            prune_masks[name] = mask
            pruned = (~mask).sum().item()
            total_pruned += pruned

            if pruned > 0:
                modified_count += 1
                prune_pct = (pruned / num_channels) * 100
                if prune_pct > 50:
                    print(f"  ‚ö†Ô∏è  {name}: {num_channels-pruned}/{num_channels} kept ({prune_pct:.1f}% pruned)")

    actual_sparsity = total_pruned / total_channels
    print("\n" + "‚îÄ"*60)
    print("Summary:")
    print(f"  Protected layers: {protected_count}")
    print(f"  Modified layers: {modified_count}")
    print(f"  Total: {total_pruned}/{total_channels} channels pruned")
    print(f"  Actual sparsity: {actual_sparsity:.1%}")
    print("="*60)

    return prune_masks


# ============================================================
# TEST PROTECTED LAYER-WISE PRUNING
# ============================================================

print("\n" + "="*60)
print("TESTING PROTECTED LAYER-WISE PRUNING")
print("="*60)

# Define which layers to protect (all first conv layers + small critical layers)
protected_layers = [
    'conv0.0.0',      # All first convs in blocks
    'down0',          # First downsampling layers
    'conv1.conv',     # Very small context layers
    'unet.conv',      # Final 3-channel conv
]

# Test different sparsity levels
sparsity_levels_protected = [0.2, 0.3, 0.4, 0.5, 0.6]
results_protected = {}

for sparsity in sparsity_levels_protected:
    print(f"\n{'='*60}")
    print(f"Testing Protected Layer-wise Sparsity: {sparsity:.1%}")
    print(f"{'='*60}")

    # Create protected layer-wise masks
    masks_protected = protected_layerwise_pruning(
        importance_dict_v2,
        sparsity=sparsity,
        min_channels_to_keep=8,  # Never go below 8 channels
        protected_keywords=protected_layers
    )

    # Apply pruning
    model_protected = apply_soft_pruning_v2(model, masks_protected)
    model_protected.device()

    # Evaluate
    print(f"\nEvaluating on 50 samples...")
    results = quick_evaluate_v2(model_protected, UCF_PATH, num_samples=50)
    results_protected[sparsity] = results

    psnr_drop = 35.292 - results['PSNR']
    ssim_drop = 0.9690 - results['SSIM']

    print(f"\nüìä Results:")
    print(f"   PSNR: {results['PSNR']:.3f} dB (drop: {psnr_drop:.3f} dB)")
    print(f"   SSIM: {results['SSIM']:.4f} (drop: {ssim_drop:.4f})")
    print(f"   Inference: {results['Inference_Time_ms']:.2f} ms ({results['FPS']:.1f} FPS)")

    # Quality assessment
    if psnr_drop <= 0.5:
        status = "‚úÖ EXCELLENT"
    elif psnr_drop <= 1.0:
        status = "‚úÖ GOOD"
    elif psnr_drop <= 2.0:
        status = "‚ö†Ô∏è  ACCEPTABLE"
    else:
        status = "‚ùå POOR"

    print(f"   Quality: {status}")

    # Cleanup
    del model_protected
    torch.cuda.empty_cache()

# ============================================================
# FINAL COMPARISON
# ============================================================

print("\n" + "="*60)
print("COMPARISON: Global vs Protected Layer-wise")
print("="*60)

print(f"\n{'Sparsity':<12} {'Global':<15} {'Protected':<15} {'Improvement':<15}")
print("‚îÄ"*60)

for sparsity in [0.2, 0.3, 0.4, 0.5]:
    if sparsity in pruning_results_v2:
        global_psnr = pruning_results_v2[sparsity]['PSNR']
    else:
        global_psnr = 0

    if sparsity in results_protected:
        protected_psnr = results_protected[sparsity]['PSNR']
    else:
        protected_psnr = 0

    if global_psnr > 0 and protected_psnr > 0:
        improvement = protected_psnr - global_psnr
        indicator = "‚úÖ MAJOR" if improvement > 10 else ("‚úÖ" if improvement > 5 else "‚ö†Ô∏è ")
        print(f"{sparsity:<12.1%} {global_psnr:<15.3f} {protected_psnr:<15.3f} {indicator} {improvement:+.2f} dB")

# ============================================================
# SELECT BEST AND EVALUATE ON FULL DATASET
# ============================================================

print("\n" + "="*60)
print("SELECTING BEST CONFIGURATION")
print("="*60)

# Find best sparsity with acceptable quality
acceptable_protected = {s: r for s, r in results_protected.items()
                       if (35.292 - r['PSNR']) <= 1.5}

if len(acceptable_protected) > 0:
    best_sparsity_protected = max(acceptable_protected.keys())
    best_results = acceptable_protected[best_sparsity_protected]

    print(f"\n‚úÖ SELECTED: {best_sparsity_protected:.1%} sparsity")
    print(f"   PSNR: {best_results['PSNR']:.3f} dB")
    print(f"   Drop: {35.292 - best_results['PSNR']:.3f} dB")

    # Full evaluation
    print(f"\n{'='*60}")
    print("FULL EVALUATION ON UCF-101")
    print(f"{'='*60}")
    print(f"Creating final model with {best_sparsity_protected:.1%} sparsity...")

    final_masks = protected_layerwise_pruning(
        importance_dict_v2,
        sparsity=best_sparsity_protected,
        min_channels_to_keep=8,
        protected_keywords=protected_layers
    )

    final_model = apply_soft_pruning_v2(model, final_masks)
    final_model.device()

    # Calculate FLOPs for the final model
    from thop import profile, clever_format
    try:
        sample_img0 = torch.randn(1, 3, 256, 256).to(next(final_model.parameters()).device)
        sample_img1 = torch.randn(1, 3, 256, 256).to(next(final_model.parameters()).device)
        with torch.no_grad():
            flops, params = profile(final_model.flownet, inputs=(torch.cat([sample_img0, sample_img1], 1),), verbose=False)
        flops_str, params_str = clever_format([flops, params], "%.3f")
        print(f"\nFLOPs (256x256 input): {flops_str}")
        print(f"Parameters (from profiler): {params_str}")
    except Exception as e:
        print(f"\nFLOPs calculation failed: {e}")
        flops_str = "N/A"


    print(f"\nEvaluating on 379 sequences...")
    print("‚è±Ô∏è  This will take 5-10 minutes...")

    final_results_full = quick_evaluate_v2(final_model, UCF_PATH, num_samples=379)

    # Final report
    print("\n" + "="*60)
    print("FINAL RESULTS - PROTECTED LAYER-WISE PRUNING")
    print("="*60)

    print(f"\nüìä BASELINE:")
    print(f"   PSNR: 35.292 dB")
    print(f"   SSIM: 0.9690")
    # Assuming baseline FLOPs is needed for comparison - get from the first cell output
    baseline_flops_str = "11.684G"
    print(f"   FLOPs (256x256): {baseline_flops_str}")


    print(f"\nüìä OPTIMIZED ({best_sparsity_protected:.1%} sparsity):")
    psnr_drop_final = 35.292 - final_results_full['PSNR']
    print(f"   PSNR: {final_results_full['PSNR']:.3f} dB (Œî: {psnr_drop_final:.3f} dB)")
    print(f"   SSIM: {final_results_full['SSIM']:.4f}")
    print(f"   Inference: {final_results_full['Inference_Time_ms']:.2f} ms")
    print(f"   FPS: {final_results_full['FPS']:.1f}")
    print(f"   FLOPs (256x256): {flops_str}")


    if psnr_drop_final <= 1.0:
        print(f"\n‚úÖ SUCCESS! Quality degradation is acceptable")
        print(f"   Proceed to quantization or video processing")
    elif psnr_drop_final <= 2.0:
        print(f"\n‚ö†Ô∏è  ACCEPTABLE. Fine-tuning recommended")
    else:
        print(f"\n‚ùå Quality drop too high. Needs fine-tuning")

    # Save
    torch.save({
        'flownet_state_dict': final_model.flownet.state_dict(),
        'sparsity': best_sparsity_protected,
        'masks': {k: v.numpy() for k, v in final_masks.items()},
        'results': final_results_full,
        'method': 'protected_layerwise',
        'flops': flops_str
    }, '/content/pruned_model_FINAL.pth')

    print(f"\nüíæ Model saved to '/content/pruned_model_FINAL.pth'")

else:
    print("\n‚ùå No acceptable configuration found")
    print("   All sparsity levels degrade quality too much")
    print("   Recommendation: Use lower sparsity or different pruning strategy")

print("\n" + "="*60)
print("‚úÖ PROTECTED LAYER-WISE PRUNING COMPLETE!")
print("="*60)


TESTING PROTECTED LAYER-WISE PRUNING

Testing Protected Layer-wise Sparsity: 20.0%
PROTECTED LAYER-WISE PRUNING
  Target sparsity per layer: 20.0%
  Min channels to keep: 8
  Protected keywords: ['conv0.0.0', 'down0', 'conv1.conv', 'unet.conv']
  üõ°Ô∏è  PROTECTED: block0.conv0.0.0 (120 channels)
  üõ°Ô∏è  PROTECTED: block1.conv0.0.0 (75 channels)
  üõ°Ô∏è  PROTECTED: block2.conv0.0.0 (45 channels)
  üõ°Ô∏è  PROTECTED: block_tea.conv0.0.0 (45 channels)
  üõ°Ô∏è  PROTECTED: contextnet.conv1.conv1.0 (16 channels)
  üõ°Ô∏è  PROTECTED: contextnet.conv1.conv2.0 (16 channels)
  üõ°Ô∏è  PROTECTED: unet.down0.conv1.0 (32 channels)
  üõ°Ô∏è  PROTECTED: unet.down0.conv2.0 (32 channels)
  üõ°Ô∏è  PROTECTED: unet.conv (3 channels)

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Summary:
  Protected layers: 9
  Modified layers: 48
  Total: 1300/6858 channel




üìä Results:
   PSNR: 31.839 dB (drop: 3.453 dB)
   SSIM: 0.9612 (drop: 0.0078)
   Inference: 10.12 ms (98.8 FPS)
   Quality: ‚ùå POOR

Testing Protected Layer-wise Sparsity: 30.0%
PROTECTED LAYER-WISE PRUNING
  Target sparsity per layer: 30.0%
  Min channels to keep: 8
  Protected keywords: ['conv0.0.0', 'down0', 'conv1.conv', 'unet.conv']
  üõ°Ô∏è  PROTECTED: block0.conv0.0.0 (120 channels)
  üõ°Ô∏è  PROTECTED: block1.conv0.0.0 (75 channels)
  üõ°Ô∏è  PROTECTED: block2.conv0.0.0 (45 channels)
  üõ°Ô∏è  PROTECTED: block_tea.conv0.0.0 (45 channels)
  üõ°Ô∏è  PROTECTED: contextnet.conv1.conv1.0 (16 channels)
  üõ°Ô∏è  PROTECTED: contextnet.conv1.conv2.0 (16 channels)
  üõ°Ô∏è  PROTECTED: unet.down0.conv1.0 (32 channels)
  üõ°Ô∏è  PROTECTED: unet.down0.conv2.0 (32 channels)
  üõ°Ô∏è  PROTECTED: unet.conv (3 channels)

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î




üìä Results:
   PSNR: 31.910 dB (drop: 3.382 dB)
   SSIM: 0.9634 (drop: 0.0056)
   Inference: 12.36 ms (80.9 FPS)
   Quality: ‚ùå POOR

Testing Protected Layer-wise Sparsity: 40.0%
PROTECTED LAYER-WISE PRUNING
  Target sparsity per layer: 40.0%
  Min channels to keep: 8
  Protected keywords: ['conv0.0.0', 'down0', 'conv1.conv', 'unet.conv']
  üõ°Ô∏è  PROTECTED: block0.conv0.0.0 (120 channels)
  üõ°Ô∏è  PROTECTED: block1.conv0.0.0 (75 channels)
  üõ°Ô∏è  PROTECTED: block2.conv0.0.0 (45 channels)
  üõ°Ô∏è  PROTECTED: block_tea.conv0.0.0 (45 channels)
  üõ°Ô∏è  PROTECTED: contextnet.conv1.conv1.0 (16 channels)
  üõ°Ô∏è  PROTECTED: contextnet.conv1.conv2.0 (16 channels)
  üõ°Ô∏è  PROTECTED: unet.down0.conv1.0 (32 channels)
  üõ°Ô∏è  PROTECTED: unet.down0.conv2.0 (32 channels)
  üõ°Ô∏è  PROTECTED: unet.conv (3 channels)

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î




üìä Results:
   PSNR: 32.173 dB (drop: 3.119 dB)
   SSIM: 0.9633 (drop: 0.0057)
   Inference: 9.97 ms (100.3 FPS)
   Quality: ‚ùå POOR

Testing Protected Layer-wise Sparsity: 50.0%
PROTECTED LAYER-WISE PRUNING
  Target sparsity per layer: 50.0%
  Min channels to keep: 8
  Protected keywords: ['conv0.0.0', 'down0', 'conv1.conv', 'unet.conv']
  üõ°Ô∏è  PROTECTED: block0.conv0.0.0 (120 channels)
  üõ°Ô∏è  PROTECTED: block1.conv0.0.0 (75 channels)
  üõ°Ô∏è  PROTECTED: block2.conv0.0.0 (45 channels)
  üõ°Ô∏è  PROTECTED: block_tea.conv0.0.0 (45 channels)
  üõ°Ô∏è  PROTECTED: contextnet.conv1.conv1.0 (16 channels)
  üõ°Ô∏è  PROTECTED: contextnet.conv1.conv2.0 (16 channels)
  üõ°Ô∏è  PROTECTED: unet.down0.conv1.0 (32 channels)
  üõ°Ô∏è  PROTECTED: unet.down0.conv2.0 (32 channels)
  üõ°Ô∏è  PROTECTED: unet.conv (3 channels)

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î




üìä Results:
   PSNR: 32.413 dB (drop: 2.879 dB)
   SSIM: 0.9646 (drop: 0.0044)
   Inference: 10.09 ms (99.1 FPS)
   Quality: ‚ùå POOR

Testing Protected Layer-wise Sparsity: 60.0%
PROTECTED LAYER-WISE PRUNING
  Target sparsity per layer: 60.0%
  Min channels to keep: 8
  Protected keywords: ['conv0.0.0', 'down0', 'conv1.conv', 'unet.conv']
  üõ°Ô∏è  PROTECTED: block0.conv0.0.0 (120 channels)
  ‚ö†Ô∏è  block0.conv0.1.0: 96/240 kept (60.0% pruned)
  ‚ö†Ô∏è  block0.convblock.0.0: 96/240 kept (60.0% pruned)
  ‚ö†Ô∏è  block0.convblock.1.0: 96/240 kept (60.0% pruned)
  ‚ö†Ô∏è  block0.convblock.2.0: 96/240 kept (60.0% pruned)
  ‚ö†Ô∏è  block0.convblock.3.0: 96/240 kept (60.0% pruned)
  ‚ö†Ô∏è  block0.convblock.4.0: 96/240 kept (60.0% pruned)
  ‚ö†Ô∏è  block0.convblock.5.0: 96/240 kept (60.0% pruned)
  ‚ö†Ô∏è  block0.convblock.6.0: 96/240 kept (60.0% pruned)
  ‚ö†Ô∏è  block0.convblock.7.0: 96/240 kept (60.0% pruned)
  üõ°Ô∏è  PROTECTED: block1.conv0.0.0 (75 channels)
  ‚ö†Ô∏è  block1.conv

                                                           


üìä Results:
   PSNR: 32.700 dB (drop: 2.592 dB)
   SSIM: 0.9646 (drop: 0.0044)
   Inference: 10.45 ms (95.7 FPS)
   Quality: ‚ùå POOR

COMPARISON: Global vs Protected Layer-wise

Sparsity     Global          Protected       Improvement    
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
20.0%        19.840          31.839          ‚úÖ MAJOR +12.00 dB
30.0%        25.319          31.910          ‚úÖ +6.59 dB
40.0%        31.874          32.173          ‚ö†Ô∏è  +0.30 dB
50.0%        32.361          32.413          ‚ö†Ô∏è  +0.05 dB

SELECTING BEST CONFIGURATION

‚ùå No acceptable configuration found
   All sparsity levels degrade quality too much
   Recommendation: Use lower sparsity or different pruning strategy

‚úÖ PROTECTED LAYER-WISE PRUNING COMPLETE!




### ATTEMPT 1

In [None]:
# ============================================================
# COMPLETE CORRECTED PRUNING IMPLEMENTATION + TABLE EXPORT
# ============================================================

import torch
import torch.nn as nn
import numpy as np
import copy
import os
from tqdm import tqdm
import time
import json
import math
import pandas as pd
from IPython.display import display, Markdown

# ============================================================
# CORRECTED HELPER FUNCTIONS (as provided)
# ============================================================

def compute_channel_importance_v2(module, layer_types=[nn.Conv2d]):
    """
    Compute L1-norm importance for each channel (VERIFIED CORRECT)
    """
    importance_dict = {}

    for name, layer in module.named_modules():
        if isinstance(layer, tuple(layer_types)):
            weight = layer.weight.data
            # L1 norm per output channel: sum over [in_channels, H, W]
            importance = torch.norm(weight.view(weight.size(0), -1), p=1, dim=1)
            importance_dict[name] = importance.cpu()

    print(f"‚úÖ Analyzed {len(importance_dict)} convolutional layers")

    # Debug statistics
    all_imp = torch.cat([imp for imp in importance_dict.values()])
    print(f"   Importance range: [{all_imp.min():.2f}, {all_imp.max():.2f}], Mean: {all_imp.mean():.2f}")

    return importance_dict


def global_channel_pruning_v2(importance_dict, sparsity=0.5):
    """
    CORRECTED: Prune channels with SMALLEST importance (least important)

    Args:
        importance_dict: Channel importance scores
        sparsity: Fraction of channels to REMOVE (0-1)

    Returns:
        prune_masks: Boolean masks (True = KEEP, False = PRUNE)
    """
    # Flatten all importance scores
    all_importance = torch.cat([imp for imp in importance_dict.values()])
    total_channels = len(all_importance)

    # Calculate how many to KEEP
    num_to_keep = int(total_channels * (1 - sparsity))

    # Sort importance in DESCENDING order and get threshold
    # Channels with importance >= threshold will be KEPT
    sorted_importance = torch.sort(all_importance, descending=True)[0]
    threshold = sorted_importance[num_to_keep - 1] if num_to_keep > 0 else sorted_importance[-1]

    print(f"   Sparsity: {sparsity:.1%}")
    print(f"   Total channels: {total_channels}")
    print(f"   Channels to keep: {num_to_keep}")
    print(f"   Threshold: {threshold:.6f}")
    print(f"   Logic: KEEP channels with importance >= {threshold:.6f}")

    # Create masks: True = KEEP, False = PRUNE
    prune_masks = {}
    total_kept = 0

    for name, importance in importance_dict.items():
        # CORRECTED: Keep channels with importance >= threshold
        mask = (importance >= threshold)
        prune_masks[name] = mask
        total_kept += mask.sum().item()

    actual_sparsity = 1 - (total_kept / total_channels)
    actual_kept_pct = (total_kept / total_channels) * 100

    print(f"   Result: {total_kept}/{total_channels} channels kept ({actual_kept_pct:.1f}%)")
    print(f"   Actual sparsity: {actual_sparsity:.1%}")

    return prune_masks


def apply_soft_pruning_v2(model, prune_masks):
    """
    Apply soft pruning by zeroing out weights (VERIFIED CORRECT)

    NOTE: This function creates a fresh RIFEModel instance and loads weights from 'train_log'.
    If you want to apply masks directly to an in-memory model, adjust accordingly.
    """
    # Import inside function to avoid errors if path differs
    try:
        from model.RIFE import Model as RIFEModel
    except Exception as e:
        raise RuntimeError("Could not import RIFEModel from model.RIFE. Make sure your project path is correct.") from e

    # Create fresh model instance
    pruned_model = RIFEModel()
    pruned_model.load_model('train_log')
    pruned_model.eval()

    # Apply masks
    total_channels = 0
    zeroed_channels = 0
    layers_modified = 0

    for name, module in pruned_model.flownet.named_modules():
        if name in prune_masks:
            mask = prune_masks[name]
            # Some modules may be on CPU - move mask accordingly
            if any(p is None for p in module.parameters()):
                device = torch.device('cpu')
            else:
                device = next(module.parameters()).device

            # Create weight mask: expand to [out_ch, in_ch, H, W]
            weight_mask = mask.view(-1, 1, 1, 1).expand_as(module.weight).to(device)

            # Apply mask: multiply weights by mask (zeros out pruned channels)
            module.weight.data *= weight_mask

            # Apply to bias if exists
            if module.bias is not None:
                bias_mask = mask.to(device)
                module.bias.data *= bias_mask

            # Track statistics
            zeroed = (~mask).sum().item()
            total = len(mask)
            zeroed_channels += zeroed
            total_channels += total
            layers_modified += 1

    if total_channels == 0:
        print("‚ö†Ô∏è  Warning: No layers matched the provided masks.")
    else:
        print(f"   Applied masks to {layers_modified} layers")
        print(f"   Zeroed {zeroed_channels}/{total_channels} channels ({zeroed_channels/total_channels*100:.1f}%)")

    return pruned_model


def quick_evaluate_v2(model, dataset_path, num_samples=50):
    """
    Quick evaluation (same as before, just renamed for consistency)
    """
    import cv2
    from model.pytorch_msssim import ssim_matlab

    model.eval()
    # if model has flownet use its parameters for device detection, else CPU
    try:
        device = next(model.flownet.parameters()).device
    except Exception:
        device = torch.device('cpu')

    dirs = sorted(os.listdir(dataset_path))[:num_samples]
    psnr_list, ssim_list, time_list = [], [], []

    for d in tqdm(dirs, desc="Evaluating", leave=False):
        img0_path = os.path.join(dataset_path, d, 'frame_00.png')
        img1_path = os.path.join(dataset_path, d, 'frame_02.png')
        gt_path = os.path.join(dataset_path, d, 'frame_01_gt.png')

        if not all(map(os.path.exists, [img0_path, img1_path, gt_path])):
            continue

        img0 = torch.tensor(cv2.imread(img0_path).transpose(2,0,1)/255.).float().unsqueeze(0).to(device)
        img1 = torch.tensor(cv2.imread(img1_path).transpose(2,0,1)/255.).float().unsqueeze(0).to(device)
        gt = torch.tensor(cv2.imread(gt_path).transpose(2,0,1)/255.).float().unsqueeze(0).to(device)

        if torch.cuda.is_available():
            torch.cuda.synchronize()
        start_time = time.time()

        with torch.no_grad():
            pred = model.inference(img0, img1)[0]

        if torch.cuda.is_available():
            torch.cuda.synchronize()
        time_list.append((time.time() - start_time) * 1000)

        ssim_val = ssim_matlab(gt, torch.round(pred*255).unsqueeze(0)/255).detach().cpu().numpy()
        ssim_list.append(float(ssim_val))

        out = pred.detach().cpu().numpy().transpose(1,2,0)
        out = np.round(out*255)/255.
        gt_np = gt[0].cpu().numpy().transpose(1,2,0)
        mse = ((gt_np - out)**2).mean()
        psnr = -10 * math.log10(mse + 1e-8)
        psnr_list.append(psnr)

    return {
        'PSNR': np.mean(psnr_list) if len(psnr_list) > 0 else float('nan'),
        'SSIM': np.mean(ssim_list) if len(ssim_list) > 0 else float('nan'),
        'Inference_Time_ms': np.mean(time_list) if len(time_list) > 0 else float('nan'),
        'FPS': 1000. / np.mean(time_list) if len(time_list) > 0 else float('nan')
    }

# Simple parameter counting helper (if not already defined)
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

# ============================================================
# STEP 1: DIAGNOSTIC - COMPARE OLD VS NEW PRUNING
# ============================================================

print("="*60)
print("STEP 1: DIAGNOSTIC - COMPARE OLD VS NEW PRUNING")
print("="*60)

if 'model' not in globals():
    raise RuntimeError("Model not found. Run baseline cell first or ensure `model` exists in the global namespace.")

UCF_PATH = "/content/drive/MyDrive/UCF-101/ucf101_interp_ours"
baseline_psnr = 35.292
baseline_ssim = 0.9690

# Compute importance with new function
print("\nComputing channel importance...")
importance_dict_v2 = compute_channel_importance_v2(model.flownet)

# Test both methods on same sparsity
sparsity_test = 0.3

print(f"\n{'='*60}")
print(f"COMPARING METHODS AT {sparsity_test:.0%} SPARSITY")
print(f"{'='*60}")

# Try old method if available (wrap in try/except)
results_old = {'PSNR': float('nan'), 'SSIM': float('nan')}
try:
    if 'global_channel_pruning' in globals() and 'apply_soft_pruning' in globals():
        print("\nüî¥ OLD METHOD (inverted logic):")
        masks_old = global_channel_pruning(importance_dict, sparsity_test)  # old function from earlier cell
        model_old = apply_soft_pruning(model, masks_old)
        try:
            model_old.device()
        except Exception:
            pass
        results_old = quick_evaluate(model_old, UCF_PATH, num_samples=20)
        print(f"   PSNR: {results_old['PSNR']:.3f} dB")
        print(f"   SSIM: {results_old['SSIM']:.4f}")
        del model_old
    else:
        print("‚ÑπÔ∏è Old pruning functions not found; skipping old-method diagnostic.")
except Exception as e:
    print(f"‚ö†Ô∏è Skipping old-method diagnostic due to error: {e}")

# New corrected method
print("\n‚úÖ NEW METHOD (corrected logic):")
masks_new = global_channel_pruning_v2(importance_dict_v2, sparsity_test)
model_new = apply_soft_pruning_v2(model, masks_new)
try:
    model_new.device()
except Exception:
    pass
results_new = quick_evaluate_v2(model_new, UCF_PATH, num_samples=20)
print(f"   PSNR: {results_new['PSNR']:.3f} dB")
print(f"   SSIM: {results_new['SSIM']:.4f}")

# Comparison
print(f"\n{'='*60}")
print("DIAGNOSTIC RESULT:")
print(f"{'='*60}")
improvement = results_new['PSNR'] - (results_old['PSNR'] if not math.isnan(results_old['PSNR']) else float('nan'))
if not math.isnan(improvement) and improvement > 5.0:
    print(f"‚úÖ MAJOR IMPROVEMENT: +{improvement:.2f} dB")
    print(f"   Old logic was definitely inverted!")
    print(f"   Proceeding with corrected pruning...")
elif not math.isnan(improvement) and improvement > 1.0:
    print(f"‚úÖ IMPROVEMENT: +{improvement:.2f} dB")
    print(f"   New logic is better")
else:
    print(f"‚ö†Ô∏è  MINIMAL DIFFERENCE: {improvement if not math.isnan(improvement) else 'N/A'} dB")
    print(f"   Issue might be elsewhere")

# Cleanup partial objects if exist
try:
    del model_new
except Exception:
    pass
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# ============================================================
# STEP 2: RUN CORRECTED PRUNING EXPERIMENTS
# ============================================================

print("\n" + "="*60)
print("STEP 2: CORRECTED PRUNING EXPERIMENTS")
print("="*60)

sparsity_levels = [0.1, 0.2, 0.3, 0.4, 0.5]
pruning_results_v2 = {}

for sparsity in sparsity_levels:
    print(f"\n{'‚îÄ'*60}")
    print(f"Testing sparsity: {sparsity:.1%}")
    print(f"{'‚îÄ'*60}")

    # Create masks with corrected method
    prune_masks = global_channel_pruning_v2(importance_dict_v2, sparsity)

    # Apply pruning
    pruned_model = apply_soft_pruning_v2(model, prune_masks)
    try:
        pruned_model.device()
    except Exception:
        pass

    # Evaluate
    print(f"\nEvaluating on 50 samples...")
    results = quick_evaluate_v2(pruned_model, UCF_PATH, num_samples=50)
    pruning_results_v2[sparsity] = results

    # Calculate metrics
    psnr_drop = baseline_psnr - results['PSNR']
    ssim_drop = baseline_ssim - results['SSIM']

    print(f"\nüìä Results:")
    print(f"   PSNR: {results['PSNR']:.3f} dB (drop: {psnr_drop:.3f} dB)")
    print(f"   SSIM: {results['SSIM']:.4f} (drop: {ssim_drop:.4f})")
    print(f"   Inference: {results['Inference_Time_ms']:.2f} ms ({results['FPS']:.1f} FPS)")

    # Clear memory
    try:
        del pruned_model
    except Exception:
        pass
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# ============================================================
# STEP 3: SELECT OPTIMAL SPARSITY
# ============================================================

print("\n" + "="*60)
print("STEP 3: SELECTING OPTIMAL SPARSITY")
print("="*60)

print(f"\n{'Sparsity':<10} {'PSNR':<10} {'Drop':<10} {'SSIM':<10} {'Status':<15}")
print("‚îÄ"*60)

acceptable = {}
for sparsity in sorted(pruning_results_v2.keys()):
    results = pruning_results_v2[sparsity]
    psnr_drop = baseline_psnr - results['PSNR']

    if psnr_drop <= 0.5:
        status = "‚úÖ Excellent"
        acceptable[sparsity] = results
    elif psnr_drop <= 1.0:
        status = "‚úÖ Good"
        acceptable[sparsity] = results
    elif psnr_drop <= 1.5:
        status = "‚ö†Ô∏è  Marginal"
    else:
        status = "‚ùå Too high"

    print(f"{sparsity:<10.1%} {results['PSNR']:<10.3f} {psnr_drop:<10.3f} {results['SSIM']:<10.4f} {status:<15}")

print("\n" + "‚îÄ"*60)

if len(acceptable) > 0:
    # Choose highest acceptable sparsity (most compression)
    best_sparsity = max(acceptable.keys())
    best_results = acceptable[best_sparsity]
    psnr_drop = baseline_psnr - best_results['PSNR']

    print(f"‚úÖ SELECTED: {best_sparsity:.1%} sparsity")
    print(f"   PSNR: {best_results['PSNR']:.3f} dB (drop: {psnr_drop:.3f} dB)")
    print(f"   SSIM: {best_results['SSIM']:.4f}")
    print(f"   Quality: {'Excellent' if psnr_drop <= 0.5 else 'Good'}")
else:
    # Use minimum sparsity
    best_sparsity = min(pruning_results_v2.keys())
    best_results = pruning_results_v2[best_sparsity]
    psnr_drop = baseline_psnr - best_results['PSNR']

    print(f"‚ö†Ô∏è  SELECTED: {best_sparsity:.1%} sparsity (minimum)")
    print(f"   PSNR: {best_results['PSNR']:.3f} dB (drop: {psnr_drop:.3f} dB)")
    print(f"   Note: All sparsity levels had high degradation")

print("="*60)

# ============================================================
# STEP 4: FINAL EVALUATION ON FULL DATASET
# ============================================================

print("\n" + "="*60)
print("STEP 4: FULL EVALUATION ON UCF-101")
print("="*60)

print(f"\nCreating final pruned model ({best_sparsity:.1%} sparsity)...")
final_masks = global_channel_pruning_v2(importance_dict_v2, best_sparsity)
final_pruned_model = apply_soft_pruning_v2(model, final_masks)
try:
    final_pruned_model.device()
except Exception:
    pass

print(f"\nEvaluating on full UCF-101 dataset (379 sequences)...")
print("‚è±Ô∏è  This may take several minutes depending on your runtime and hardware...")

final_results = quick_evaluate_v2(final_pruned_model, UCF_PATH, num_samples=379)

# ============================================================
# STEP 5: FINAL REPORT
# ============================================================

print("\n" + "="*60)
print("FINAL RESULTS - CORRECTED PRUNING")
print("="*60)

print(f"\nüìä BASELINE MODEL:")
print(f"   PSNR: {baseline_psnr:.3f} dB")
print(f"   SSIM: {baseline_ssim:.4f}")
print(f"   Parameters: {count_parameters(model):,} ({count_parameters(model)/1e6:.2f}M)")

print(f"\nüìä PRUNED MODEL (Sparsity: {best_sparsity:.1%}):")
psnr_drop_final = baseline_psnr - final_results['PSNR']
ssim_drop_final = baseline_ssim - final_results['SSIM']

print(f"   PSNR: {final_results['PSNR']:.3f} dB (Œî: {psnr_drop_final:.3f} dB)")
print(f"   SSIM: {final_results['SSIM']:.4f} (Œî: {ssim_drop_final:.4f})")
print(f"   Inference: {final_results['Inference_Time_ms']:.2f} ms")
print(f"   FPS: {final_results['FPS']:.1f}")
print(f"   Est. Param Reduction: ~{best_sparsity*100:.0f}%")

# Quality assessment
if psnr_drop_final <= 0.5:
    quality_verdict = "‚úÖ EXCELLENT - No fine-tuning needed"
elif psnr_drop_final <= 1.0:
    quality_verdict = "‚úÖ GOOD - Optional fine-tuning"
elif psnr_drop_final <= 2.0:
    quality_verdict = "‚ö†Ô∏è  ACCEPTABLE - Fine-tuning recommended"
else:
    quality_verdict = "‚ùå POOR - Fine-tuning required"

print(f"\n{quality_verdict}")

# Save results
results_summary = {
    'baseline': {
        'PSNR': baseline_psnr,
        'SSIM': baseline_ssim,
        'Parameters': count_parameters(model)
    },
    'pruned_corrected': {
        'sparsity': best_sparsity,
        'PSNR': final_results['PSNR'],
        'SSIM': final_results['SSIM'],
        'PSNR_drop': psnr_drop_final,
        'SSIM_drop': ssim_drop_final,
        'Inference_Time_ms': final_results['Inference_Time_ms'],
        'FPS': final_results['FPS']
    },
    'all_experiments': {str(k): v for k, v in pruning_results_v2.items()}
}

with open('/content/pruning_results_corrected.json', 'w') as f:
    json.dump(results_summary, f, indent=2)

print(f"\nüíæ Results saved to '/content/pruning_results_corrected.json'")

# Save model
torch.save({
    'flownet_state_dict': final_pruned_model.flownet.state_dict(),
    'sparsity': best_sparsity,
    'masks': {k: v.numpy() for k, v in final_masks.items()},
    'results': final_results
}, '/content/pruned_model_corrected.pth')

print(f"üíæ Model saved to '/content/pruned_model_corrected.pth'")

print("\n" + "="*60)
print("‚úÖ CORRECTED PRUNING COMPLETE!")
print("="*60)

# Next steps recommendation
print("\nüìã NEXT STEPS:")
if psnr_drop_final <= 1.0:
    print("   1. ‚úÖ Proceed to quantization (quality is good)")
    print("   2. Optional: Fine-tune for further improvement")
    print("   3. Implement video processing pipeline")
elif psnr_drop_final <= 2.0:
    print("   1. ‚ö†Ô∏è  Fine-tune the pruned model (recommended)")
    print("   2. Then proceed to quantization")
    print("   3. Implement video processing pipeline")
else:
    print("   1. ‚ùå Fine-tune the pruned model (required)")
    print("   2. Re-evaluate after fine-tuning")
    print("   3. Then consider quantization")

print("="*60)

# ============================================================
# INTEGRATED TABLE EXPORT & DISPLAY (CSV + JSON)
# ============================================================

def status_from_psnr_drop(psnr_drop):
    if psnr_drop <= 0.5:
        return "‚úÖ Excellent"
    elif psnr_drop <= 1.0:
        return "‚úÖ Good"
    elif psnr_drop <= 1.5:
        return "‚ö†Ô∏è Marginal"
    else:
        return "‚ùå Too high"

# DIAGNOSTIC TABLE
diag_rows = []
# Only add old row if old results exist
try:
    if not math.isnan(results_old['PSNR']):
        diag_rows.append({"Method": "Old method (inverted logic)", "PSNR_dB": results_old['PSNR'], "SSIM": results_old['SSIM']})
except Exception:
    pass

diag_rows.append({"Method": "New method (corrected)", "PSNR_dB": results_new['PSNR'], "SSIM": results_new['SSIM']})
diag_rows.append({"Method": "Baseline", "PSNR_dB": baseline_psnr, "SSIM": baseline_ssim})

df_diag = pd.DataFrame(diag_rows)

# PRUNING EXPERIMENTS TABLE
rows = []
for s, res in pruning_results_v2.items():
    psnr = res['PSNR']
    ssim = res['SSIM']
    time_ms = res.get('Inference_Time_ms', float('nan'))
    fps = res.get('FPS', float('nan'))
    psnr_drop = baseline_psnr - psnr
    ssim_drop = baseline_ssim - ssim
    rows.append({
        "Sparsity": f"{s:.0%}",
        "PSNR_dB": round(psnr, 3),
        "PSNR_drop_dB": round(psnr_drop, 3),
        "SSIM": round(ssim, 4),
        "SSIM_drop": round(ssim_drop, 4),
        "Inference_ms": round(time_ms, 2),
        "FPS": round(fps, 2),
        "Status": status_from_psnr_drop(psnr_drop)
    })

df_prune = pd.DataFrame(rows).sort_values(by="Sparsity")

# FINAL SUMMARY TABLE
final_summary = {
    "Baseline_PSNR_dB": baseline_psnr,
    "Baseline_SSIM": baseline_ssim,
    "Baseline_Params": count_parameters(model),
    "Selected_Sparsity": f"{best_sparsity:.0%}",
    "Final_PSNR_dB": final_results['PSNR'],
    "Final_SSIM": final_results['SSIM'],
    "PSNR_drop_dB": round(baseline_psnr - final_results['PSNR'], 3),
    "SSIM_drop": round(baseline_ssim - final_results['SSIM'], 4),
    "Inference_ms": round(final_results['Inference_Time_ms'], 2),
    "FPS": round(final_results['FPS'], 2),
    "Est_Param_Reduction_pct": round(best_sparsity * 100, 1),
    "Quality_verdict": quality_verdict
}
df_final = pd.DataFrame([final_summary]).T.reset_index()
df_final.columns = ["Metric", "Value"]

# Display the tables in the notebook
display(Markdown("## Diagnostic comparison (old vs new)"))
display(df_diag)

display(Markdown("## Pruning experiments"))
display(df_prune)

display(Markdown("## Final summary"))
display(df_final)

# Save to CSV and JSON for quick sharing
out_dir = "/content"
df_prune.to_csv(os.path.join(out_dir, 'pruning_experiments_table.csv'), index=False)
df_diag.to_csv(os.path.join(out_dir, 'diagnostic_table.csv'), index=False)
with open(os.path.join(out_dir, 'pruning_summary.json'), 'w') as f:
    json.dump({
        "diagnostic": diag_rows,
        "pruning_experiments": rows,
        "final_summary": final_summary
    }, f, indent=2)

print("\nSaved: ")
print(f" - {os.path.join(out_dir, 'pruning_experiments_table.csv')}")
print(f" - {os.path.join(out_dir, 'diagnostic_table.csv')}")
print(f" - {os.path.join(out_dir, 'pruning_summary.json')}")

STEP 1: DIAGNOSTIC - COMPARE OLD VS NEW PRUNING

Computing channel importance...
‚úÖ Analyzed 57 convolutional layers
   Importance range: [1.89, 188.22], Mean: 73.84

COMPARING METHODS AT 30% SPARSITY
‚ÑπÔ∏è Old pruning functions not found; skipping old-method diagnostic.

‚úÖ NEW METHOD (corrected logic):
   Sparsity: 30.0%
   Total channels: 6858
   Channels to keep: 4800
   Threshold: 40.152420
   Logic: KEEP channels with importance >= 40.152420
   Result: 4800/6858 channels kept (70.0%)
   Actual sparsity: 30.0%
   Applied masks to 57 layers
   Zeroed 2058/6858 channels (30.0%)




   PSNR: 23.118 dB
   SSIM: 0.8412

DIAGNOSTIC RESULT:
‚ö†Ô∏è  MINIMAL DIFFERENCE: N/A dB
   Issue might be elsewhere

STEP 2: CORRECTED PRUNING EXPERIMENTS

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Testing sparsity: 10.0%
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
   Sparsity: 10.0%
   Total channels: 6858
   Channels to keep: 6172
   Threshold: 21.201897
   Logic: KEEP channels with importance >= 21.201897
   Result: 6172/6858 channels kept (90.0%)
   Actual sparsity: 10.0%
   Applied masks to 57 layers
   Zeroed 686/6858 channels (10.0%)

Evaluating on 50 samples...





üìä Results:
   PSNR: 18.283 dB (drop: 17.009 dB)
   SSIM: 0.6607 (drop: 0.3083)
   Inference: 10.02 ms (99.8 FPS)

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Testing sparsity: 20.0%
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
   Sparsity: 20.0%
   Total channels: 6858
   Channels to keep: 5486
   Threshold: 30.632299
   Logic: KEEP channels with importance >= 30.632299
   Result: 5486/6858 channels kept (80.0%)
   Actual sparsity: 20.0%
   Applied masks to 57 layers
   Zeroed 1372/6858 channels (20.0%)

Evaluating on 50 samples...





üìä Results:
   PSNR: 18.658 dB (drop: 16.634 dB)
   SSIM: 0.6846 (drop: 0.2844)
   Inference: 12.39 ms (80.7 FPS)

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Testing sparsity: 30.0%
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
   Sparsity: 30.0%
   Total channels: 6858
   Channels to keep: 4800
   Threshold: 40.152420
   Logic: KEEP channels with importance >= 40.152420
   Result: 4800/6858 channels kept (70.0%)
   Actual sparsity: 30.0%
   Applied masks to 57 layers
   Zeroed 2058/6858 channels (30.0%)

Evaluating on 50 samples...





üìä Results:
   PSNR: 23.764 dB (drop: 11.528 dB)
   SSIM: 0.8488 (drop: 0.1202)
   Inference: 10.36 ms (96.5 FPS)

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Testing sparsity: 40.0%
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
   Sparsity: 40.0%
   Total channels: 6858
   Channels to keep: 4114
   Threshold: 50.451553
   Logic: KEEP channels with importance >= 50.451553
   Result: 4114/6858 channels kept (60.0%)
   Actual sparsity: 40.0%
   Applied masks to 57 layers
   Zeroed 2744/6858 channels (40.0%)

Evaluating on 50 samples...





üìä Results:
   PSNR: 30.409 dB (drop: 4.883 dB)
   SSIM: 0.9552 (drop: 0.0138)
   Inference: 9.96 ms (100.4 FPS)

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Testing sparsity: 50.0%
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
   Sparsity: 50.0%
   Total channels: 6858
   Channels to keep: 3429
   Threshold: 73.872589
   Logic: KEEP channels with importance >= 73.872589
   Result: 3429/6858 channels kept (50.0%)
   Actual sparsity: 50.0%
   Applied masks to 57 layers
   Zeroed 3429/6858 channels (50.0%)

Evaluating on 50 samples...





üìä Results:
   PSNR: 30.919 dB (drop: 4.373 dB)
   SSIM: 0.9572 (drop: 0.0118)
   Inference: 10.81 ms (92.5 FPS)

STEP 3: SELECTING OPTIMAL SPARSITY

Sparsity   PSNR       Drop       SSIM       Status         
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
10.0%      18.283     17.009     0.6607     ‚ùå Too high     
20.0%      18.658     16.634     0.6846     ‚ùå Too high     
30.0%      23.764     11.528     0.8488     ‚ùå Too high     
40.0%      30.409     4.883      0.9552     ‚ùå Too high     
50.0%      30.919     4.373      0.9572     ‚ùå Too high     

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
‚ö†Ô∏è  SELECTED: 10.0% sparsity (minimum)
   PSNR: 18.283 dB (drop: 17.009 dB)
   Note: All sparsity levels had h

                                                             


FINAL RESULTS - CORRECTED PRUNING

üìä BASELINE MODEL:
   PSNR: 35.292 dB
   SSIM: 0.9690




AttributeError: 'Model' object has no attribute 'parameters'