In [1]:
import torch
import scipy


In [23]:
import scipy.fft 
import torch

import torch.nn as nn
from torch_dct import dct, idct  # pip install torch-dct

from torch.utils.data import DataLoader,TensorDataset

import time

In [13]:
# Generate the layers

class LinearDCT(nn.Module):
    """
    Linear layer with weights trained in DCT domain.
    During forward pass, IDCT is applied to recover real-space weights.
    """
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        # DCT-domain weights (learned)
        self.weight_dct = nn.Parameter(torch.randn(out_features, in_features))

        # Optional bias
        self.bias = nn.Parameter(torch.randn(out_features)) if bias else None

    def forward(self, x):
        # Convert DCT weights back to real-space via IDCT
        weight_real = idct(self.weight_dct, norm='ortho')  # Shape: [out_features, in_features]

        # Apply linear transformation
        output = x @ weight_real.T

        if self.bias is not None:
            output += self.bias

        return output
    


class LinearDCTModel(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        self.model = nn.Sequential(LinearDCT(in_features, in_features//2, bias), LinearDCT(in_features//2, out_features, bias))

    def forward(self, x):
        # x shape: (batch_size, in_features)
        return self.model(x)
    
class LinearStandardModel(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        self.model = nn.Sequential(nn.Linear(in_features, in_features//2, bias), nn.Linear(in_features//2, out_features, bias))

    def forward(self, x):
        # x shape: (batch_size, in_features)
        return self.model(x)
    


In [4]:
def generate_true_samples(input_size, output_size, instances, seed=0):
    if seed is not None:
        torch.manual_seed(seed)

    # Random inputs
    x = torch.randn(instances, input_size)

    # True weight and bias
    W = torch.randn(output_size, input_size)
    b = torch.randn(output_size)

    # True function: y = x @ W.T + b
    y = x @ W.T + b  # shape: (instances, output_size)

    return x, y, (W, b)

In [5]:
batch_size = 4
in_features = 1000
out_features = 500

x = torch.randn(batch_size, in_features)

layer = LinearDCT(in_features, out_features)
output = layer(x)

print("Output shape:", output.shape)

Output shape: torch.Size([4, 500])


In [28]:
INPUT_SIZE  = 1024
OUTPUT_SIZE = 512
TRAIN_SAMPLES = 16_000
VAL_SAMPLES   = 2_000
BATCH_SIZE    = 128
EPOCHS        = 100
LR            = 1e-3
DEVICE        = "cuda" if torch.cuda.is_available() else "cpu"


x_train, y_train, _ = generate_true_samples(INPUT_SIZE, OUTPUT_SIZE, TRAIN_SAMPLES, seed=0)
x_val,   y_val, _   = generate_true_samples(INPUT_SIZE, OUTPUT_SIZE, VAL_SAMPLES,   seed=1)

train_loader = DataLoader(TensorDataset(x_train, y_train), batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(TensorDataset(x_val,   y_val),   batch_size=BATCH_SIZE, shuffle=False)


In [35]:

def run_epoch(model, loader, training=False):
    if training:
        model.train()
    else:
        model.eval()
    running_loss = 0.0
    with torch.set_grad_enabled(training):
        for xb, yb in loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            preds  = model(xb)
            loss   = criterion(preds, yb)

            if training:
                optim.zero_grad(set_to_none=True)
                loss.backward()
                optim.step()

            running_loss += loss.item() * xb.size(0)
    return running_loss / len(loader.dataset)


def train(model, train_loader, val_loader, epochs, optim, criterion):
    start_time = time.time()

    for epoch in range(epochs):
        train_loss = run_epoch(model, train_loader, training=True)
        val_loss   = run_epoch(model, val_loader, training=False)

        print(f"Epoch {epoch:02d} │ train MSE {train_loss:.4f} │ val MSE {val_loss:.4f}")
    print("\nTraining finished!")
    print(f"Total time: {time.time() - start_time:.2f} seconds")
       

In [36]:
model_dct = LinearDCTModel(INPUT_SIZE, OUTPUT_SIZE).to(DEVICE)
criterion = nn.MSELoss()
optim      = torch.optim.AdamW(model_dct.parameters(), lr=LR)

In [37]:
train(model_dct, train_loader, val_loader, EPOCHS, optim, criterion)

Epoch 00 │ train MSE 468490.4352 │ val MSE 412399.0720
Epoch 01 │ train MSE 364776.5240 │ val MSE 326219.4917
Epoch 02 │ train MSE 287162.4267 │ val MSE 260345.5394
Epoch 03 │ train MSE 227945.1441 │ val MSE 209249.8257
Epoch 04 │ train MSE 182173.4161 │ val MSE 169169.7791
Epoch 05 │ train MSE 146411.7414 │ val MSE 137452.6009
Epoch 06 │ train MSE 118230.3621 │ val MSE 112157.3716
Epoch 07 │ train MSE 95857.1226 │ val MSE 91860.0961
Epoch 08 │ train MSE 77989.2368 │ val MSE 75481.8450
Epoch 09 │ train MSE 63645.3261 │ val MSE 62212.3918
Epoch 10 │ train MSE 52081.9649 │ val MSE 51418.4421
Epoch 11 │ train MSE 42724.7375 │ val MSE 42614.6711
Epoch 12 │ train MSE 35128.1999 │ val MSE 35410.8958
Epoch 13 │ train MSE 28945.1020 │ val MSE 29505.3041
Epoch 14 │ train MSE 23899.3057 │ val MSE 24657.5531
Epoch 15 │ train MSE 19773.2426 │ val MSE 20668.2787
Epoch 16 │ train MSE 16392.3953 │ val MSE 17380.5744
Epoch 17 │ train MSE 13617.4943 │ val MSE 14669.8756
Epoch 18 │ train MSE 11336.2021 

In [32]:
model_linear = LinearDCTModel(INPUT_SIZE, OUTPUT_SIZE).to(DEVICE)
criterion = nn.MSELoss()
optim      = torch.optim.AdamW(model_linear.parameters(), lr=LR)

In [38]:
train(model_linear, train_loader, val_loader, EPOCHS, optim, criterion)

Epoch 00 │ train MSE 144.9702 │ val MSE 1906.6094
Epoch 01 │ train MSE 144.9702 │ val MSE 1906.6094
Epoch 02 │ train MSE 144.9702 │ val MSE 1906.6094
Epoch 03 │ train MSE 144.9702 │ val MSE 1906.6094
Epoch 04 │ train MSE 144.9702 │ val MSE 1906.6094
Epoch 05 │ train MSE 144.9702 │ val MSE 1906.6094
Epoch 06 │ train MSE 144.9702 │ val MSE 1906.6094
Epoch 07 │ train MSE 144.9702 │ val MSE 1906.6094
Epoch 08 │ train MSE 144.9702 │ val MSE 1906.6094
Epoch 09 │ train MSE 144.9702 │ val MSE 1906.6094
Epoch 10 │ train MSE 144.9702 │ val MSE 1906.6094
Epoch 11 │ train MSE 144.9702 │ val MSE 1906.6094
Epoch 12 │ train MSE 144.9702 │ val MSE 1906.6094
Epoch 13 │ train MSE 144.9702 │ val MSE 1906.6094
Epoch 14 │ train MSE 144.9702 │ val MSE 1906.6094
Epoch 15 │ train MSE 144.9702 │ val MSE 1906.6094
Epoch 16 │ train MSE 144.9702 │ val MSE 1906.6094
Epoch 17 │ train MSE 144.9702 │ val MSE 1906.6094
Epoch 18 │ train MSE 144.9702 │ val MSE 1906.6094
Epoch 19 │ train MSE 144.9702 │ val MSE 1906.6094
