<a href="https://colab.research.google.com/github/AlperYildirim1/Language-as-Waves/blob/main/CIFAR_Last.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import time
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# ==========================================
# 1. CONFIGURATION (UNCHANGED)
# ==========================================
BATCH_SIZE = 128
EPOCHS = 20 # Reduced for ablation speed (Increase to 50 for full paper results)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SLM_BITS = 4

# Learning Rates
LR_OPTICAL = 0.005
LR_DIGITAL = 0.001

print(f"‚öôÔ∏è DEVICE: {DEVICE}")

# ==========================================
# 2. DATA PREPARATION (UNCHANGED)
# ==========================================
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# ==========================================
# 3. CORE LOGIC (EXACT COPY OF YOUR CODE)
# ==========================================
def quantize_slm(x, bits):
    levels = 2**bits - 1
    return torch.round(x * levels) / levels

class StableOpticalLayer(nn.Module):
    def __init__(self, num_filters):
        super().__init__()
        # Initialization: Balanced magnitude, small random phase
        self.slm_mag = nn.Parameter(torch.ones(num_filters, 32, 32) * 0.8)
        self.slm_phase = nn.Parameter(torch.randn(num_filters, 32, 32) * 0.1)

    def forward(self, x):
        B, C, H, W = x.shape
        x = x.view(B * C, 1, H, W)

        # Physical Constraints
        mag = torch.sigmoid(self.slm_mag)
        mag = quantize_slm(mag, SLM_BITS)
        phase = self.slm_phase

        # Fourier Optics (Stable Autograd Version)
        x_fft = torch.fft.fftshift(torch.fft.fft2(x))
        h = mag * torch.exp(1j * phase)
        obj_fft = x_fft * h
        out_spatial = torch.fft.ifft2(torch.fft.ifftshift(obj_fft))

        # Intensity Detection
        out_intensity = out_spatial.abs()

        # Hardware Noise Simulation (Training Only)
        if self.training:
            out_intensity = out_intensity * (1.0 + torch.randn_like(out_intensity) * 0.01)

        # Stable AGC (Automatic Gain Control)
        max_val = torch.amax(out_intensity, dim=(2,3), keepdim=True) + 1e-6
        out_normalized = out_intensity / max_val

        return out_normalized

class HybridNetFinal(nn.Module):
    def __init__(self):
        super().__init__()
        self.optics = StableOpticalLayer(num_filters=64)

        self.pool_dim = 8
        self.flat_dim = 64 * 3 * self.pool_dim**2

        self.classifier = nn.Sequential(
            nn.Linear(self.flat_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, 10)
        )

    def forward(self, x):
        B = x.shape[0]
        opt_out = self.optics(x)
        pooled = F.adaptive_max_pool2d(opt_out, (self.pool_dim, self.pool_dim))
        flat = pooled.view(B, -1)
        return self.classifier(flat)

# ==========================================
# 4. ABLATION ENGINE
# ==========================================
def run_training_session(mode_name, freeze_optics):
    """
    Runs a training session.
    freeze_optics (bool): If True, optics layer is NOT optimized (random projection).
    """
    print(f"\nüöÄ STARTING EXPERIMENT: {mode_name}")
    print(f"   Optics Frozen: {freeze_optics}")

    model = HybridNetFinal().to(DEVICE)

    # --- CRITICAL: OPTIMIZER CONFIGURATION ---
    if freeze_optics:
        # OPTION B: FROZEN OPTICS (Baseline)
        # We only pass classifier parameters. Optics remains random.
        optimizer = torch.optim.AdamW([
            {'params': model.classifier.parameters(), 'lr': LR_DIGITAL}
        ], weight_decay=1e-3)
    else:
        # OPTION A: LEARNED OPTICS (Proposed Method)
        # We optimize both optics (slowly) and classifier.
        optimizer = torch.optim.AdamW([
            {'params': model.optics.parameters(), 'lr': LR_OPTICAL},
            {'params': model.classifier.parameters(), 'lr': LR_DIGITAL}
        ], weight_decay=1e-3)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
    criterion = nn.CrossEntropyLoss()

    acc_history = []

    for epoch in range(EPOCHS):
        model.train()
        for data, target in train_loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()

            # Gradient Clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()

        scheduler.step()

        # Testing
        model.eval()
        test_correct = 0
        total_samples = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(DEVICE), target.to(DEVICE)
                output = model(data)
                pred = output.argmax(1)
                test_correct += pred.eq(target).sum().item()
                total_samples += target.size(0)

        acc = 100. * test_correct / total_samples
        acc_history.append(acc)
        print(f"   Epoch {epoch+1}/{EPOCHS} | Test Acc: {acc:.2f}%")

    return acc_history

# ==========================================
# 5. EXECUTION & PLOTTING
# ==========================================
if __name__ == "__main__":
    # Experiment 1: Full Model (Optics Learning)
    acc_learned = run_training_session("PROPOSED (Learned Optics)", freeze_optics=False)

    # Experiment 2: Baseline (Optics Frozen / Random)
    acc_frozen = run_training_session("BASELINE (Frozen Optics)", freeze_optics=True)

    # --- PLOTTING FOR PAPER ---
    plt.figure(figsize=(10, 6), dpi=120)
    epochs_x = range(1, EPOCHS + 1)

    plt.plot(epochs_x, acc_learned, 'o-', label='Learned Optics (Proposed)', color='blue', linewidth=2)
    plt.plot(epochs_x, acc_frozen, 'x--', label='Frozen Optics (Random Baseline)', color='gray', linewidth=2)

    plt.xlabel('Epochs', fontsize=12)
    plt.ylabel('Test Accuracy (%)', fontsize=12)
    plt.title('Ablation Study: Contribution of Optical Learning', fontsize=14)
    plt.legend(fontsize=11)
    plt.grid(True, linestyle='--', alpha=0.7)

    # Calculate Delta
    final_delta = acc_learned[-1] - acc_frozen[-1]
    print(f"\nüìä FINAL RESULTS:")
    print(f"   Learned Accuracy: {acc_learned[-1]:.2f}%")
    print(f"   Frozen Accuracy:  {acc_frozen[-1]:.2f}%")
    print(f"   Optical Gain (Delta): +{final_delta:.2f}%")

    if final_delta > 3.0:
        print("‚úÖ CONCLUSION: The optical layer provides significant learnable features.")
    else:
        print("‚ö†Ô∏è CONCLUSION: The digital head dominates performance (Consider reducing classifier size).")

    plt.show()