# Ablation Study 2.6: Data Augmentation Analysis

## Motivation

The current training pipeline uses three data augmentation techniques:
1. Random point dropout
2. Random scale (scaling point clouds)
3. Random shift (translating point clouds)

This ablation examines the contribution of each augmentation technique individually and in combination to understand which augmentations are most beneficial for model performance.

## Experimental Plan

1. Test different augmentation combinations:
   - No augmentation (baseline)
   - Only dropout
   - Only scale
   - Only shift
   - Dropout + Scale
   - Dropout + Shift
   - Scale + Shift
   - All augmentations (current default)
2. Keep all other hyperparameters constant (step=1, temp=5.0, spike=True)
3. Train models on ModelNet40 dataset
4. Compare accuracy and generalization

## Expected Insight

This experiment reveals which augmentation techniques contribute most to model performance and generalization. Some augmentations may be redundant or even harmful, while others may be critical for good performance.

## Dataset Setup

Before running this notebook, ensure the ModelNet40 dataset is downloaded and extracted to: `data/modelnet40_normal_resampled/`

In [None]:
import os
import sys
import torch
import numpy as np
import datetime
import importlib
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt

sys.path.append('..')
from data_utils.ModelNetDataLoader import ModelNetDataLoader
from models.spike_model import SpikeModel
import provider

print('Imports successful!')
from cache_utils import load_training_history, save_training_history, cache_checkpoint, load_cached_checkpoint_path, best_metric
from viz_utils import plot_training_curves, summarize_histories, plot_metric_table, plot_metric_bars


## Augmentation Functions

In [None]:
def apply_augmentations(points, use_dropout=True, use_scale=True, use_shift=True):
    """Apply selected augmentations to point cloud data"""
    if use_dropout:
        points = provider.random_point_dropout(points)
    if use_scale:
        points[:,:,0:3] = provider.random_scale_point_cloud(points[:,:,0:3])
    if use_shift:
        points[:,:,0:3] = provider.shift_point_cloud(points[:,:,0:3])
    return points

print('Augmentation functions defined!')

## Configuration for Different Augmentation Combinations

In [None]:
class Args:
    def __init__(self, use_dropout=True, use_scale=True, use_shift=True, num_category=40):
        self.use_cpu = False
        self.gpu = '0'
        self.batch_size = 24
        self.model = 'pointnet_cls'
        self.num_category = num_category
        self.epoch = 200
        self.learning_rate = 0.001
        self.num_point = 1024
        self.optimizer = 'Adam'
        
        # Create descriptive log directory name
        aug_parts = []
        if use_dropout:
            aug_parts.append('dropout')
        if use_scale:
            aug_parts.append('scale')
        if use_shift:
            aug_parts.append('shift')
        aug_str = '_'.join(aug_parts) if aug_parts else 'no_aug'
        self.log_dir = f'ablation_{aug_str}_modelnet{num_category}'
        
        self.decay_rate = 1e-4
        self.use_normals = False
        self.process_data = False
        self.use_uniform_sample = False
        self.step = 1
        self.spike = True
        self.temp = 5.0
        self.use_dropout = use_dropout
        self.use_scale = use_scale
        self.use_shift = use_shift

# Create configurations for different augmentation combinations
augmentation_configs = [
    {'name': 'No Augmentation', 'dropout': False, 'scale': False, 'shift': False},
    {'name': 'Only Dropout', 'dropout': True, 'scale': False, 'shift': False},
    {'name': 'Only Scale', 'dropout': False, 'scale': True, 'shift': False},
    {'name': 'Only Shift', 'dropout': False, 'scale': False, 'shift': True},
    {'name': 'Dropout + Scale', 'dropout': True, 'scale': True, 'shift': False},
    {'name': 'Dropout + Shift', 'dropout': True, 'scale': False, 'shift': True},
    {'name': 'Scale + Shift', 'dropout': False, 'scale': True, 'shift': True},
    {'name': 'All Augmentations', 'dropout': True, 'scale': True, 'shift': True},
]

args_list = [Args(use_dropout=cfg['dropout'], use_scale=cfg['scale'], use_shift=cfg['shift'], num_category=40) 
             for cfg in augmentation_configs]

for cfg, args in zip(augmentation_configs, args_list):
    print(f"{cfg['name']}: {args.log_dir}")

## Helper Functions

In [None]:
def setup_experiment(args):
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    exp_dir = Path('../log/classification') / args.log_dir
    exp_dir.mkdir(parents=True, exist_ok=True)
    checkpoints_dir = exp_dir / 'checkpoints'
    checkpoints_dir.mkdir(exist_ok=True)
    return exp_dir, checkpoints_dir

def load_data(args):
    data_path = 'C:\\Users\\VIICTTE\\ML_Project\\modelnet40_normal_resampled'
    train_dataset = ModelNetDataLoader(root=data_path, args=args, split='train')
    test_dataset = ModelNetDataLoader(root=data_path, args=args, split='test')
    trainDataLoader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True)
    testDataLoader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
    return trainDataLoader, testDataLoader

def create_model(args):
    sys.path.append('../models')
    model = importlib.import_module(args.model)
    classifier = model.get_model(args.num_category, normal_channel=args.use_normals)
    if args.spike:
        classifier = SpikeModel(classifier, args.step, args.temp)
        classifier.set_spike_state(True)
    criterion = model.get_loss()
    if not args.use_cpu:
        classifier = classifier.cuda()
        criterion = criterion.cuda()
    return classifier, criterion

print('Helper functions defined!')

## Training Function with Configurable Augmentation

In [None]:
def train_model(args, exp_dir, checkpoints_dir, max_epochs=None):
    if max_epochs:
        args.epoch = max_epochs
    trainDataLoader, testDataLoader = load_data(args)
    classifier, criterion = create_model(args)
    
    optimizer = torch.optim.Adam(classifier.parameters(), lr=args.learning_rate, weight_decay=args.decay_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)
    
    best_acc = 0.0
    history = []
    
    aug_desc = []
    if args.use_dropout:
        aug_desc.append('dropout')
    if args.use_scale:
        aug_desc.append('scale')
    if args.use_shift:
        aug_desc.append('shift')
    aug_str = ', '.join(aug_desc) if aug_desc else 'no augmentation'
    print(f'Training with: {aug_str}')
    
    for epoch in range(args.epoch):
        print(f'Epoch {epoch+1}/{args.epoch}')
        classifier.train()
        scheduler.step()
        mean_correct = []
        
        for points, target in tqdm(trainDataLoader):
            optimizer.zero_grad()
            points = points.data.numpy()
            
            # Apply selected augmentations
            points = apply_augmentations(points, args.use_dropout, args.use_scale, args.use_shift)
            
            points = torch.Tensor(points).transpose(2, 1)
            if not args.use_cpu:
                points, target = points.cuda(), target.cuda()
            pred, trans_feat = classifier(points)
            loss = criterion(pred, target.long(), trans_feat)
            loss.backward()
            optimizer.step()
            pred_choice = pred.data.max(1)[1]
            correct = pred_choice.eq(target.long().data).cpu().sum()
            mean_correct.append(correct.item() / float(points.size()[0]))
        
        train_acc = np.mean(mean_correct)
        
        with torch.no_grad():
            classifier.eval()
            test_correct = []
            for points, target in testDataLoader:
                if not args.use_cpu:
                    points, target = points.cuda(), target.cuda()
                points = points.transpose(2, 1)
                pred, _ = classifier(points)
                pred_choice = pred.data.max(1)[1]
                correct = pred_choice.eq(target.long().data).cpu().sum()
                test_correct.append(correct.item() / float(points.size()[0]))
            test_acc = np.mean(test_correct)
        
        print(f'Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
        history.append({'epoch': epoch+1, 'train_acc': train_acc, 'test_acc': test_acc})
        
        if test_acc >= best_acc:
            best_acc = test_acc
            torch.save({'model_state_dict': classifier.state_dict()}, str(checkpoints_dir / 'best_model.pth'))
    
    return classifier, history, best_acc

print('Training function defined!')

## Train Models with Different Augmentation Combinations

**Note**: Training takes significant time. Consider reducing epochs for testing or testing fewer combinations.

In [None]:
# Train models with different augmentation combinations
BASELINE_CACHE_NAME = 'baseline_step_1'
results = {}

for cfg, args in zip(augmentation_configs, args_list):
    print(f"
=== Training: {cfg['name']} ===")
    is_baseline = cfg['dropout'] and cfg['scale'] and cfg['shift']

    if is_baseline:
        baseline_history, baseline_meta = load_training_history(BASELINE_CACHE_NAME, with_metadata=True)
        if baseline_history:
            best_acc = best_metric(baseline_history, ['test_acc', 'test_instance_acc']) or 0.0
            print(f"Loaded cached baseline history with {len(baseline_history)} epoch(s). Best accuracy: {best_acc:.4f}")
            results[cfg['name']] = {
                'classifier': None,
                'history': baseline_history,
                'best_acc': best_acc,
                'config': cfg,
                'metadata': baseline_meta,
            }
            cached_ckpt = load_cached_checkpoint_path(BASELINE_CACHE_NAME)
            if cached_ckpt:
                print(f'Cached baseline checkpoint available at: {cached_ckpt}')
            continue
        print('No cached baseline history found; training full-augmentation baseline from scratch.')

    exp_dir, ckpt_dir = setup_experiment(args)
    classifier, history, best_acc = train_model(args, exp_dir, ckpt_dir, max_epochs=10)  # Reduce to 10 for testing

    if is_baseline:
        metadata = {
            'variant': 'baseline_all_augmentations',
            'config': dict(vars(args)),
            'max_epochs': args.epoch,
            'augmentation_config': cfg,
        }
        history_path = save_training_history(history, BASELINE_CACHE_NAME, metadata=metadata)
        print(f'Saved baseline history to {history_path}')
        best_ckpt = ckpt_dir / 'best_model.pth'
        if best_ckpt.exists():
            cached_ckpt = cache_checkpoint(best_ckpt, BASELINE_CACHE_NAME)
            print(f'Cached baseline checkpoint to {cached_ckpt}')

    results[cfg['name']] = {
        'classifier': classifier,
        'history': history,
        'best_acc': best_acc,
        'config': cfg,
    }
    print(f"{cfg['name']}: Best Accuracy = {best_acc:.4f}")

## Visualization and Analysis

In [None]:
import matplotlib.pyplot as plt
from pathlib import Path

metrics = {
    "Train Accuracy": ["train_acc"],
    "Test Accuracy": ["test_instance_acc", "test_acc"],
}

available_histories = {}
if 'results' in locals():
    for cfg in augmentation_configs:
        name = cfg['name']
        entry = results.get(name, {})
        history = entry.get('history', []) if isinstance(entry, dict) else []
        if history:
            available_histories[name] = history

if not available_histories:
    print('No training history available for visualization. Run the training cells above first.')
else:
    fig, axes = plot_training_curves(
        available_histories,
        metrics,
        title="Augmentation Ablation: Accuracy per Epoch",
        max_cols=2
    )

    figures_dir = Path("../log/figures") / "augmentation"
    figures_dir.mkdir(parents=True, exist_ok=True)
    curve_path = figures_dir / "accuracy_curves.png"
    fig.savefig(curve_path, dpi=150, bbox_inches="tight")
    plt.show()

    summary_stats = summarize_histories(available_histories, metrics)
    table_fig, table_ax = plot_metric_table(
        summary_stats,
        title="Augmentation Ablation Summary",
        value_fmt="{:.4f}",
        include_first=True
    )
    table_path = figures_dir / "accuracy_summary.png"
    table_fig.savefig(table_path, dpi=150, bbox_inches="tight")
    plt.show()

    bar_fig, bar_ax = plot_metric_bars(
        summary_stats,
        metric_name="Test Accuracy",
        title="Best Test Accuracy by Augmentation",
        ylabel="Test Accuracy"
    )
    bar_path = figures_dir / "best_accuracy.png"
    bar_fig.savefig(bar_path, dpi=150, bbox_inches="tight")
    plt.show()

    baseline_best = summary_stats.get('No Augmentation', {}).get('Test Accuracy', {}).get('best')
    if baseline_best is not None:
        contrib_fig, contrib_ax = plt.subplots(figsize=(8, 4))
        labels = []
        deltas = []
        for label, metric_stats in summary_stats.items():
            best_value = metric_stats.get('Test Accuracy', {}).get('best')
            if best_value is None:
                continue
            labels.append(label)
            deltas.append(best_value - baseline_best)
        bars = contrib_ax.bar(labels, deltas, color='#6a5acd')
        contrib_ax.axhline(0.0, color='black', linestyle='--', linewidth=0.8)
        contrib_ax.set_ylabel('Accuracy Δ vs. No Augmentation')
        contrib_ax.set_title('Accuracy Gain Relative to Baseline Augmentation')
        contrib_ax.grid(True, axis='y', linestyle='--', alpha=0.3)
        for bar, delta in zip(bars, deltas):
            contrib_ax.text(bar.get_x() + bar.get_width() / 2, delta, f'{delta*100:.2f}%', ha='center', va='bottom')
        contrib_fig.tight_layout()
        contrib_path = figures_dir / "accuracy_delta.png"
        contrib_fig.savefig(contrib_path, dpi=150, bbox_inches="tight")
        plt.show()

    gap_fig, gap_ax = plt.subplots(figsize=(8, 4))
    labels = []
    gaps = []
    for label, metric_stats in summary_stats.items():
        train_last = metric_stats.get('Train Accuracy', {}).get('last')
        test_last = metric_stats.get('Test Accuracy', {}).get('last')
        if train_last is None or test_last is None:
            continue
        labels.append(label)
        gaps.append(train_last - test_last)
    if labels:
        bars = gap_ax.bar(labels, gaps, color='#ff8c00')
        gap_ax.axhline(0.0, color='black', linestyle='--', linewidth=0.8)
        gap_ax.set_ylabel('Generalization Gap (Train - Test)')
        gap_ax.set_title('Generalization Gap by Augmentation Policy')
        gap_ax.grid(True, axis='y', linestyle='--', alpha=0.3)
        for bar, gap in zip(bars, gaps):
            gap_ax.text(bar.get_x() + bar.get_width() / 2, gap, f'{gap*100:.2f}%', ha='center', va='bottom')
        gap_fig.tight_layout()
        gap_path = figures_dir / "generalization_gap.png"
        gap_fig.savefig(gap_path, dpi=150, bbox_inches="tight")
        plt.show()

    def _fmt(value):
        return '-' if value is None else f"{value:.4f}"

    print('
Detailed metrics:')
    for label, metric_stats in summary_stats.items():
        train_stats = metric_stats.get('Train Accuracy', {})
        test_stats = metric_stats.get('Test Accuracy', {})
        print(
            f"  {label}: train_last={{_fmt(train_stats.get('last'))}}, test_last={{_fmt(test_stats.get('last'))}}, best_test={{_fmt(test_stats.get('best'))}}"
        )

    if baseline_best is not None:
        print(f"
Baseline (No Augmentation) best test accuracy: {baseline_best:.4f}")


## Summary and Insights

In [None]:
print('\n' + '='*60)
print('ABLATION STUDY 2.6 SUMMARY: DATA AUGMENTATION')
print('='*60)

print('\n1. Final Accuracy Results:')
for cfg in augmentation_configs:
    acc = results[cfg['name']]['best_acc']
    print(f"   {cfg['name']:20s}: {acc:.4f}")

best_config = max(augmentation_configs, key=lambda cfg: results[cfg['name']]['best_acc'])
best_acc = results[best_config['name']]['best_acc']
print(f"\n2. Best Configuration: {best_config['name']} with accuracy {best_acc:.4f}")

print('\n3. Individual Augmentation Contributions:')
baseline_acc = results['No Augmentation']['best_acc']
for aug_name in ['Only Dropout', 'Only Scale', 'Only Shift']:
    acc = results[aug_name]['best_acc']
    improvement = (acc - baseline_acc) * 100
    print(f"   {aug_name:15s}: +{improvement:5.2f}% over baseline")

print('\n4. Generalization Analysis:')
for cfg in augmentation_configs:
    history = results[cfg['name']]['history']
    final_train = history[-1]['train_acc']
    final_test = history[-1]['test_acc']
    gap = (final_train - final_test) * 100
    print(f"   {cfg['name']:20s}: {gap:5.2f}% gap (train-test)")

print('\n5. Key Insights:')
if best_config['name'] == 'All Augmentations':
    print('   - All augmentations together provide the best performance')
    print('   - Augmentations are complementary and work well in combination')
    print('   - Current default configuration is optimal')
elif best_config['name'] == 'No Augmentation':
    print('   - Augmentations do NOT help performance')
    print('   - Model may be overfitting to augmented data')
    print('   - Consider removing augmentations or adjusting their parameters')
else:
    print(f"   - {best_config['name']} provides the best performance")
    print('   - Not all augmentations are beneficial')
    print('   - Some augmentations may be redundant or harmful')

print('\n6. Recommendations:')
most_important = max(['Only Dropout', 'Only Scale', 'Only Shift'], 
                     key=lambda name: results[name]['best_acc'])
print(f"   - Most important single augmentation: {most_important}")
print(f"   - Best overall configuration: {best_config['name']}")

if results['All Augmentations']['best_acc'] < best_acc:
    print('   - Consider using fewer augmentations for better performance')
    print('   - Some augmentations may be introducing too much noise')
else:
    print('   - Keep all augmentations for robust training')
    print('   - Augmentations improve generalization')

print('\n' + '='*60)