In [90]:
from DSN import DeepSpectralNet
from SBN import SpectralBillinearNet
import torch
import torch.nn as nn
import torch.nn.functional as F
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
from torchvision import datasets,transforms
from torch.utils.data import DataLoader

transform = transforms.ToTensor()

mnist_train = datasets.FashionMNIST(
    root="./data",
    train=True,
    download=True,
    transform=transform
)

mnist_test = datasets.FashionMNIST(
    root="./data",
    train=False,
    download=True,
    transform=transform
)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))  # 28x28 → 784
])

train_ds = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)
test_ds  = datasets.FashionMNIST('./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
test_loader  = DataLoader(test_ds, batch_size=256)

In [4]:
device='cpu'

In [162]:
model = DeepSpectralNet(
    layers_dim=[784,128,64,32,16,8,10],
    #use_final_linear=True,       # CRUCIAL pour classification
    use_layernorm=True
).to(device)

In [166]:
import torch
import torch.nn.functional as F

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(5):
    model.train()
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        logits = model(x)     # logits bruts
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch} | loss = {loss.item():.4f}")

Epoch 0 | loss = 0.3949
Epoch 1 | loss = 0.2400
Epoch 2 | loss = 0.4724
Epoch 3 | loss = 0.3622
Epoch 4 | loss = 0.4466


In [167]:
model.eval()

correct = 0
total = 0

with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)

        logits = model(x)              # (batch, 10)
        preds = logits.argmax(dim=1)   # classe prédite

        correct += (preds == y).sum().item()
        total += y.size(0)

accuracy = correct / total
print(f"Test accuracy: {accuracy * 100:.2f}%")

Test accuracy: 86.71%


In [183]:
# --- 1. Dataset ---
transform = transforms.Compose([
    transforms.ToTensor(),              # Passe de [0, 255] à [0, 1]
    transforms.Normalize((0.5,), (0.5,)) # Passe de [0, 1] à [-1, 1]
])

train_ds = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)
test_ds  = datasets.FashionMNIST('./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
test_loader  = DataLoader(test_ds, batch_size=256)


In [40]:
# --- 2. CNN Model ---
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool  = nn.MaxPool2d(2, 2)
        self.fc1   = nn.Linear(32*7*7, 128)
        self.fc2   = nn.Linear(128, 10)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # 28x28 -> 14x14
        x = self.pool(F.relu(self.conv2(x)))  # 14x14 -> 7x7
        x = x.view(x.size(0), -1)             # flatten
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = SimpleCNN().to(device)
model = torch.compile(model)
# --- 3. Loss + Optimizer ---
criterion = nn.CrossEntropyLoss()

In [45]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
for epoch in range(10):
    model.train()
    running_loss = 0

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)

        x = x.view(x.size(0), 1, 28, 28) 

        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * x.size(0)

    print(f"Epoch {epoch} | Loss: {running_loss / len(train_ds):.4f}")

Epoch 0 | Loss: 0.1449
Epoch 1 | Loss: 0.1252
Epoch 2 | Loss: 0.1143
Epoch 3 | Loss: 0.1040
Epoch 4 | Loss: 0.0960


In [47]:
 
# --- 5. Evaluation ---
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        
        # reshape pour la CNN
        x = x.view(-1, 1, 28, 28)
        
        preds = model(x).argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)
accuracy = correct / total
print(f"Test Accuracy: {accuracy*100:.2f}%")

Test Accuracy: 91.91%


In [48]:
for epoch in range(5):
    model.train()
    running_loss = 0

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)

        x = x.view(x.size(0), 1, 28, 28) 

        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * x.size(0)

    print(f"Epoch {epoch} | Loss: {running_loss / len(train_ds):.4f}")
# --- 5. Evaluation ---
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        
        # reshape pour la CNN
        x = x.view(-1, 1, 28, 28)
        
        preds = model(x).argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)
accuracy = correct / total
print(f"Test Accuracy: {accuracy*100:.2f}%")

Epoch 0 | Loss: 0.0876
Epoch 1 | Loss: 0.0798
Epoch 2 | Loss: 0.0731
Epoch 3 | Loss: 0.0662
Epoch 4 | Loss: 0.0600
Test Accuracy: 91.62%


In [81]:
# --- 2. CNN Model ---
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 4, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(4, 8, kernel_size=3, padding=1)
        self.pool  = nn.MaxPool2d(4, 4)
        self.fc2   = DeepSpectralNet(layers_dim=[8*1*1,32,10],use_layernorm=True)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # 28x28 -> 14x14
        x = self.pool(F.relu(self.conv2(x)))  # 14x14 -> 7x7
        x = x.view(x.size(0), -1)             # flatten
        x = self.fc2(x)
        return x

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


model = SimpleCNN().to(device)
model = torch.compile(model)
# --- 3. Loss + Optimizer ---
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
print("Nombre total de paramètres (spectral):", count_parameters(model))

Nombre total de paramètres (spectral): 2146


In [88]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [89]:

# --- 4. Training Loop ---
for epoch in range(10):  
    model.train()
    running_loss = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        x = x.view(-1, 1, 28, 28)  # batch_size, 1 canal, 28x28
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * x.size(0)
    print(f"Epoch {epoch} | Loss: {running_loss / len(train_ds):.4f}")

Epoch 0 | Loss: 2.6188
Epoch 1 | Loss: 2.3273
Epoch 2 | Loss: 2.3136


KeyboardInterrupt: 

In [85]:
 
# --- 5. Evaluation ---
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        
        # reshape pour la CNN
        x = x.view(-1, 1, 28, 28)
        
        preds = model(x).argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)
accuracy = correct / total
print(f"Test Accuracy: {accuracy*100:.2f}%")

Test Accuracy: 77.67%


In [75]:

# --- 4. Training Loop ---
for epoch in range(5):  
    model.train()
    running_loss = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        x = x.view(-1, 1, 28, 28)  # batch_size, 1 canal, 28x28
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * x.size(0)
    print(f"Epoch {epoch} | Loss: {running_loss / len(train_ds):.4f}")
     
# --- 5. Evaluation ---
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        
        # reshape pour la CNN
        x = x.view(-1, 1, 28, 28)
        
        preds = model(x).argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)
accuracy = correct / total
print(f"Test Accuracy: {accuracy*100:.2f}%")

Epoch 0 | Loss: 0.1127
Epoch 1 | Loss: 0.1070
Epoch 2 | Loss: 0.0982
Epoch 3 | Loss: 0.0893
Epoch 4 | Loss: 0.0858
Test Accuracy: 89.77%


In [70]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        
        # reshape pour la CNN
        x = x.view(-1, 1, 28, 28)
        
        preds = model(x).argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)
accuracy = correct / total
print(f"Test Accuracy: {accuracy*100:.2f}%")

Test Accuracy: 89.89%


In [56]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool  = nn.MaxPool2d(2, 2)
        self.fc1   = nn.Linear(32*7*7, 128)
        self.fc2   = nn.Linear(128, 10)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # 28x28 -> 14x14
        x = self.pool(F.relu(self.conv2(x)))  # 14x14 -> 7x7
        x = x.view(x.size(0), -1)             # flatten
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
model = SimpleCNN()
print("Nombre total de paramètres (classique):", count_parameters(model))

Nombre total de paramètres (classique): 206922


In [120]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SpectralCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        
        # 1. Extraction Spatiale Dense
        # Augmentation des canaux pour capturer plus de primitives visuelles
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2), # 28x28 -> 14x14
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2)  # 14x14 -> 7x7
        )
        
        # 2. Global Pooling pour préparer l'entrée du SBN
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        
        # 3. SBN Large (High-Capacity)
        # On augmente 'hidden_dim' à 256 pour égaler le nombre de paramètres d'un gros MLP.
        # layers_dim = [Entrée, Cachée, Sortie]
        # Le degré algébrique final sera de 2^L = 2^2 = 4 (modèle quartique)[cite: 43, 53].
        self.sbn_head = SpectralBillinearNet(
            layers_dim=[128, 256, num_classes], 
            use_layernorm=True      # Indispensable pour stabiliser les activations bilinéaires [cite: 109, 176]
        )

    def forward(self, x):
        x = self.features(x)
        x = self.gap(x)
        x = x.view(x.size(0), -1) 
        
        # Le SBN applique : y = Σ λ_k * (x^T p_k)(l_k^T x) + β [cite: 30, 39]
        x = self.sbn_head(x)
        return x

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


model = SpectralCNN().to(device)
model = torch.compile(model)
# --- 3. Loss + Optimizer ---
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-1)
print("Nombre total de paramètres (spectral):", count_parameters(model))

Nombre total de paramètres (spectral): 293834


In [125]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [130]:

# --- 4. Training Loop ---
for epoch in range(5):  
    model.train()
    running_loss = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        x = x.view(-1, 1, 28, 28)  # batch_size, 1 canal, 28x28
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * x.size(0)
    print(f"Epoch {epoch} | Loss: {running_loss / len(train_ds):.4f}")

Epoch 0 | Loss: 0.0683
Epoch 1 | Loss: 0.0623
Epoch 2 | Loss: 0.0584
Epoch 3 | Loss: 0.0549
Epoch 4 | Loss: 0.0512


In [131]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        
        # reshape pour la CNN
        x = x.view(-1, 1, 28, 28)
        
        preds = model(x).argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)
accuracy = correct / total
print(f"Test Accuracy: {accuracy*100:.2f}%")

Test Accuracy: 91.39%


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SpectralCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        
        # 1. Extraction Spatiale Dense
        # Augmentation des canaux pour capturer plus de primitives visuelles
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2), # 28x28 -> 14x14
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2)  # 14x14 -> 7x7
        )
        
        # 2. Global Pooling pour préparer l'entrée du SBN
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        
        # 3. SBN Large (High-Capacity)
        # On augmente 'hidden_dim' à 256 pour égaler le nombre de paramètres d'un gros MLP.
        # layers_dim = [Entrée, Cachée, Sortie]
        # Le degré algébrique final sera de 2^L = 2^2 = 4 (modèle quartique)[cite: 43, 53].
        self.sbn_head = DeepSpectralNet(
            layers_dim=[128, 256, num_classes], 
            use_layernorm=True     
        )

    def forward(self, x):
        x = self.features(x)
        x = self.gap(x)
        x = x.view(x.size(0), -1) 
        
        # Le SBN applique : y = Σ λ_k * (x^T p_k)(l_k^T x) + β [cite: 30, 39]
        x = self.sbn_head(x)
        return x

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


model = SpectralCNN().to(device)
model = torch.compile(model)
# --- 3. Loss + Optimizer ---
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-1)
print("Nombre total de paramètres (spectral):", count_parameters(model))

Nombre total de paramètres (spectral): 211530


In [None]:

# --- 4. Training Loop ---
for epoch in range(5):  
    model.train()
    running_loss = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        x = x.view(-1, 1, 28, 28)  # batch_size, 1 canal, 28x28
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * x.size(0)
    print(f"Epoch {epoch} | Loss: {running_loss / len(train_ds):.4f}")

In [137]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        
        # reshape pour la CNN
        x = x.view(-1, 1, 28, 28)
        
        preds = model(x).argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)
accuracy = correct / total
print(f"Test Accuracy: {accuracy*100:.2f}%")

Test Accuracy: 10.03%


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


model = SpectralBillinearNet(
    layers_dim=[784,128,128,10],
    use_final_linear=True,       # CRUCIAL pour classification
    use_layernorm=True
).to('cpu')
model = torch.compile(model)
# --- 3. Loss + Optimizer ---
criterion = nn.CrossEntropyLoss()

In [161]:

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

for epoch in range(5):
    model.train()
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        logits = model(x)     # logits bruts
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch} | loss = {loss.item():.4f}")

Epoch 0 | loss = 0.1833
Epoch 1 | loss = 0.2670
Epoch 2 | loss = 0.2460
Epoch 3 | loss = 0.2093
Epoch 4 | loss = 0.2236


In [162]:
model.eval()

correct = 0
total = 0

with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)

        logits = model(x)              # (batch, 10)
        preds = logits.argmax(dim=1)   # classe prédite

        correct += (preds == y).sum().item()
        total += y.size(0)

accuracy = correct / total
print(f"Test accuracy: {accuracy * 100:.2f}%")

Test accuracy: 88.85%
