In [1]:
# CIFAR Dataset Preprocessing 

import os
import pickle
import numpy as np
import torch
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from pathlib import Path
import torch.nn.functional as F

DATA_ROOT = r'C:\Users\Akshay Patil\Desktop\cifar-10-python\cifar-10-batches-py'
CIFAR_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR_STD = (0.2470, 0.2435, 0.2616)

def load_cifar_batch(file_path):
    """Load CIFAR-10 batch from local file"""
    with open(file_path, 'rb') as f:
        batch = pickle.load(f, encoding='latin1')
    images = batch['data'].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)
    labels = batch['labels']
    return images, labels

class TestDataset(torch.utils.data.Dataset):
    """Custom dataset to load test images in original order"""
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.image_files = sorted(self.root_dir.glob('*.png'), 
                                 key=lambda x: int(x.stem.split('_')[1]))
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        image = Image.open(img_path)
        if self.transform:
            image = self.transform(image)
        return image

def generate_sample_data():
    # Define transforms with normalization
    train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD)
    ])
    
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD)
    ])

    trainset = torchvision.datasets.ImageFolder(
        root=str(Path(DATA_ROOT) / 'train_organized'),
        transform=train_transform
    )
    trainloader = DataLoader(trainset, batch_size=60000, shuffle=False)
    input_images, class_labels = next(iter(trainloader))
    
    testset = TestDataset(
        root_dir=Path(DATA_ROOT) / 'test_original',
        transform=test_transform
    )
    testloader = DataLoader(testset, batch_size=10000, shuffle=False)
    test_images = next(iter(testloader))
    
    _, test_labels = load_cifar_batch(Path(DATA_ROOT) / 'test_batch')
    test_labels = torch.tensor(test_labels, dtype=torch.long)
    
    return input_images, class_labels, test_images, test_labels

if __name__ == '__main__':
    train_images, train_labels, test_images, test_labels = generate_sample_data()
    
    print(f"Training data shape: {train_images.shape}")
    print(f"Training labels shape: {train_labels.shape}")
    print(f"Test data shape: {test_images.shape}")
    print(f"Test labels shape: {test_labels.shape}")
    train_images = train_images.reshape(49999,3,32,32)
    test_images = test_images.reshape(10000,3,32,32)
    train_images = F.normalize(train_images,p=2,dim=1)
    test_images = F.normalize(test_images,p=2,dim=1)

Training data shape: torch.Size([49999, 3, 32, 32])
Training labels shape: torch.Size([49999])
Test data shape: torch.Size([10000, 3, 32, 32])
Test labels shape: torch.Size([10000])


In [None]:
# Chebyshev NN for CIFAR (Same Width across all layers)
# Architechture can be made more in depth...
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SharedChebyshevLayer(nn.Module):
    def __init__(self, degree):
        super().__init__()
        self.degree = degree
        self.coeffs = nn.Parameter(torch.randn(3072, degree+1, dtype=torch.float64))
        
    def _chebyshev_basis(self, x):
        """Inplace-operation-free Chebyshev polynomial basis computation"""
        batch_size, feat_dim = x.shape
        x_flat = x.reshape(-1)
        
        T_list = [torch.ones_like(x_flat)]
        if self.degree >= 1:
            T_list.append(x_flat)
            for k in range(1, self.degree):
                Tk = 2 * x_flat * T_list[k] - T_list[k-1]
                T_list.append(Tk)

        T = torch.stack(T_list, dim=1)
        return T.view(batch_size, feat_dim, self.degree+1)

    def forward(self, x):
        basis = self._chebyshev_basis(x)
        return torch.einsum('bik,ik->bi', basis, self.coeffs)

class ChebyshevSharedNN(nn.Module):
    def __init__(self, num_layers=3, cheb_degree=10, num_classes=10):
        super().__init__()
        self.num_layers = num_layers
        self.cheb_degree = cheb_degree
        self.num_classes = num_classes
        self.init_transform = nn.Linear(3072, 3072, dtype=torch.float64)
        self.transforms = nn.ModuleList([
            nn.Linear(3072, 3072, dtype=torch.float64) 
            for _ in range(num_layers)
        ])
        self.cheb_layers = nn.ModuleList([
            SharedChebyshevLayer(cheb_degree)
            for _ in range(num_layers+1)
        ])
        self.class_projections = nn.ModuleList([
            nn.Linear(3072, 1, dtype=torch.float64)
            for _ in range(num_classes)
        ])

    def forward(self, x):
        x = x.to(torch.float64)
        x = F.normalize(self.init_transform(x), p=2.0, dim=1)
        for transform, cheb_layer in zip(self.transforms, self.cheb_layers[0:3]):
            modulated = cheb_layer(x)
            x = F.normalize(transform(modulated), p=2.0, dim=1)
        x = self.cheb_layers[3](x)
        x = F.normalize(x,p = 2.0 , dim = 1)
        return torch.cat([proj(x) for proj in self.class_projections], dim=1)



input_vectors = torch.tensor(train_images, dtype=torch.float64)
target_labels = train_labels 
train_dataset = TensorDataset(input_vectors, target_labels)
test_dataset = TensorDataset(
    torch.tensor(test_images, dtype=torch.float64),
    torch.tensor(test_labels, dtype=torch.long)
)
train_loader = DataLoader(train_dataset, batch_size=1000, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

model = ChebyshevSharedNN(num_layers=3, cheb_degree=20).to(device) # hyperparmeter degree of the series , based on this we can create a lesser depth circuit by giving more degrees of freedom to the activation function tuned to the dataset...
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, factor=0.5)

for epoch in range(100):
    model.train()
    train_loss, train_correct = 0.0, 0
    
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * inputs.size(0)
        preds = outputs.argmax(dim=1)
        train_correct += (preds == labels).sum().item()
    
    model.eval()
    test_loss, test_correct = 0.0, 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            test_loss += criterion(outputs, labels).item() * inputs.size(0)
            preds = outputs.argmax(dim=1)
            test_correct += (preds == labels).sum().item()
    train_loss /= len(train_loader.dataset)
    train_acc = train_correct / len(train_loader.dataset)
    test_loss /= len(test_loader.dataset)
    test_acc = test_correct / len(test_loader.dataset)
    
    # Update learning rate
    scheduler.step(test_loss)
    
    print(f'Epoch {epoch+1}/100')
    print(f'Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f}')
    print(f'Test Loss: {test_loss:.4f} | Acc: {test_acc:.4f}')
    print('-'*50)


  input_vectors = torch.tensor(train_images, dtype=torch.float64)
  torch.tensor(test_images, dtype=torch.float64),
  torch.tensor(test_labels, dtype=torch.long)


Epoch 1/100
Train Loss: 2.2731 | Acc: 0.2192
Test Loss: 2.2382 | Acc: 0.3354
--------------------------------------------------
Epoch 2/100
Train Loss: 2.2167 | Acc: 0.3728
Test Loss: 2.1940 | Acc: 0.4013
--------------------------------------------------
Epoch 3/100
Train Loss: 2.1734 | Acc: 0.4229
Test Loss: 2.1531 | Acc: 0.4369
--------------------------------------------------
Epoch 4/100
Train Loss: 2.1284 | Acc: 0.4583
Test Loss: 2.1086 | Acc: 0.4623
--------------------------------------------------
Epoch 5/100
Train Loss: 2.0826 | Acc: 0.4805
Test Loss: 2.0687 | Acc: 0.4732
--------------------------------------------------
Epoch 6/100
Train Loss: 2.0349 | Acc: 0.4989
Test Loss: 2.0275 | Acc: 0.4816
--------------------------------------------------
Epoch 7/100
Train Loss: 1.9849 | Acc: 0.5159
Test Loss: 1.9818 | Acc: 0.4858
--------------------------------------------------
Epoch 8/100
Train Loss: 1.9332 | Acc: 0.5261
Test Loss: 1.9420 | Acc: 0.4958
---------------------------

In [None]:
# Chebyshev FeedForward NN for CIFAR (Different Width for Different Layers)


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SharedChebyshevLayer(nn.Module):
    def __init__(self, degree: int, input_dim: int):
        super().__init__()
        self.degree = degree
        self.coeffs = nn.Parameter(
            torch.randn(input_dim, degree + 1, dtype=torch.float64)
        )

    def _chebyshev_basis(self, x: torch.Tensor) -> torch.Tensor:
        """
        Computes Chebyshev basis T_0 … T_degree for every scalar in x
        without in-place ops (keeps autograd happy).
        """
        bsz, feat = x.shape
        x_flat = x.reshape(-1)

        basis = [torch.ones_like(x_flat)]
        if self.degree >= 1:
            basis.append(x_flat)
            for k in range(1, self.degree):
                Tk = 2 * x_flat * basis[k] - basis[k - 1]
                basis.append(Tk)

        T = torch.stack(basis, dim=1)                   
        return T.view(bsz, feat, self.degree + 1)      

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        T = self._chebyshev_basis(x)
        return torch.einsum("bik,ik->bi", T, self.coeffs) 


class ChebyshevSharedNN(nn.Module):
    def __init__(
        self,
        num_layers: int = 5,
        cheb_degree: int = 15,
        num_classes: int = 10,
        hidden_dims: list[int] | None = None,
    ):
        super().__init__()
        if hidden_dims is None:
            hidden_dims = [3072] * (num_layers + 1)
        assert len(hidden_dims) == num_layers + 1, (
        )

        self.init_linear   = nn.Linear(3072, hidden_dims[0], dtype=torch.float64)
        self.transforms    = nn.ModuleList()
        self.cheb_layers   = nn.ModuleList()

        self.cheb_layers.append(SharedChebyshevLayer(cheb_degree, hidden_dims[0]))

        for in_dim, out_dim in zip(hidden_dims[:-1], hidden_dims[1:]):
            self.transforms.append(nn.Linear(in_dim, out_dim, dtype=torch.float64))
            self.cheb_layers.append(SharedChebyshevLayer(cheb_degree, out_dim))

        self.heads = nn.ModuleList(
            [nn.Linear(hidden_dims[-1], 1, dtype=torch.float64) for _ in range(num_classes)]
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.to(torch.float64)

        # Initial projection + Chebyshev modulation
        x = F.normalize(self.init_linear(x), p=2, dim=1)
        x = F.normalize(self.cheb_layers[0](x), p=2, dim=1)

        # Main stack
        for linear, cheb in zip(self.transforms, self.cheb_layers[1:]):
            x = F.normalize(linear(x),           p=2, dim=1)
            x = F.normalize(cheb(x),             p=2, dim=1)
        return torch.cat([head(x) for head in self.heads], dim=1)

input_vectors = torch.tensor(train_images, dtype=torch.float64)
target_labels = torch.tensor(train_labels, dtype=torch.long)

train_ds = TensorDataset(input_vectors, target_labels)
test_ds  = TensorDataset(
    torch.tensor(test_images, dtype=torch.float64),
    torch.tensor(test_labels, dtype=torch.long)
)

train_loader = DataLoader(train_ds, batch_size=1000, shuffle=True)
test_loader  = DataLoader(test_ds,  batch_size=1000, shuffle=False)

model = ChebyshevSharedNN(
    num_layers   = 5,
    cheb_degree  = 20,
    hidden_dims  = [3072, 1536 , 768 , 384 , 192 , 81 ], 
).to(device)

criterion  = nn.CrossEntropyLoss()
optimizer  = optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-4)
scheduler  = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, factor=0.5)

EPOCHS = 1000
for epoch in range(1, EPOCHS + 1):
    model.train()
    running_loss, running_correct = 0.0, 0
    for X, y in train_loader:
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()

        logits = model(X)
        loss   = criterion(logits, y)
        loss.backward()
        optimizer.step()

        running_loss   += loss.item() * X.size(0)
        running_correct += (logits.argmax(1) == y).sum().item()

    train_loss = running_loss / len(train_loader.dataset)
    train_acc  = running_correct / len(train_loader.dataset)

    model.eval()
    val_loss, val_correct = 0.0, 0
    with torch.no_grad():
        for X, y in test_loader:
            X, y = X.to(device), y.to(device)
            logits = model(X)

            val_loss     += criterion(logits, y).item() * X.size(0)
            val_correct  += (logits.argmax(1) == y).sum().item()

    val_loss /= len(test_loader.dataset)
    val_acc  = val_correct / len(test_loader.dataset)

    scheduler.step(val_loss)
    print(
        f"Epoch {epoch:3d}/{EPOCHS} | "
        f"Train: loss {train_loss:.4f}, acc {train_acc:.4f} | "
        f"Val: loss {val_loss:.4f}, acc {val_acc:.4f}"
    )


  input_vectors = torch.tensor(train_images, dtype=torch.float64)
  target_labels = torch.tensor(train_labels, dtype=torch.long)
  torch.tensor(test_images, dtype=torch.float64),
  torch.tensor(test_labels, dtype=torch.long)


Epoch   1/1000 | Train: loss 2.3040, acc 0.1009 | Val: loss 2.3015, acc 0.1155
Epoch   2/1000 | Train: loss 2.2808, acc 0.1263 | Val: loss 2.2460, acc 0.1960
Epoch   3/1000 | Train: loss 2.2233, acc 0.2750 | Val: loss 2.2031, acc 0.3166
Epoch   4/1000 | Train: loss 2.1926, acc 0.3330 | Val: loss 2.1810, acc 0.3414
Epoch   5/1000 | Train: loss 2.1744, acc 0.3673 | Val: loss 2.1666, acc 0.3934
Epoch   6/1000 | Train: loss 2.1603, acc 0.3904 | Val: loss 2.1553, acc 0.3952
Epoch   7/1000 | Train: loss 2.1499, acc 0.4081 | Val: loss 2.1488, acc 0.4030
Epoch   8/1000 | Train: loss 2.1413, acc 0.4190 | Val: loss 2.1426, acc 0.4240
Epoch   9/1000 | Train: loss 2.1347, acc 0.4300 | Val: loss 2.1345, acc 0.4235
Epoch  10/1000 | Train: loss 2.1267, acc 0.4405 | Val: loss 2.1303, acc 0.4272
Epoch  11/1000 | Train: loss 2.1217, acc 0.4467 | Val: loss 2.1247, acc 0.4385
Epoch  12/1000 | Train: loss 2.1153, acc 0.4555 | Val: loss 2.1203, acc 0.4418
Epoch  13/1000 | Train: loss 2.1092, acc 0.4635 | Va

In [None]:
# Direct Lagrangian OVR (Per Class , Different Strategies can be used as in Lagrangian NN paper , also depth and width can be changed based on traditional
# NN methods that gives the best result accuracy wise) - 

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



input_vectors = torch.tensor(train_images, dtype=torch.float64)
target_labels = train_labels  # Convert one-hot to class labels

class LagrangianNNOVRClassifier(nn.Module):
    def __init__(self, degree, num_classes=10):
        super().__init__()
        self.degree = degree
        self.num_classes = num_classes
        
        self.transformation_params = nn.ParameterList([
            nn.Parameter(torch.empty(3072,3072 , dtype=torch.float64)) 
            for _ in range(degree)
        ])
        

        self.bias_params = nn.ParameterList([
            nn.Parameter(torch.empty(3072, dtype=torch.float64))
            for _ in range(degree)
        ])
        

        self.alpha_params = nn.ParameterList([
            nn.Parameter(torch.empty(num_classes, 3072, dtype=torch.float64))
            for _ in range(degree-1)
        ])
        

        self.class_projections = nn.ModuleList([
            nn.Linear(3072, 1, dtype=torch.float64) 
            for _ in range(num_classes)
        ])

        with torch.no_grad():
            for param in self.transformation_params:
                nn.init.kaiming_normal_(param, mode='fan_in', nonlinearity='linear')
            for bias in self.bias_params:
                nn.init.zeros_(bias)  
            for alpha in self.alpha_params:
                nn.init.normal_(alpha)
            for proj in self.class_projections:
                nn.init.kaiming_normal_(proj.weight)

    def forward(self, x):
        x = x.to(torch.float64)
        batch_size = x.size(0)
        

        x = torch.matmul(self.transformation_params[0], x.unsqueeze(-1)).squeeze()
        x = x + self.bias_params[0]  
        

        class_outputs = []
        for c in range(self.num_classes):
            x_class = x.clone()
            

            for i, (transform, bias, alpha) in enumerate(zip(
                self.transformation_params[1:], 
                self.bias_params[1:], 
                self.alpha_params
            )):

                alpha_clamped = torch.sigmoid(alpha[c])
                modulated = x_class * alpha_clamped
                
                modulated = torch.matmul(transform, modulated.unsqueeze(-1)).squeeze()
                modulated = modulated + bias  
                x_class = F.normalize(modulated, p=2.0, dim=1)
            
            class_outputs.append(x_class)
        
        scores = torch.cat([
            proj(features) for proj, features in zip(self.class_projections, class_outputs)
        ], dim=1)
        
        return scores

def ovr_hinge_loss(scores, targets):
    batch_size = scores.size(0)
    correct_scores = scores[torch.arange(batch_size), targets].unsqueeze(1)
    margins = 1 - (correct_scores - scores)
    margins[torch.arange(batch_size), targets] = 0 
    return torch.clamp(margins, min=0).sum(dim=1).mean()

degree = 10
num_classes = 10
model = LagrangianNNOVRClassifier(degree, num_classes).to(device)

dataset = TensorDataset(input_vectors, train_labels)
batch_size = 512
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

test_dataset = TensorDataset(
    torch.tensor(test_images, dtype=torch.float64),
    torch.tensor(test_labels, dtype=torch.long)
)
test_dataloader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

optimizer = optim.AdamW(model.parameters(), lr=1e-2, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, factor=0.5)

for epoch in range(100):
    model.train()
    epoch_loss = 0.0
    train_correct = 0
    
    for batch_input, batch_target in dataloader:
        batch_input, batch_target = batch_input.to(device), batch_target.to(device)
        
        optimizer.zero_grad()
        scores = model(batch_input)
        loss = ovr_hinge_loss(scores, batch_target)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        preds = torch.argmax(scores, dim=1)
        train_correct += (preds == batch_target).sum().item()
    
    model.eval()
    val_correct = 0
    with torch.no_grad():
        for batch_input, batch_target in dataloader:
            batch_input, batch_target = batch_input.to(device), batch_target.to(device)
            scores = model(batch_input)
            preds = torch.argmax(scores, dim=1)
            val_correct += (preds == batch_target).sum().item()

    test_correct = 0
    with torch.no_grad():
        for batch_input, batch_target in test_dataloader:
            batch_input, batch_target = batch_input.to(device), batch_target.to(device)
            scores = model(batch_input)
            preds = torch.argmax(scores, dim=1)
            test_correct += (preds == batch_target).sum().item()
    
    
    train_acc = train_correct / len(dataset)
    val_acc = val_correct / len(dataset)
    test_acc = test_correct / len(test_dataset)
    avg_loss = epoch_loss / len(dataloader)
    
    scheduler.step(avg_loss)
    
    print(f'Epoch [{epoch+1}/100] |'
          f'Loss: {avg_loss:.4f} |'
          f'Train Acc: {train_acc:.4f} |'
          f'Val Acc: {val_acc:.4f} |'
          f'Test Acc: {test_acc:.4f}|')


  input_vectors = torch.tensor(train_images, dtype=torch.float64)
  torch.tensor(test_images, dtype=torch.float64),
  torch.tensor(test_labels, dtype=torch.long)


In [None]:
# Convolutional Lagrangian OVR (per class and not per sample) - Just 10000 Parameters , Optimum Arcitechtures that work best for this via traditional NN that 
# are already implemented can be used.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim
import numpy as np


def generate_sample_data():
    return train_images.reshape(-1,3,32,32), train_labels, test_images.reshape(-1,3,32,32), test_labels

class ConvolutionalLagrangianNN(nn.Module):
    """
    10 shared Conv layers + class-specific alpha modulations +
    1 Lagrangian transform layer + ten 1-D heads.

    Total parameters: 10 000
    """
    def __init__(self,
                 num_classes: int = 10,
                 conv_degree: int = 10,
                 lagrangian_degree: int = 1):
        super().__init__()
        self.num_classes = num_classes
        self.conv_degree = conv_degree
        self.lagrangian_degree = lagrangian_degree


        channels = [3] + [9] * conv_degree          
        self.pooling_layers = {1, 3, 5, 7, 9}     

        self.conv_layers      = nn.ModuleList()
        self.conv_batch_norms = nn.ModuleList()
        for i in range(conv_degree):
            self.conv_layers.append(
                nn.Conv2d(channels[i], channels[i + 1],
                          kernel_size=3, padding=1, bias=True)
            )
            self.conv_batch_norms.append(
                nn.BatchNorm2d(channels[i + 1])
            )

        self.final_conv_size  = 1 * 1 * channels[-1]    
        self.lagrangian_size  = 32                     
        self.dim_expansion    = nn.Linear(self.final_conv_size,
                                          self.lagrangian_size)

        self.conv_alpha_params = nn.ParameterList([
            nn.Parameter(torch.empty(num_classes, channels[i + 1]))
            for i in range(conv_degree)
        ])

        self.lagrangian_transforms = nn.ParameterList([
            nn.Parameter(torch.empty(self.lagrangian_size,
                                     self.lagrangian_size))
            for _ in range(lagrangian_degree)
        ])
        self.lagrangian_biases = nn.ParameterList([
            nn.Parameter(torch.empty(self.lagrangian_size))
            for _ in range(lagrangian_degree)
        ])
        self.lagrangian_alpha_params = nn.ParameterList([
            nn.Parameter(torch.empty(num_classes, self.lagrangian_size))
            for _ in range(lagrangian_degree)
        ])

        self.class_projections = nn.ModuleList([
            nn.Linear(self.lagrangian_size, 1)
            for _ in range(num_classes)
        ])

        self._initialize_parameters()


    def _initialize_parameters(self):
        with torch.no_grad():
            for conv in self.conv_layers:
                nn.init.kaiming_normal_(conv.weight, mode='fan_out',
                                        nonlinearity='relu')
                nn.init.zeros_(conv.bias)
            nn.init.kaiming_normal_(self.dim_expansion.weight,
                                    mode='fan_in', nonlinearity='linear')
            nn.init.zeros_(self.dim_expansion.bias)

            for alpha in self.conv_alpha_params:
                nn.init.normal_(alpha, mean=0., std=0.1)
            for T in self.lagrangian_transforms:
                nn.init.kaiming_normal_(T, mode='fan_in',
                                        nonlinearity='linear')
            for b in self.lagrangian_biases:
                nn.init.zeros_(b)
            for alpha in self.lagrangian_alpha_params:
                nn.init.normal_(alpha, mean=0., std=0.1)
            for proj in self.class_projections:
                nn.init.kaiming_normal_(proj.weight)
                nn.init.zeros_(proj.bias)

    def forward(self, x):
        x = x.float()
        batch_size = x.size(0)
        class_outputs = []

        for c in range(self.num_classes):
            x_c = x.clone()

            for i, (conv, bn, alpha) in enumerate(zip(
                    self.conv_layers, self.conv_batch_norms,
                    self.conv_alpha_params)):
                x_c = conv(x_c)
                x_c = bn(x_c)
                alpha_c = torch.sigmoid(alpha[c]).view(1, -1, 1, 1)
                x_c = x_c * alpha_c

                if (i + 1) in self.pooling_layers:
                    x_c = F.max_pool2d(x_c, 2)

            x_c = self.dim_expansion(x_c.view(batch_size, -1))

            T, b = self.lagrangian_transforms[0], self.lagrangian_biases[0]
            x_c = x_c @ T.t() + b
            alpha_L = torch.sigmoid(self.lagrangian_alpha_params[0][c])
            x_c = x_c * alpha_L
            x_c = F.normalize(x_c, p=2.0, dim=1)

            class_outputs.append(x_c)

        scores = torch.cat(
            [proj(feat) for proj, feat in zip(self.class_projections,
                                              class_outputs)],
            dim=1)
        return scores


def ovr_hinge_loss(scores, targets):
    """One-vs-rest hinge loss (margin = 1)."""
    correct = scores[torch.arange(scores.size(0)), targets].unsqueeze(1)
    margins = 1 - (correct - scores)
    margins[torch.arange(scores.size(0)), targets] = 0
    return torch.clamp(margins, min=0).sum(dim=1).mean()


def train_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    input_images, class_labels, test_images, test_labels = generate_sample_data()

    
    model = ConvolutionalLagrangianNN().to(device)

    train_dataset = TensorDataset(input_images, class_labels)
    test_dataset = TensorDataset(test_images, test_labels)
    
    batch_size = 256
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5)
    
    print(f"Model initialized with {sum(p.numel() for p in model.parameters())} parameters")
    print(f"Training on {len(train_dataset)} samples, testing on {len(test_dataset)} samples")

    num_epochs = 100
    best_test_acc = 0.0
    
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0
        train_correct = 0
        
        for batch_idx, (batch_input, batch_target) in enumerate(train_dataloader):
            batch_input, batch_target = batch_input.to(device), batch_target.to(device)
            
            optimizer.zero_grad()
            scores = model(batch_input)
            loss = ovr_hinge_loss(scores, batch_target)
            loss.backward()
            
            optimizer.step()

            epoch_loss += loss.item()
            preds = torch.argmax(scores, dim=1)
            train_correct += (preds == batch_target).sum().item()
            
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}')
        
        if epoch % 20 == 0:
            model.eval()
            val_correct = 0
            with torch.no_grad():
                for batch_input, batch_target in train_dataloader:
                    batch_input, batch_target = batch_input.to(device), batch_target.to(device)
                    scores = model(batch_input)
                    preds = torch.argmax(scores, dim=1)
                    val_correct += (preds == batch_target).sum().item()

            test_correct = 0
            with torch.no_grad():
                for batch_input, batch_target in test_dataloader:
                    batch_input, batch_target = batch_input.to(device), batch_target.to(device)
                    scores = model(batch_input)
                    preds = torch.argmax(scores, dim=1)
                    test_correct += (preds == batch_target).sum().item()
    
            train_acc = train_correct / len(train_dataset)
            val_acc = val_correct / len(train_dataset)
            test_acc = test_correct / len(test_dataset)

        avg_loss = epoch_loss / len(train_dataloader)

        scheduler.step(avg_loss)

        if test_acc > best_test_acc:
            best_test_acc = test_acc
            torch.save(model.state_dict(), 'best_conv_lagrangian_model.pth')
        
        print(f'Epoch [{epoch+1}/{num_epochs}] | '
              f'Loss: {avg_loss:.4f} | '
              f'Train Acc: {train_acc:.4f} | '
              f'Val Acc: {val_acc:.4f} | '
              f'Test Acc: {test_acc:.4f} | '
              f'Best Test: {best_test_acc:.4f}')

if __name__ == "__main__":
    train_model()

# One can implement all the Methods or Strategies discussed in the Best NN is Lagrangian NN Paper..here only OVR along with convolution is implemented...
# One can also increase the number of parameters as here only 10000 parameters are considered....
# And also evaluation can be done only at particular epochs to save the compute time drastically...

Model initialized with 10000 parameters
Training on 49999 samples, testing on 10000 samples
Epoch 1, Batch 0, Loss: 8.9698
Epoch 1, Batch 100, Loss: 4.2937
Epoch [1/100] | Loss: 5.3860 | Train Acc: 0.2946 | Val Acc: 0.3659 | Test Acc: 0.3674 | Best Test: 0.3674
Epoch 2, Batch 0, Loss: 3.9823
Epoch 2, Batch 100, Loss: 3.6648
Epoch [2/100] | Loss: 3.8197 | Train Acc: 0.4019 | Val Acc: 0.3992 | Test Acc: 0.4001 | Best Test: 0.4001
Epoch 3, Batch 0, Loss: 3.6025
Epoch 3, Batch 100, Loss: 3.5653
Epoch [3/100] | Loss: 3.3597 | Train Acc: 0.4425 | Val Acc: 0.4568 | Test Acc: 0.4483 | Best Test: 0.4483
Epoch 4, Batch 0, Loss: 3.4796
Epoch 4, Batch 100, Loss: 2.7732
Epoch [4/100] | Loss: 3.0767 | Train Acc: 0.4677 | Val Acc: 0.4801 | Test Acc: 0.4726 | Best Test: 0.4726
Epoch 5, Batch 0, Loss: 3.1537
Epoch 5, Batch 100, Loss: 3.1095
Epoch [5/100] | Loss: 2.8908 | Train Acc: 0.4888 | Val Acc: 0.4802 | Test Acc: 0.4741 | Best Test: 0.4741
Epoch 6, Batch 0, Loss: 3.0613
Epoch 6, Batch 100, Loss: 2

In [1]:
# Any and Every Existing type of NN architechture can be applied the Lagrangian Per Sample (Probabilistic Interpretation of NN from Getting Exact Equations of NN
# Paper), Lagrangian Per Class (SVM Analog of NN - Various Stategies from Best NN is Lagrangian NN Paper) or Chebyshev Learnable Activation (Deterministic Interpretation  
# of NN Paper along with Bayes Theorem Every Term Explicity Computed for NN paper) - Advantage of using this is fully Interpretable Equation wise version 
# of NN (either N linear Lagrangian Equation for N samples , or M linear Lagrangian Equations (Various Stragtegies) for M classes of N samples or one
# single high degree Polynomial equation for all classes (Chebyshev Activation NN))


In [None]:
# Other Strategies in the Lagrangian OVR other than the one mentioned in paper is , forming groups per class and assigning one lagrangian vector to that group
# to improve the expressivity , if the groups are equal to 1 then it forms as Lagrangian per sample , a case which is already implemented...

In [None]:
# Chebyshev Convolutional NN , separate activation per layer
import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
from pathlib import Path

import torch._dynamo
torch._dynamo.config.suppress_errors = True

torch.set_num_threads(8)         
torch.backends.mkldnn.enabled = True
device = torch.device("cpu")      
dtype = torch.float32        

def cheb_nodes_first_kind(n, a=-1.0, b=1.0):
    k = torch.arange(1, n + 1, dtype=dtype)
    x = torch.cos((2 * k - 1) * math.pi / (2 * n))
    return 0.5 * (a + b) + 0.5 * (b - a) * x

def cheb_nodes_second_kind(n, a=-1.0, b=1.0):
    k = torch.arange(0, n, dtype=dtype)
    x = torch.cos(k * math.pi / (n - 1))
    return 0.5 * (a + b) + 0.5 * (b - a) * x

class InputScaler(nn.Module):
    def __init__(self, momentum=0.9):
        super().__init__()
        self.momentum = momentum
        self.register_buffer("running_min", torch.tensor(0., dtype=dtype))
        self.register_buffer("running_max", torch.tensor(0., dtype=dtype))
        self.register_buffer("initialized", torch.tensor(False))

    @torch.no_grad()
    def update_stats(self, x):
        batch_min, batch_max = x.min(), x.max()
        if not self.initialized:
            self.running_min.copy_(batch_min)
            self.running_max.copy_(batch_max)
            self.initialized.fill_(True)
        else:
            self.running_min.mul_(self.momentum).add_(batch_min * (1 - self.momentum))
            self.running_max.mul_(self.momentum).add_(batch_max * (1 - self.momentum))

    def scale(self, x, training=True):
        if training and self.training:
            self.update_stats(x)

        if not self.initialized:
            min_val, max_val = x.min(), x.max()
        else:
            min_val, max_val = self.running_min, self.running_max

        range_val = torch.clamp(max_val - min_val, min=1e-8)
        scaled = 2.0 * (x - min_val) / range_val - 1.0
        return torch.clamp(scaled, -1.0, 1.0)

class SharedChebyshevLayer(nn.Module):
    def __init__(self, feature_size, degree, use_nodes=True, node_kind="first"):
        super().__init__()
        self.degree = degree
        self.feature_size = feature_size
        self.coeffs = nn.Parameter(torch.randn(degree + 1, dtype=dtype) * 0.01)
        self.scaler = InputScaler()
        
        self.register_buffer("nodes",
            cheb_nodes_first_kind(degree+1) if use_nodes and node_kind=="first"
            else cheb_nodes_second_kind(degree+1) if use_nodes else torch.linspace(-1,1,degree+1)
        )

    def _scale_inputs(self, x, training=True):
        return self.scaler.scale(x, training)

    def _basis(self, x):
        T0, T1 = torch.ones_like(x), x
        basis = [T0, T1]
        for k in range(1, self.degree):
            Tk = 2.0 * x * basis[-1] - basis[-2]
            if k < self.nodes.numel():
                node_w = torch.exp(-0.5 * torch.abs(x - self.nodes[k]))
                Tk *= node_w
            basis.append(Tk)
        return torch.stack(basis[:self.degree+1], dim=-1)

    def forward(self, x):
        xs = self._scale_inputs(x, self.training)
        
        if xs.dim() == 4: 
            B, C, H, W = xs.shape
            xs_flat = xs.view(B, C, H*W)
            Phi = self._basis(xs_flat) 
            y = torch.einsum("bchk,k->bch", Phi, self.coeffs)
            return y.view(B, C, H, W)
        else: 
            Phi = self._basis(xs)  
            return torch.einsum("bfk,k->bf", Phi, self.coeffs)

class BottleneckBlock(nn.Module):
    def __init__(self, in_ch, out_ch, reduction_ratio=4):
        super().__init__()
        mid = max(in_ch // reduction_ratio, 8)
        self.reduce = nn.Conv2d(in_ch, mid, 1, bias=False)
        self.conv = nn.Conv2d(mid, mid, 3, padding=1, bias=False)
        self.expand = nn.Conv2d(mid, out_ch, 1, bias=False)
    
    def forward(self, x):
        return self.expand(self.conv(self.reduce(x)))

class ConvolutionalChebyshevNN(nn.Module):
    def __init__(self, num_classes=10, conv_depth=5,
                 cheb_degree=20, use_nodes=True, node_kind="second", use_bottleneck=True):
        super().__init__()
        ch = [3, 32, 64, 128, 256, 384] if use_bottleneck else [3,48,96,192,320,512]
        self.conv_layers, self.cheb_layers, self.norms = nn.ModuleList(), nn.ModuleList(), nn.ModuleList()
        self.pool, self.pool_after = nn.MaxPool2d(2), {1,2,3,4}

        for i in range(conv_depth):
            self.conv_layers.append(
                BottleneckBlock(ch[i],ch[i+1]) if use_bottleneck and i>0
                else nn.Conv2d(ch[i],ch[i+1],3,padding=1,bias=False)
            )
            self.cheb_layers.append(
                SharedChebyshevLayer(ch[i+1], cheb_degree, use_nodes, node_kind)
            )
            self.norms.append(nn.LayerNorm(ch[i+1]))

        self.final_cheb = SharedChebyshevLayer(ch[-1], cheb_degree, use_nodes, node_kind)
        self.heads = nn.ModuleList([nn.Linear(ch[-1],1) for _ in range(num_classes)])

        self.apply(self._init_weights)

    @staticmethod
    def _init_weights(m):
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode="fan_out")
        elif isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight, mode="fan_in")
            nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight); nn.init.zeros_(m.bias)

    def forward(self, x):
        x = x.to(memory_format=torch.channels_last)
        for i,(cv,chb,ln) in enumerate(zip(self.conv_layers,self.cheb_layers,self.norms)):
            x = cv(x)
            x = chb(x)
            x = ln(x.permute(0,2,3,1)).permute(0,3,1,2)
            if i in self.pool_after: x = self.pool(x)
        x = F.adaptive_avg_pool2d(x,1).flatten(1)
        x = F.normalize(self.final_cheb(x),p=2,dim=1,eps=1e-6)
        return torch.cat([h(x) for h in self.heads], dim=1)

def count_parameters(model):
    tot = sum(p.numel() for p in model.parameters())
    train = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return tot, train

# def compile_net(net):
#     try:        
#         return torch.compile(net, mode="default", dynamic=True)
#     except Exception:
#         return torch.jit.script(net)

def train(model, train_images, train_labels, test_images, test_labels, epochs=50, bs=64, patience=5):
    train_loader = DataLoader(TensorDataset(train_images, train_labels), bs , shuffle=True, num_workers=8,
                              pin_memory=False, persistent_workers=True)
    test_loader = DataLoader(TensorDataset(test_images, test_labels), bs , shuffle=False,num_workers=8,
                              pin_memory=False, persistent_workers=True)

    criterion = nn.CrossEntropyLoss()
    opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, mode="min",
                                                     patience=2, factor=0.5, min_lr=1e-5)
    best_val, wait = 1e9, 0
    start = time.time()

    for ep in range(1, epochs+1):
        model.train()
        tr_loss, tr_hit = 0., 0
        for i, (x, y) in enumerate(train_loader):
            if i%100 ==0:
                print(f'Batch Number {i}')
            x,y = x.to(device,dtype), y.to(device)
            opt.zero_grad()
            out = model(x)
            loss = criterion(out,y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
            opt.step()

            tr_loss += loss.item()*x.size(0)
            tr_hit += (out.argmax(1)==y).sum().item()

        model.eval()
        vl_loss, vl_hit = 0., 0
        with torch.no_grad():
            for x,y in test_loader:
                x,y = x.to(device,dtype), y.to(device)
                out = model(x)
                vl_loss += criterion(out,y).item()*x.size(0)
                vl_hit += (out.argmax(1)==y).sum().item()

        tr_loss /= len(train_loader.dataset); tr_acc = tr_hit/len(train_loader.dataset)
        vl_loss /= len(test_loader.dataset) ; vl_acc = vl_hit/len(test_loader.dataset)
        scheduler.step(vl_loss)
        
        print(f"Ep {ep:02d} | "
              f"Train {tr_loss:.3f}/{tr_acc:.3f} | "
              f"Val {vl_loss:.3f}/{vl_acc:.3f} | "
              f"Train Accuracy {tr_acc:.3f}/{tr_acc:.3f} | "
              f"Test Accuracy {vl_acc:.3f}/{vl_acc:.3f} | "
              f"LR {opt.param_groups[0]['lr']:.2e} | "
              f"Time {(time.time()-start)/60:.1f}m")

if __name__ == "__main__":
    net = ConvolutionalChebyshevNN(num_classes=10, cheb_degree=10, use_bottleneck=True)
    print(f"Total params: {count_parameters(net)[0]:,}")
    train(net, train_images, train_labels, test_images, test_labels, epochs=100, bs=64, patience=1)

    


Total params: 112,556
Batch Number 0
Batch Number 100
Batch Number 200
Batch Number 300
Batch Number 400
Batch Number 500
Batch Number 600
Batch Number 700
Ep 01 | Train 1.778/0.348 | Val 1.543/0.446 | Train Accuracy 0.348/0.348 | Test Accuracy 0.446/0.446 | LR 1.00e-03 | Time 28.2m
Batch Number 0
Batch Number 100
Batch Number 200
Batch Number 300
Batch Number 400
Batch Number 500
Batch Number 600
Batch Number 700
Ep 02 | Train 1.436/0.482 | Val 1.399/0.493 | Train Accuracy 0.482/0.482 | Test Accuracy 0.493/0.493 | LR 1.00e-03 | Time 57.0m
Batch Number 0
Batch Number 100
Batch Number 200
Batch Number 300
Batch Number 400
Batch Number 500
Batch Number 600
Batch Number 700
Ep 03 | Train 1.257/0.551 | Val 1.215/0.564 | Train Accuracy 0.551/0.551 | Test Accuracy 0.564/0.564 | LR 1.00e-03 | Time 86.3m
Batch Number 0
Batch Number 100
Batch Number 200
Batch Number 300
Batch Number 400
Batch Number 500
Batch Number 600
Batch Number 700
Ep 04 | Train 1.138/0.594 | Val 1.139/0.593 | Train Accura

In [None]:
# Chebyshev Convolutional NN , separate activation per feature element per layer 
# This Architechture of NN does not exist Already in traditional NN , but it is compute extensive
# Also different weights and biases for different classes can be another strategy along with Chebyshev Activations 
# Etc ... can  be tried to increase the expressivity of the Conventional NN further , with ultimate goal to construct compact
# human understanable equations from these obtained NN equations , i.e. first use unsupervised learning to reduce the number 
# of features and then make the interpreted equations human understandable on this compact manifold , and then scale them back to given feature size...
# Universal Audio , Image , Video , Text best P(Y/X) per class human understandable
# equations or all the terms of Bayes Equations per class is another ultimate scope of this research...this can lead to further innovations in the field
# and then extending this research from Supervised or Generative or Unsuperivised(Autoencoder as done in Bayes Theorem) to other types of NN Architechtures 
# namely Reinforcement Learning and World Models...


import math, time, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
from pathlib import Path

import torch._dynamo
torch._dynamo.config.suppress_errors = True


torch.set_num_threads(8)         
torch.backends.mkldnn.enabled = True
device = torch.device("cpu")      
dtype   = torch.float32        


def cheb_nodes_first_kind(n, a=-1.0, b=1.0):
    k = torch.arange(1, n + 1, dtype=dtype)
    x = torch.cos((2 * k - 1) * math.pi / (2 * n))
    return 0.5 * (a + b) + 0.5 * (b - a) * x

def cheb_nodes_second_kind(n, a=-1.0, b=1.0):
    k = torch.arange(0, n, dtype=dtype)
    x = torch.cos(k * math.pi / (n - 1))
    return 0.5 * (a + b) + 0.5 * (b - a) * x

class InputScaler(nn.Module):
    def __init__(self, momentum=0.9):
        super().__init__()
        self.momentum = momentum
        self.register_buffer("running_min", torch.tensor(0., dtype=dtype))
        self.register_buffer("running_max", torch.tensor(0., dtype=dtype))
        self.register_buffer("initialized", torch.tensor(False))

    @torch.no_grad()
    def update_stats(self, x):
        batch_min, batch_max = x.min(), x.max()
        if not self.initialized:
            self.running_min.copy_(batch_min)
            self.running_max.copy_(batch_max)
            self.initialized.fill_(True)
        else:
            self.running_min.mul_(self.momentum).add_(batch_min * (1 - self.momentum))
            self.running_max.mul_(self.momentum).add_(batch_max * (1 - self.momentum))

    def scale(self, x, training=True):
        if training and self.training:
            self.update_stats(x)

        if not self.initialized:
            min_val, max_val = x.min(), x.max()
        else:
            min_val, max_val = self.running_min, self.running_max

        range_val = torch.clamp(max_val - min_val, min=1e-8)
        scaled = 2.0 * (x - min_val) / range_val - 1.0
        return torch.clamp(scaled, -1.0, 1.0)

class SharedChebyshevLayer(nn.Module):
    def __init__(self, feature_size, degree, use_nodes=True, node_kind="first"):
        super().__init__()
        self.degree, self.feature_size = degree, feature_size
        self.coeffs = nn.Parameter(torch.randn(feature_size, degree + 1, dtype=dtype)*0.01)

        self.scalers = nn.ModuleList([InputScaler() for _ in range(feature_size)])
        self.register_buffer("nodes",
            cheb_nodes_first_kind(degree+1) if use_nodes and node_kind=="first"
            else cheb_nodes_second_kind(degree+1) if use_nodes else torch.linspace(-1,1,degree+1)
        )

    def _scale_inputs(self, x, training=True):
        if x.dim() == 4:            # (B,C,H,W)
            B,C,H,W = x.shape
            x_flat  = x.view(B, C, -1)
            scaled  = torch.zeros_like(x_flat)
            for c in range(C):
                scaled[:,c,:] = self.scalers[c].scale(x_flat[:,c,:], training)
            return scaled.view(B,C,H,W)
        else:                       # (B,F)
            out = torch.zeros_like(x)
            for f in range(x.size(1)):
                out[:,f] = self.scalers[f].scale(x[:,f], training)
            return out

    def _basis(self, x):
        T0, T1 = torch.ones_like(x), x
        basis  = [T0, T1]
        for k in range(1, self.degree):
            Tk = 2.0 * x * basis[-1] - basis[-2]
            if k < self.nodes.numel():
                node_w = torch.exp(-0.5 * torch.abs(x - self.nodes[k]))
                Tk *= node_w
            basis.append(Tk)
        return torch.stack(basis[:self.degree+1], dim=-1)

    def forward(self, x):
        xs = self._scale_inputs(x, self.training)
        if xs.dim()==4:
            B, C, H, W = xs.shape
            Phi = self._basis(xs.view(B,C,-1))
            y   = torch.einsum("bchk,ck->bch", Phi, self.coeffs)
            return y.view(B,C,H,W)
        else:
            Phi = self._basis(xs)
            return torch.einsum("bfk,fk->bf", Phi, self.coeffs)

class BottleneckBlock(nn.Module):
    def __init__(self, in_ch, out_ch, reduction_ratio=4):
        super().__init__()
        mid = max(in_ch // reduction_ratio, 8)
        self.reduce = nn.Conv2d(in_ch, mid, 1, bias=False)
        self.conv   = nn.Conv2d(mid, mid, 3, padding=1, bias=False)
        self.expand = nn.Conv2d(mid, out_ch, 1, bias=False)
    def forward(self, x):
        return self.expand(self.conv(self.reduce(x)))

class ConvolutionalChebyshevNN(nn.Module):
    def __init__(self, num_classes=10, conv_depth=5,
                 cheb_degree=20, use_nodes=True, node_kind="second", use_bottleneck=True):
        super().__init__()
        ch = [3, 32, 64, 128, 256, 384] if use_bottleneck else [3,48,96,192,320,512]
        self.conv_layers, self.cheb_layers, self.norms = nn.ModuleList(), nn.ModuleList(), nn.ModuleList()
        self.pool, self.pool_after = nn.MaxPool2d(2), {1,2,3,4}

        for i in range(conv_depth):
            self.conv_layers.append(
                BottleneckBlock(ch[i],ch[i+1]) if use_bottleneck and i>0
                else nn.Conv2d(ch[i],ch[i+1],3,padding=1,bias=False)
            )
            self.cheb_layers.append(
                SharedChebyshevLayer(ch[i+1], cheb_degree, use_nodes, node_kind)
            )
            self.norms.append(nn.LayerNorm(ch[i+1]))

        self.final_cheb = SharedChebyshevLayer(ch[-1], cheb_degree, use_nodes, node_kind)
        self.heads = nn.ModuleList([nn.Linear(ch[-1],1) for _ in range(num_classes)])

        self.apply(self._init_weights)

    @staticmethod
    def _init_weights(m):
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode="fan_out")
        elif isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight, mode="fan_in")
            nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight); nn.init.zeros_(m.bias)

    def forward(self, x):
        x = x.to(memory_format=torch.channels_last)
        for i,(cv,chb,ln) in enumerate(zip(self.conv_layers,self.cheb_layers,self.norms)):
            x = cv(x)
            x = chb(x)
            x = ln(x.permute(0,2,3,1)).permute(0,3,1,2)
            if i in self.pool_after: x = self.pool(x)
        x = F.adaptive_avg_pool2d(x,1).flatten(1)
        x = F.normalize(self.final_cheb(x),p=2,dim=1,eps=1e-6)
        return torch.cat([h(x) for h in self.heads], dim=1)

def count_parameters(model):
    tot = sum(p.numel() for p in model.parameters())
    train = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return tot, train

def compile_net(net):
    try:        
        return torch.compile(net, mode="default", dynamic=True)
    except Exception:
        return torch.jit.script(net)

def train(model, train_images,train_labels, test_images,test_labels, epochs=50, bs=64, patience=5):

    train_loader = DataLoader(TensorDataset(train_images, train_labels), bs , shuffle=True, num_workers=8,
                              pin_memory=False, persistent_workers=True)
    test_loader  = DataLoader(TensorDataset(test_images, test_labels), bs , shuffle=False,num_workers=8,
                              pin_memory=False, persistent_workers=True)

    criterion = nn.CrossEntropyLoss()
    opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, mode="min",
                                                     patience=2, factor=0.5, min_lr=1e-5)
    best_val, wait = 1e9, 0
    start = time.time()

    for ep in range(1, epochs+1):
        model.train()
        tr_loss, tr_hit = 0., 0
        for i, (x, y) in enumerate(train_loader):
            print(f'Batch Number {i}')
            x,y = x.to(device,dtype), y.to(device)
            opt.zero_grad()
            out = model(x)
            loss = criterion(out,y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
            opt.step()

            tr_loss += loss.item()*x.size(0)
            tr_hit  += (out.argmax(1)==y).sum().item()

        model.eval()
        vl_loss, vl_hit = 0., 0
        with torch.no_grad():
            for x,y in test_loader:
                x,y = x.to(device,dtype), y.to(device)
                out = model(x)
                vl_loss += criterion(out,y).item()*x.size(0)
                vl_hit  += (out.argmax(1)==y).sum().item()

        tr_loss /= len(train_loader.dataset); tr_acc = tr_hit/len(train_loader.dataset)
        vl_loss /= len(test_loader.dataset) ; vl_acc = vl_hit/len(test_loader.dataset)
        scheduler.step(vl_loss)

        print(f"Ep {ep:02d} | "
              f"Train {tr_loss:.3f}/{tr_acc:.3f} | "
              f"Val {vl_loss:.3f}/{vl_acc:.3f} | "
              f"LR {opt.param_groups[0]['lr']:.2e} | "
              f"Time {(time.time()-start)/60:.1f}m")
        if vl_loss < best_val:
            best_val, wait = vl_loss, 0
            torch.save(model.state_dict(), "best_cheb_cpu.pth")
        else:
            wait += 1
            if wait > patience:
                print("Early stopping.")
                break

if __name__ == "__main__":

    net = ConvolutionalChebyshevNN(num_classes=10, cheb_degree=20, use_bottleneck=False)
    net = compile_net(net).to(device)
    print(f"Total params: {count_parameters(net)[0]:,}")

    train(net, train_images, train_labels, test_images, test_labels, epochs=100, bs=128, patience=1)


In [None]:
# Chebyshev VIT and other architechtures can also be implemented...