In [74]:
import torch
import numpy as np

In [75]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [76]:
import torch
import numpy as np

def generate_data(n_samples=1000, n_grid=128):
    x = np.linspace(0, 1, n_grid)

    inputs = []
    outputs = []

    for _ in range(n_samples):
        # random function
        f = np.sin(2*np.pi*x*np.random.uniform(1,4))  
        f += 0.3*np.sin(6*np.pi*x*np.random.uniform(0.3,1))  
        
        g = f * f + 0.3 * np.roll(f, 2)  
        
        inputs.append(f)
        outputs.append(g)

    return torch.tensor(inputs, dtype=torch.float32), torch.tensor(outputs, dtype=torch.float32)


In [84]:
import pywt
import torch.nn as nn
import torch.nn.functional as F

class WaveletLayer(nn.Module):
    def __init__(self, in_channels, out_channels, wavelet='db4', level=1):
        super().__init__()
        self.wavelet = wavelet
        self.level = level
        self.linear = nn.Linear(in_channels, out_channels)

    def forward(self, x):
        """
        x: [batch, channels, grid]
        """
        import pywt

        device = x.device
        batch, channels, n = x.shape
        out_channels = self.linear.out_features

        cA_list, cD_list = [], []

        # DWT on each channel and batch
        for b in range(batch):
            cA_ch, cD_ch = [], []
            for c in range(channels):
                cA, cD = pywt.dwt(x[b, c].detach().cpu().numpy(), self.wavelet)
                cA_ch.append(torch.tensor(cA, device=device, dtype=x.dtype))
                cD_ch.append(torch.tensor(cD, device=device, dtype=x.dtype))
            cA_list.append(torch.stack(cA_ch))
            cD_list.append(torch.stack(cD_ch))

        cA = torch.stack(cA_list)
        cD = torch.stack(cD_list)

        # Linear mixing on approximation coefficients
        cA = self.linear(cA.permute(0, 2, 1)).permute(0, 2, 1)

        # IDWT to reconstruct
        out = []
        for b in range(batch):
            rec = []
            for ch in range(out_channels):
                rec_ch = pywt.idwt(cA[b, ch].detach().cpu().numpy(),
                                   cD[b, ch].detach().cpu().numpy(),
                                   self.wavelet)
                rec.append(torch.tensor(rec_ch, device=device, dtype=x.dtype))
            out.append(torch.stack(rec))
        out = torch.stack(out)
        return out

In [87]:
class WNO(nn.Module):
    def __init__(self, in_channels=1, width=16, wave_layers=4):
        """
        in_channels: number of input features per grid point
        width: lifting dimension
        wave_layers: number of Wavelet layers
        """
        super().__init__()
        self.width = width

        # Input lifting
        self.fc0 = nn.Linear(in_channels, width)

        # Multiple Wavelet layers
        self.wave_layers = nn.ModuleList([WaveletLayer(width, width) for _ in range(wave_layers)])

        # Output projection
        self.fc1 = nn.Linear(width, 64)
        self.fc2 = nn.Linear(64, 1)


    def forward(self, x):
            """
            x: [batch, grid] or [batch, grid, in_channels]
            """
            if x.ndim == 2:
                x = x.unsqueeze(-1)  # [batch, grid, in_channels]

            x = self.fc0(x)          # lift â†’ [batch, grid, width]
            x = x.permute(0, 2, 1)   # [batch, width, grid]

            # Apply wavelet layers
            for wave in self.wave_layers:
                x = x + wave(x)      # residual connection
                x = F.gelu(x)

            x = x.permute(0, 2, 1)   # [batch, grid, width]
            x = F.gelu(self.fc1(x))  # intermediate projection
            x = self.fc2(x)           # output projection

            return x.squeeze(-1)      # [batch, grid]


In [89]:
# Training 
model = WNO().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.MSELoss()

x, y = generate_data(1000)

x = x.float().to(device)
y = y.float().to(device)

x_train = x[:800]
y_train = y[:800]

x_test = x[800:]
y_test = y[800:]


for epoch in range(5):
    pred = model(x_train)
    loss = loss_fn(pred, y_train)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    
    print("Epoch", epoch, "Loss:", loss.item())


Epoch 0 Loss: 0.6257686018943787
Epoch 1 Loss: 0.6093882918357849
Epoch 2 Loss: 0.5933340191841125
Epoch 3 Loss: 0.5776136517524719
Epoch 4 Loss: 0.5622329115867615


In [90]:
model.eval()
with torch.no_grad():
    pred = model(x_test)
    test_loss = loss_fn(pred, y_test)
    print("Test Loss:", test_loss.item())


Test Loss: 0.5516216158866882
