# ResNet-18 Pruning & Quantization (CIFAR-10)
Kaggle / Jupyter / Colab 호환

**Note**: 사전 학습된 체크포인트 필요 (`train_resnet18_cifar10.ipynb`로 학습)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.utils.prune as prune
import torchvision
import torchvision.transforms as transforms
import os
import copy
import numpy as np

# ===================== Platform Detection =====================
def detect_platform():
    if 'KAGGLE_KERNEL_RUN_TYPE' in os.environ:
        return 'kaggle'
    try:
        import google.colab
        return 'colab'
    except:
        return 'jupyter'

PLATFORM = detect_platform()

# ===================== CONFIG =====================
CONFIG = {
    'checkpoint_path': './checkpoints/resnet18_ckpt.pth',  # 학습된 모델 경로
    'pruning_ratios': [0.50, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90],
    'finetune_epochs': 5,
    'finetune_lr': 0.001,
    'batch_size': 128,
}

# Platform별 경로 설정
if PLATFORM == 'kaggle':
    SAVE_DIR = '/kaggle/working'
    DATA_DIR = './data'
    CONFIG['checkpoint_path'] = '/kaggle/working/resnet18_ckpt.pth'
elif PLATFORM == 'colab':
    SAVE_DIR = '/content'
    DATA_DIR = './data'
    CONFIG['checkpoint_path'] = '/content/resnet18_ckpt.pth'
else:  # jupyter (local)
    SAVE_DIR = './checkpoints'
    DATA_DIR = './data'

os.makedirs(SAVE_DIR, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f'✓ Platform: {PLATFORM}')
print(f'✓ Device: {device}')
print(f'✓ Checkpoint: {CONFIG["checkpoint_path"]}')
print(f'✓ Save Dir: {SAVE_DIR}')

In [None]:
# ===================== Utility Functions =====================
def get_model_info(model, name='Model'):
    temp_path = os.path.join(SAVE_DIR, 'temp.p')
    torch.save(model.state_dict(), temp_path)
    size_mb = os.path.getsize(temp_path) / 1e6
    os.remove(temp_path)
    
    total_params = sum(p.numel() for p in model.parameters())
    zero_params = sum((p == 0).sum().item() for p in model.parameters())
    non_zero_params = total_params - zero_params
    sparsity = 100 * zero_params / total_params
    
    print(f'[{name}] Size: {size_mb:.2f}MB | Params: {non_zero_params:,}/{total_params:,} | Sparsity: {sparsity:.1f}%')
    return size_mb, non_zero_params, sparsity

def evaluate(model, loader, device=device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    return 100 * correct / total

print('✓ Utility functions defined')

In [None]:
# ===================== Data & Model =====================
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

trainset = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=transform_test)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

# ===================== ResNet-18 for CIFAR-10 (Custom) =====================
import torch.nn.functional as F

class BasicBlock(nn.Module):
    expansion = 1
    
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes))
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return F.relu(out)

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)
    
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        return self.linear(out)

def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])

print(f'✓ Data loaded: {len(trainset)} train, {len(testset)} test')
print('✓ Custom ResNet-18 defined')

In [None]:
# ===================== Load Pretrained Model =====================
baseline_model = ResNet18().to(device)

if not os.path.exists(CONFIG['checkpoint_path']):
    raise FileNotFoundError(f"Checkpoint not found: {CONFIG['checkpoint_path']}\n"
                            f"Please train the model first using train_resnet18_cifar10.ipynb")

checkpoint = torch.load(CONFIG['checkpoint_path'], map_location=device)
baseline_model.load_state_dict(checkpoint['net'])
base_acc = evaluate(baseline_model, testloader)
base_size, base_nz, _ = get_model_info(baseline_model, 'Baseline')

print(f'\n✓ Loaded pretrained model')
print(f'✓ Baseline accuracy: {base_acc:.2f}%')

In [None]:
# ===================== Pruning Analysis =====================
print('Pruning Sensitivity Analysis...')
print('-' * 70)

results = {'ratio': [], 'acc': [], 'params': []}
best_pruned_model, target_ratio = None, 0.0

for r in CONFIG['pruning_ratios']:
    model_temp = copy.deepcopy(baseline_model)
    
    # Global L1 Unstructured Pruning
    params_to_prune = [(m, 'weight') for m in model_temp.modules() 
                       if isinstance(m, (nn.Conv2d, nn.Linear))]
    prune.global_unstructured(params_to_prune, pruning_method=prune.L1Unstructured, amount=r)
    for m, _ in params_to_prune:
        prune.remove(m, 'weight')
    
    acc = evaluate(model_temp, testloader)
    _, nz_params, sparsity = get_model_info(model_temp, f'Prune {r*100:.0f}%')
    
    results['ratio'].append(r)
    results['acc'].append(acc)
    results['params'].append(nz_params)
    
    drop = base_acc - acc
    status = 'Safe' if drop < 1 else 'Caution' if drop < 2 else 'Warning' if drop < 5 else 'Collapse'
    print(f'  → Acc: {acc:.2f}% (drop: {drop:.2f}%) [{status}]')
    
    if drop < 2.0:
        best_pruned_model = copy.deepcopy(model_temp)
        target_ratio = r

print('-' * 70)
print(f'✓ Best pruning ratio: {target_ratio*100:.0f}%')

In [None]:
# ===================== Fine-tuning =====================
print(f'Fine-tuning pruned model ({target_ratio*100:.0f}% sparsity)...')

ft_model = copy.deepcopy(best_pruned_model).to(device)
criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(ft_model.parameters(), lr=CONFIG['finetune_lr'], momentum=0.9, weight_decay=5e-4)

pruned_acc_before = results['acc'][results['ratio'].index(target_ratio)]

for epoch in range(CONFIG['finetune_epochs']):
    ft_model.train()
    running_loss = 0.0
    for inputs, labels in trainloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer_ft.zero_grad()
        loss = criterion(ft_model(inputs), labels)
        loss.backward()
        optimizer_ft.step()
        running_loss += loss.item()
    print(f'  FT Epoch {epoch+1}/{CONFIG["finetune_epochs"]} | Loss: {running_loss/len(trainloader):.4f}')

ft_acc = evaluate(ft_model, testloader)
print(f'\n✓ Fine-tuning done')
print(f'  Before FT: {pruned_acc_before:.2f}%')
print(f'  After FT:  {ft_acc:.2f}%')
print(f'  Recovered: +{ft_acc - pruned_acc_before:.2f}%')

# Save pruned model
pruned_path = os.path.join(SAVE_DIR, 'pruned_finetuned.pth')
torch.save(ft_model.state_dict(), pruned_path)
print(f'✓ Saved: {pruned_path}')

In [None]:
# ===================== Quantization =====================
print('Quantizing model to INT8...')

# CPU로 이동 (양자화는 CPU에서)
model_fp32 = copy.deepcopy(ft_model).cpu().eval()

# Dynamic Quantization
model_int8 = torch.quantization.quantize_dynamic(
    model_fp32, {nn.Linear, nn.Conv2d}, dtype=torch.qint8)

# 평가
q_acc = evaluate(model_int8, testloader, device='cpu')
q_size, _, _ = get_model_info(model_int8, 'Quantized')

print(f'\n' + '=' * 50)
print('FINAL COMPARISON')
print('=' * 50)
print(f'Baseline:   {base_acc:.2f}% | {base_size:.2f}MB')
print(f'Pruned+FT:  {ft_acc:.2f}% | {base_size:.2f}MB (sparse)')
print(f'Quantized:  {q_acc:.2f}% | {q_size:.2f}MB')
print(f'\nCompression: {base_size/q_size:.2f}x smaller')
print(f'Acc drop:    {base_acc - q_acc:.2f}%')

# Save quantized model
q_path = os.path.join(SAVE_DIR, 'quantized_int8.pth')
torch.save(model_int8.state_dict(), q_path)
print(f'\n✓ Saved: {q_path}')

In [None]:
# ===================== Visualization =====================
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Plot 1: Accuracy vs Pruning Ratio
axes[0].plot(np.array(results['ratio'])*100, results['acc'], 'r-o', label='Pruned Acc')
axes[0].axhline(y=base_acc, color='b', linestyle='--', label=f'Baseline ({base_acc:.1f}%)')
axes[0].axhline(y=ft_acc, color='g', linestyle=':', label=f'After FT ({ft_acc:.1f}%)')
axes[0].set_xlabel('Pruning Ratio (%)')
axes[0].set_ylabel('Accuracy (%)')
axes[0].set_title('Pruning Sensitivity')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot 2: Parameter Reduction
params_m = np.array(results['params']) / 1e6
axes[1].plot(np.array(results['ratio'])*100, params_m, 'b-s')
axes[1].set_xlabel('Pruning Ratio (%)')
axes[1].set_ylabel('Parameters (M)')
axes[1].set_title('Parameter Reduction')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
fig_path = os.path.join(SAVE_DIR, 'pruning_results.png')
plt.savefig(fig_path, dpi=150)
plt.show()

print(f'✓ Saved: {fig_path}')

In [None]:
# ===================== Save (Platform별) =====================
print(f'Platform: {PLATFORM}')
print(f'Models saved in: {SAVE_DIR}')

if PLATFORM == 'colab':
    try:
        from google.colab import drive
        drive.mount('/content/drive')
        import shutil
        dest = '/content/drive/MyDrive/pruning_results/'
        os.makedirs(dest, exist_ok=True)
        for f in ['pruned_finetuned.pth', 'quantized_int8.pth', 'pruning_results.png']:
            src = os.path.join(SAVE_DIR, f)
            if os.path.exists(src):
                shutil.copy(src, dest)
        print(f'✓ Files copied to Google Drive: {dest}')
    except Exception as e:
        print(f'Drive mount failed: {e}')
        print('Files are in /content/')

elif PLATFORM == 'kaggle':
    print('✓ Files in /kaggle/working/ (download from Output tab)')
    
else:
    print(f'✓ Files in {SAVE_DIR}/')

print('\n=== All Done! ===')