In [1]:
import pennylane as qml
import torch
from torch import nn
from torch.utils.data import DataLoader
from pennylane import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

Using cpu device


  return torch._C._cuda_getDeviceCount() > 0


In [2]:
import sys
sys.path.insert(0, "../")
import utils.utils as utils
import models.fourier_models as fm
import models.quantum_models as qm

In [3]:
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

In [4]:
# generate a suitable regression task
n_samples = 200
n_features = 3
n_informative = 3
n_targets = 1
noise = 0.0
random_state = 42
X, y = make_regression(n_samples=n_samples, n_features=n_features, n_informative=n_informative, n_targets=n_targets, noise=noise, random_state=random_state)
X, y = torch.from_numpy(X), torch.from_numpy(y)
# Scale data to interval [-pi/2, pi/2]
X_scaled = utils.data_scaler(X, interval=(-torch.pi/2, torch.pi/2))

# Split the data set into training and testing
X_train, X_test, y_train, y_test = train_test_split(
    X_scaled, y, test_size=0.2, random_state=2)

# torch dataloaders
train_data_list = []
for i in range(len(X_train)):
    data_point = (X_train[i], y_train[i])
    train_data_list.append(data_point)

test_data_list = []
for i in range(len(X_test)):
    data_point = (X_test[i], y_test[i])
    test_data_list.append(data_point)
    
train_dataloader = DataLoader(train_data_list, batch_size=200, shuffle=True)
test_dataloader = DataLoader(test_data_list, batch_size=200, shuffle=True)

In [5]:
n_qubits = n_features
model = qm.QuantumRegressionModel(n_qubits)

In [6]:
def train(dataloader, model, loss_fn, optimizer, printing=False):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred.flatten(), y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            if printing == True:
                print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
                for param_group in optimizer.param_groups:
                    print("lr: ", param_group['lr'])
        return loss

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred.flatten(), y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return test_loss

In [7]:
loss_fn = nn.MSELoss(reduction='mean') # equiv. to torch.linalg.norm(input-target)**2
optimizer = torch.optim.Adam(model.parameters(), lr=0.8)
train(train_dataloader, model, loss_fn, optimizer, printing=True)
optimizer = torch.optim.Adam(model.parameters(), lr=0.2)

epochs = 100
for t in tqdm(range(epochs)):
    # print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer, printing=True)
    # test(test_dataloader, model, loss_fn)
    # scheduler.step()
print("Done!")
print(train(train_dataloader, model, loss_fn, optimizer, printing=True))
print(test(test_dataloader, model, loss_fn))

loss: 10550.017158  [    0/  160]
lr:  0.8


  1%|          | 1/100 [00:03<06:25,  3.90s/it]

loss: 10512.832161  [    0/  160]
lr:  0.2


  2%|▏         | 2/100 [00:07<06:21,  3.89s/it]

loss: 10430.224336  [    0/  160]
lr:  0.2


  3%|▎         | 3/100 [00:11<06:16,  3.88s/it]

loss: 10398.837042  [    0/  160]
lr:  0.2


  4%|▍         | 4/100 [00:15<06:12,  3.88s/it]

loss: 10374.762585  [    0/  160]
lr:  0.2


  5%|▌         | 5/100 [00:19<06:10,  3.90s/it]

loss: 10287.908892  [    0/  160]
lr:  0.2


  6%|▌         | 6/100 [00:23<06:05,  3.88s/it]

loss: 10261.564912  [    0/  160]
lr:  0.2


  7%|▋         | 7/100 [00:27<06:00,  3.87s/it]

loss: 10185.088427  [    0/  160]
lr:  0.2


  8%|▊         | 8/100 [00:30<05:53,  3.84s/it]

loss: 10089.008541  [    0/  160]
lr:  0.2


  9%|▉         | 9/100 [00:34<05:49,  3.85s/it]

loss: 10002.009880  [    0/  160]
lr:  0.2


 10%|█         | 10/100 [00:38<05:44,  3.83s/it]

loss: 9940.924570  [    0/  160]
lr:  0.2


 11%|█         | 11/100 [00:42<05:39,  3.82s/it]

loss: 9856.193665  [    0/  160]
lr:  0.2


 12%|█▏        | 12/100 [00:46<05:35,  3.81s/it]

loss: 9724.495464  [    0/  160]
lr:  0.2


 13%|█▎        | 13/100 [00:49<05:30,  3.80s/it]

loss: 9649.723119  [    0/  160]
lr:  0.2


 14%|█▍        | 14/100 [00:53<05:26,  3.80s/it]

loss: 9593.594967  [    0/  160]
lr:  0.2


 15%|█▌        | 15/100 [00:57<05:23,  3.81s/it]

loss: 9525.150861  [    0/  160]
lr:  0.2


 16%|█▌        | 16/100 [01:01<05:20,  3.82s/it]

loss: 9451.047391  [    0/  160]
lr:  0.2


 17%|█▋        | 17/100 [01:05<05:15,  3.80s/it]

loss: 9385.247678  [    0/  160]
lr:  0.2


 18%|█▊        | 18/100 [01:09<05:15,  3.85s/it]

loss: 9324.996566  [    0/  160]
lr:  0.2


 19%|█▉        | 19/100 [01:12<05:10,  3.83s/it]

loss: 9263.412658  [    0/  160]
lr:  0.2


 20%|██        | 20/100 [01:17<05:13,  3.92s/it]

loss: 9188.071190  [    0/  160]
lr:  0.2


 21%|██        | 21/100 [01:20<05:05,  3.86s/it]

loss: 9106.841108  [    0/  160]
lr:  0.2


 22%|██▏       | 22/100 [01:24<04:59,  3.84s/it]

loss: 9047.135244  [    0/  160]
lr:  0.2


 23%|██▎       | 23/100 [01:28<04:54,  3.83s/it]

loss: 8967.286818  [    0/  160]
lr:  0.2


 24%|██▍       | 24/100 [01:32<04:50,  3.83s/it]

loss: 8911.908090  [    0/  160]
lr:  0.2


 25%|██▌       | 25/100 [01:35<04:45,  3.81s/it]

loss: 8860.662021  [    0/  160]
lr:  0.2


 26%|██▌       | 26/100 [01:39<04:43,  3.83s/it]

loss: 8794.943774  [    0/  160]
lr:  0.2


 27%|██▋       | 27/100 [01:43<04:43,  3.88s/it]

loss: 8730.125029  [    0/  160]
lr:  0.2


 28%|██▊       | 28/100 [01:47<04:41,  3.91s/it]

loss: 8675.511232  [    0/  160]
lr:  0.2


 29%|██▉       | 29/100 [01:51<04:37,  3.91s/it]

loss: 8616.353695  [    0/  160]
lr:  0.2


 30%|███       | 30/100 [01:55<04:34,  3.92s/it]

loss: 8556.586304  [    0/  160]
lr:  0.2


 31%|███       | 31/100 [01:59<04:27,  3.88s/it]

loss: 8490.714843  [    0/  160]
lr:  0.2


 32%|███▏      | 32/100 [02:03<04:23,  3.88s/it]

loss: 8433.206569  [    0/  160]
lr:  0.2


 33%|███▎      | 33/100 [02:07<04:16,  3.83s/it]

loss: 8376.771525  [    0/  160]
lr:  0.2


 34%|███▍      | 34/100 [02:10<04:13,  3.85s/it]

loss: 8319.938779  [    0/  160]
lr:  0.2


 35%|███▌      | 35/100 [02:14<04:11,  3.87s/it]

loss: 8259.041084  [    0/  160]
lr:  0.2


 36%|███▌      | 36/100 [02:18<04:08,  3.89s/it]

loss: 8200.637601  [    0/  160]
lr:  0.2


 37%|███▋      | 37/100 [02:22<04:08,  3.94s/it]

loss: 8148.630525  [    0/  160]
lr:  0.2


 38%|███▊      | 38/100 [02:26<04:06,  3.98s/it]

loss: 8096.360657  [    0/  160]
lr:  0.2


 39%|███▉      | 39/100 [02:30<04:04,  4.00s/it]

loss: 8049.481171  [    0/  160]
lr:  0.2


 40%|████      | 40/100 [02:35<04:01,  4.02s/it]

loss: 7993.838547  [    0/  160]
lr:  0.2


 41%|████      | 41/100 [02:38<03:54,  3.98s/it]

loss: 7931.710735  [    0/  160]
lr:  0.2


 42%|████▏     | 42/100 [02:42<03:51,  3.99s/it]

loss: 7867.002467  [    0/  160]
lr:  0.2


 43%|████▎     | 43/100 [02:46<03:47,  3.99s/it]

loss: 7818.324735  [    0/  160]
lr:  0.2


 44%|████▍     | 44/100 [02:50<03:42,  3.98s/it]

loss: 7769.093182  [    0/  160]
lr:  0.2


 45%|████▌     | 45/100 [02:54<03:35,  3.92s/it]

loss: 7714.920478  [    0/  160]
lr:  0.2


 46%|████▌     | 46/100 [02:58<03:31,  3.93s/it]

loss: 7656.876526  [    0/  160]
lr:  0.2


 47%|████▋     | 47/100 [03:02<03:27,  3.92s/it]

loss: 7607.236423  [    0/  160]
lr:  0.2


 48%|████▊     | 48/100 [03:06<03:25,  3.95s/it]

loss: 7557.744148  [    0/  160]
lr:  0.2


 49%|████▉     | 49/100 [03:10<03:25,  4.03s/it]

loss: 7506.900291  [    0/  160]
lr:  0.2


 50%|█████     | 50/100 [03:14<03:21,  4.04s/it]

loss: 7456.365856  [    0/  160]
lr:  0.2


 51%|█████     | 51/100 [03:18<03:17,  4.02s/it]

loss: 7417.920991  [    0/  160]
lr:  0.2


 52%|█████▏    | 52/100 [03:22<03:10,  3.97s/it]

loss: 7383.954269  [    0/  160]
lr:  0.2


 53%|█████▎    | 53/100 [03:26<03:04,  3.92s/it]

loss: 7366.648331  [    0/  160]
lr:  0.2


 54%|█████▍    | 54/100 [03:30<03:01,  3.95s/it]

loss: 7335.109624  [    0/  160]
lr:  0.2


 55%|█████▌    | 55/100 [03:34<02:58,  3.97s/it]

loss: 7226.715488  [    0/  160]
lr:  0.2


 56%|█████▌    | 56/100 [03:38<02:58,  4.05s/it]

loss: 7161.852039  [    0/  160]
lr:  0.2


 57%|█████▋    | 57/100 [03:42<02:51,  3.98s/it]

loss: 7153.320964  [    0/  160]
lr:  0.2


 58%|█████▊    | 58/100 [03:46<02:48,  4.01s/it]

loss: 7089.978063  [    0/  160]
lr:  0.2


 59%|█████▉    | 59/100 [03:50<02:43,  3.98s/it]

loss: 7021.238641  [    0/  160]
lr:  0.2


 60%|██████    | 60/100 [03:54<02:39,  3.98s/it]

loss: 6966.885878  [    0/  160]
lr:  0.2


 61%|██████    | 61/100 [03:58<02:32,  3.91s/it]

loss: 6942.042842  [    0/  160]
lr:  0.2


 62%|██████▏   | 62/100 [04:02<02:30,  3.96s/it]

loss: 6888.430851  [    0/  160]
lr:  0.2


 63%|██████▎   | 63/100 [04:06<02:26,  3.95s/it]

loss: 6841.773274  [    0/  160]
lr:  0.2


 64%|██████▍   | 64/100 [04:10<02:22,  3.95s/it]

loss: 6782.128846  [    0/  160]
lr:  0.2


 65%|██████▌   | 65/100 [04:13<02:16,  3.90s/it]

loss: 6727.683275  [    0/  160]
lr:  0.2


 66%|██████▌   | 66/100 [04:17<02:12,  3.91s/it]

loss: 6672.088262  [    0/  160]
lr:  0.2


 67%|██████▋   | 67/100 [04:21<02:08,  3.90s/it]

loss: 6625.953158  [    0/  160]
lr:  0.2


 68%|██████▊   | 68/100 [04:25<02:05,  3.93s/it]

loss: 6586.056432  [    0/  160]
lr:  0.2


 69%|██████▉   | 69/100 [04:29<02:01,  3.92s/it]

loss: 6541.471038  [    0/  160]
lr:  0.2


 70%|███████   | 70/100 [04:33<01:57,  3.91s/it]

loss: 6496.152541  [    0/  160]
lr:  0.2


 71%|███████   | 71/100 [04:37<01:53,  3.92s/it]

loss: 6451.236977  [    0/  160]
lr:  0.2


 72%|███████▏  | 72/100 [04:41<01:50,  3.94s/it]

loss: 6434.196911  [    0/  160]
lr:  0.2


 73%|███████▎  | 73/100 [04:45<01:45,  3.90s/it]

loss: 6542.771890  [    0/  160]
lr:  0.2


 74%|███████▍  | 74/100 [04:49<01:41,  3.90s/it]

loss: 6654.352842  [    0/  160]
lr:  0.2


 75%|███████▌  | 75/100 [04:53<01:37,  3.90s/it]

loss: 6367.194695  [    0/  160]
lr:  0.2


 76%|███████▌  | 76/100 [04:57<01:33,  3.91s/it]

loss: 6389.128408  [    0/  160]
lr:  0.2


 77%|███████▋  | 77/100 [05:00<01:29,  3.90s/it]

loss: 6378.392431  [    0/  160]
lr:  0.2


 78%|███████▊  | 78/100 [05:04<01:26,  3.91s/it]

loss: 6215.627990  [    0/  160]
lr:  0.2


 79%|███████▉  | 79/100 [05:08<01:22,  3.95s/it]

loss: 6253.804883  [    0/  160]
lr:  0.2


 80%|████████  | 80/100 [05:12<01:19,  3.96s/it]

loss: 6096.898477  [    0/  160]
lr:  0.2


 81%|████████  | 81/100 [05:16<01:14,  3.91s/it]

loss: 6164.164995  [    0/  160]
lr:  0.2


 82%|████████▏ | 82/100 [05:20<01:10,  3.93s/it]

loss: 6014.735530  [    0/  160]
lr:  0.2


 83%|████████▎ | 83/100 [05:24<01:06,  3.90s/it]

loss: 6043.102668  [    0/  160]
lr:  0.2


 84%|████████▍ | 84/100 [05:28<01:02,  3.90s/it]

loss: 5931.811865  [    0/  160]
lr:  0.2


 85%|████████▌ | 85/100 [05:32<00:57,  3.85s/it]

loss: 5932.606803  [    0/  160]
lr:  0.2


 86%|████████▌ | 86/100 [05:36<00:54,  3.86s/it]

loss: 5863.185257  [    0/  160]
lr:  0.2


 87%|████████▋ | 87/100 [05:39<00:49,  3.83s/it]

loss: 5812.831942  [    0/  160]
lr:  0.2


 88%|████████▊ | 88/100 [05:43<00:46,  3.84s/it]

loss: 5787.505213  [    0/  160]
lr:  0.2


 89%|████████▉ | 89/100 [05:47<00:41,  3.81s/it]

loss: 5714.053733  [    0/  160]
lr:  0.2


 90%|█████████ | 90/100 [05:51<00:38,  3.84s/it]

loss: 5699.157084  [    0/  160]
lr:  0.2


 91%|█████████ | 91/100 [05:55<00:34,  3.83s/it]

loss: 5637.463759  [    0/  160]
lr:  0.2


 92%|█████████▏| 92/100 [05:59<00:30,  3.87s/it]

loss: 5608.251631  [    0/  160]
lr:  0.2


 93%|█████████▎| 93/100 [06:02<00:26,  3.86s/it]

loss: 5559.046311  [    0/  160]
lr:  0.2


 94%|█████████▍| 94/100 [06:06<00:23,  3.89s/it]

loss: 5521.903112  [    0/  160]
lr:  0.2


 95%|█████████▌| 95/100 [06:10<00:19,  3.88s/it]

loss: 5482.368794  [    0/  160]
lr:  0.2


 96%|█████████▌| 96/100 [06:14<00:15,  3.90s/it]

loss: 5441.635593  [    0/  160]
lr:  0.2


 97%|█████████▋| 97/100 [06:18<00:11,  3.89s/it]

loss: 5402.480214  [    0/  160]
lr:  0.2


 98%|█████████▊| 98/100 [06:22<00:07,  3.91s/it]

loss: 5361.304393  [    0/  160]
lr:  0.2


 99%|█████████▉| 99/100 [06:26<00:03,  3.88s/it]

loss: 5327.220246  [    0/  160]
lr:  0.2


100%|██████████| 100/100 [06:30<00:00,  3.90s/it]

loss: 5288.198656  [    0/  160]
lr:  0.2
Done!





loss: 5246.066722  [    0/  160]
lr:  0.2
5246.066721595605
Test Error: 
 Accuracy: 0.0%, Avg loss: 3764.587826 

3764.587826416885


In [8]:
# generate frequencies
max_freq = 2
dim = X_scaled[0].shape[0]

W = utils.freq_generator(max_freq, dim)

# compute best approximation
ba_coeffs = utils.fourier_best_approx(W, X_train, y_train)

print("training_loss: ",utils.loss(W, ba_coeffs, X_train, y_train))
print("test loss: ",utils.loss(W, ba_coeffs, X_test, y_test))

training_loss:  tensor(0.0310, dtype=torch.float64)
test loss:  tensor(0.5630, dtype=torch.float64)


In [14]:
model(X_train[0])

tensor([-52.9373], dtype=torch.float64, grad_fn=<SqueezeBackward3>)

In [13]:
y_train[0]

tensor(-76.4457, dtype=torch.float64)

In [11]:
utils.loss(W, ba_coeffs, X_train[0:2], y_train[0:2])

tensor(0.3404, dtype=torch.float64)

In [12]:
epochs = 100
for t in tqdm(range(epochs)):
    # print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer, printing=True)
    # test(test_dataloader, model, loss_fn)
    # scheduler.step()
print("Done!")
print(train(train_dataloader, model, loss_fn, optimizer, printing=True))
print(test(test_dataloader, model, loss_fn))

  1%|          | 1/100 [00:03<06:24,  3.89s/it]

loss: 5206.724404  [    0/  160]
lr:  0.2


  2%|▏         | 2/100 [00:07<06:30,  3.99s/it]

loss: 5175.474531  [    0/  160]
lr:  0.2


  3%|▎         | 3/100 [00:11<06:19,  3.91s/it]

loss: 5136.269908  [    0/  160]
lr:  0.2


  4%|▍         | 4/100 [00:15<06:18,  3.94s/it]

loss: 5098.055116  [    0/  160]
lr:  0.2


  5%|▌         | 5/100 [00:19<06:10,  3.90s/it]

loss: 5061.806967  [    0/  160]
lr:  0.2


  6%|▌         | 6/100 [00:23<06:08,  3.92s/it]

loss: 5024.493267  [    0/  160]
lr:  0.2


  7%|▋         | 7/100 [00:27<06:01,  3.88s/it]

loss: 4992.669940  [    0/  160]
lr:  0.2


  8%|▊         | 8/100 [00:31<05:57,  3.89s/it]

loss: 4957.472378  [    0/  160]
lr:  0.2


  9%|▉         | 9/100 [00:35<06:02,  3.98s/it]

loss: 4921.097668  [    0/  160]
lr:  0.2


 10%|█         | 10/100 [00:39<05:58,  3.99s/it]

loss: 4885.496459  [    0/  160]
lr:  0.2


 11%|█         | 11/100 [00:43<05:51,  3.95s/it]

loss: 4848.792180  [    0/  160]
lr:  0.2


 12%|█▏        | 12/100 [00:47<05:48,  3.97s/it]

loss: 4814.403893  [    0/  160]
lr:  0.2


 13%|█▎        | 13/100 [00:51<05:43,  3.94s/it]

loss: 4778.075642  [    0/  160]
lr:  0.2


 14%|█▍        | 14/100 [00:55<05:41,  3.97s/it]

loss: 4746.127851  [    0/  160]
lr:  0.2


 15%|█▌        | 15/100 [00:59<05:34,  3.94s/it]

loss: 4711.060097  [    0/  160]
lr:  0.2


 16%|█▌        | 16/100 [01:03<05:30,  3.94s/it]

loss: 4678.385322  [    0/  160]
lr:  0.2


 17%|█▋        | 17/100 [01:06<05:25,  3.92s/it]

loss: 4648.304586  [    0/  160]
lr:  0.2


 18%|█▊        | 18/100 [01:10<05:22,  3.94s/it]

loss: 4622.005780  [    0/  160]
lr:  0.2


 19%|█▉        | 19/100 [01:14<05:20,  3.96s/it]

loss: 4614.638359  [    0/  160]
lr:  0.2


 20%|██        | 20/100 [01:18<05:16,  3.95s/it]

loss: 4668.376926  [    0/  160]
lr:  0.2


 21%|██        | 21/100 [01:22<05:10,  3.93s/it]

loss: 4842.591329  [    0/  160]
lr:  0.2


 22%|██▏       | 22/100 [01:26<05:08,  3.95s/it]

loss: 4802.524265  [    0/  160]
lr:  0.2


 23%|██▎       | 23/100 [01:30<05:02,  3.93s/it]

loss: 4499.632151  [    0/  160]
lr:  0.2


 24%|██▍       | 24/100 [01:34<04:57,  3.92s/it]

loss: 4543.210089  [    0/  160]
lr:  0.2


 25%|██▌       | 25/100 [01:38<04:53,  3.92s/it]

loss: 4560.691276  [    0/  160]
lr:  0.2


 26%|██▌       | 26/100 [01:42<04:48,  3.90s/it]

loss: 4394.895650  [    0/  160]
lr:  0.2


 27%|██▋       | 27/100 [01:46<04:43,  3.88s/it]

loss: 4466.383919  [    0/  160]
lr:  0.2


 28%|██▊       | 28/100 [01:49<04:38,  3.87s/it]

loss: 4365.783415  [    0/  160]
lr:  0.2


 29%|██▉       | 29/100 [01:53<04:32,  3.84s/it]

loss: 4336.731767  [    0/  160]
lr:  0.2


 30%|███       | 30/100 [01:57<04:31,  3.87s/it]

loss: 4320.889389  [    0/  160]
lr:  0.2


 31%|███       | 31/100 [02:01<04:26,  3.86s/it]

loss: 4236.994112  [    0/  160]
lr:  0.2


 32%|███▏      | 32/100 [02:05<04:23,  3.87s/it]

loss: 4270.410928  [    0/  160]
lr:  0.2


 33%|███▎      | 33/100 [02:09<04:18,  3.86s/it]

loss: 4177.148074  [    0/  160]
lr:  0.2


 34%|███▍      | 34/100 [02:13<04:17,  3.90s/it]

loss: 4178.290074  [    0/  160]
lr:  0.2


 35%|███▌      | 35/100 [02:17<04:13,  3.89s/it]

loss: 4132.965037  [    0/  160]
lr:  0.2


 36%|███▌      | 36/100 [02:20<04:09,  3.89s/it]

loss: 4091.834280  [    0/  160]
lr:  0.2


 37%|███▋      | 37/100 [02:24<04:06,  3.91s/it]

loss: 4078.032457  [    0/  160]
lr:  0.2


 38%|███▊      | 38/100 [02:28<04:03,  3.92s/it]

loss: 4028.052506  [    0/  160]
lr:  0.2


 39%|███▉      | 39/100 [02:32<03:57,  3.90s/it]

loss: 4009.096990  [    0/  160]
lr:  0.2


 40%|████      | 40/100 [02:36<03:57,  3.96s/it]

loss: 3979.444274  [    0/  160]
lr:  0.2


 41%|████      | 41/100 [02:40<03:52,  3.94s/it]

loss: 3935.142857  [    0/  160]
lr:  0.2


 42%|████▏     | 42/100 [02:44<03:50,  3.97s/it]

loss: 3922.559849  [    0/  160]
lr:  0.2


 43%|████▎     | 43/100 [02:48<03:42,  3.90s/it]

loss: 3879.189949  [    0/  160]
lr:  0.2


 44%|████▍     | 44/100 [02:52<03:37,  3.88s/it]

loss: 3851.345019  [    0/  160]
lr:  0.2


 45%|████▌     | 45/100 [02:56<03:32,  3.87s/it]

loss: 3827.289168  [    0/  160]
lr:  0.2


 46%|████▌     | 46/100 [03:00<03:29,  3.89s/it]

loss: 3792.927515  [    0/  160]
lr:  0.2


 47%|████▋     | 47/100 [03:03<03:23,  3.83s/it]

loss: 3767.465005  [    0/  160]
lr:  0.2


 48%|████▊     | 48/100 [03:07<03:19,  3.83s/it]

loss: 3738.847760  [    0/  160]
lr:  0.2


 49%|████▉     | 49/100 [03:11<03:15,  3.83s/it]

loss: 3711.340542  [    0/  160]
lr:  0.2


 50%|█████     | 50/100 [03:15<03:13,  3.87s/it]

loss: 3686.285624  [    0/  160]
lr:  0.2


 51%|█████     | 51/100 [03:19<03:07,  3.82s/it]

loss: 3653.770594  [    0/  160]
lr:  0.2


 52%|█████▏    | 52/100 [03:23<03:04,  3.83s/it]

loss: 3631.073172  [    0/  160]
lr:  0.2


 53%|█████▎    | 53/100 [03:26<02:59,  3.83s/it]

loss: 3604.550025  [    0/  160]
lr:  0.2


 54%|█████▍    | 54/100 [03:30<02:56,  3.85s/it]

loss: 3575.778578  [    0/  160]
lr:  0.2


 55%|█████▌    | 55/100 [03:34<02:52,  3.82s/it]

loss: 3551.203300  [    0/  160]
lr:  0.2


 56%|█████▌    | 56/100 [03:38<02:48,  3.83s/it]

loss: 3523.831043  [    0/  160]
lr:  0.2


 57%|█████▋    | 57/100 [03:42<02:49,  3.95s/it]

loss: 3499.009193  [    0/  160]
lr:  0.2


 58%|█████▊    | 58/100 [03:46<02:47,  3.98s/it]

loss: 3473.751350  [    0/  160]
lr:  0.2


 59%|█████▉    | 59/100 [03:50<02:46,  4.06s/it]

loss: 3449.022474  [    0/  160]
lr:  0.2


 60%|██████    | 60/100 [03:54<02:39,  3.99s/it]

loss: 3422.605155  [    0/  160]
lr:  0.2


 61%|██████    | 61/100 [03:58<02:33,  3.95s/it]

loss: 3397.446574  [    0/  160]
lr:  0.2


 62%|██████▏   | 62/100 [04:02<02:29,  3.94s/it]

loss: 3372.814734  [    0/  160]
lr:  0.2


 63%|██████▎   | 63/100 [04:06<02:24,  3.90s/it]

loss: 3348.591814  [    0/  160]
lr:  0.2


 64%|██████▍   | 64/100 [04:10<02:19,  3.88s/it]

loss: 3324.287280  [    0/  160]
lr:  0.2


 65%|██████▌   | 65/100 [04:13<02:15,  3.87s/it]

loss: 3300.740531  [    0/  160]
lr:  0.2


 66%|██████▌   | 66/100 [04:17<02:11,  3.86s/it]

loss: 3275.313811  [    0/  160]
lr:  0.2


 67%|██████▋   | 67/100 [04:21<02:06,  3.83s/it]

loss: 3251.959536  [    0/  160]
lr:  0.2


 68%|██████▊   | 68/100 [04:25<02:02,  3.83s/it]

loss: 3227.934671  [    0/  160]
lr:  0.2


 69%|██████▉   | 69/100 [04:29<01:58,  3.82s/it]

loss: 3203.818941  [    0/  160]
lr:  0.2


 70%|███████   | 70/100 [04:33<01:56,  3.88s/it]

loss: 3181.338151  [    0/  160]
lr:  0.2


 71%|███████   | 71/100 [04:37<01:53,  3.91s/it]

loss: 3157.809899  [    0/  160]
lr:  0.2


 72%|███████▏  | 72/100 [04:41<01:49,  3.90s/it]

loss: 3134.567302  [    0/  160]
lr:  0.2


 73%|███████▎  | 73/100 [04:44<01:44,  3.88s/it]

loss: 3112.545031  [    0/  160]
lr:  0.2


 74%|███████▍  | 74/100 [04:48<01:40,  3.86s/it]

loss: 3089.463764  [    0/  160]
lr:  0.2


 75%|███████▌  | 75/100 [04:52<01:35,  3.81s/it]

loss: 3067.204567  [    0/  160]
lr:  0.2


 76%|███████▌  | 76/100 [04:56<01:31,  3.82s/it]

loss: 3045.728891  [    0/  160]
lr:  0.2


 77%|███████▋  | 77/100 [05:00<01:27,  3.82s/it]

loss: 3024.130490  [    0/  160]
lr:  0.2


 78%|███████▊  | 78/100 [05:03<01:24,  3.84s/it]

loss: 3004.037799  [    0/  160]
lr:  0.2


 79%|███████▉  | 79/100 [05:07<01:20,  3.82s/it]

loss: 2986.657081  [    0/  160]
lr:  0.2


 80%|████████  | 80/100 [05:11<01:16,  3.83s/it]

loss: 2974.041771  [    0/  160]
lr:  0.2


 81%|████████  | 81/100 [05:15<01:12,  3.84s/it]

loss: 2977.101700  [    0/  160]
lr:  0.2


 82%|████████▏ | 82/100 [05:19<01:09,  3.88s/it]

loss: 3007.740589  [    0/  160]
lr:  0.2


 83%|████████▎ | 83/100 [05:23<01:05,  3.86s/it]

loss: 3072.177468  [    0/  160]
lr:  0.2


 84%|████████▍ | 84/100 [05:27<01:01,  3.86s/it]

loss: 3058.650862  [    0/  160]
lr:  0.2


 85%|████████▌ | 85/100 [05:31<00:58,  3.88s/it]

loss: 2926.051969  [    0/  160]
lr:  0.2


 86%|████████▌ | 86/100 [05:34<00:54,  3.90s/it]

loss: 2857.825534  [    0/  160]
lr:  0.2


 87%|████████▋ | 87/100 [05:38<00:50,  3.89s/it]

loss: 2927.133395  [    0/  160]
lr:  0.2


 88%|████████▊ | 88/100 [05:42<00:46,  3.88s/it]

loss: 2919.344383  [    0/  160]
lr:  0.2


 89%|████████▉ | 89/100 [05:46<00:42,  3.89s/it]

loss: 2859.576430  [    0/  160]
lr:  0.2


 90%|█████████ | 90/100 [05:50<00:39,  3.93s/it]

loss: 2866.135236  [    0/  160]
lr:  0.2


 91%|█████████ | 91/100 [05:54<00:35,  3.90s/it]

loss: 2903.768703  [    0/  160]
lr:  0.2


 92%|█████████▏| 92/100 [05:58<00:31,  3.89s/it]

loss: 2799.662012  [    0/  160]
lr:  0.2


 93%|█████████▎| 93/100 [06:02<00:27,  3.89s/it]

loss: 2848.164093  [    0/  160]
lr:  0.2


 94%|█████████▍| 94/100 [06:06<00:23,  3.91s/it]

loss: 2802.284026  [    0/  160]
lr:  0.2


 95%|█████████▌| 95/100 [06:09<00:19,  3.89s/it]

loss: 2739.898477  [    0/  160]
lr:  0.2


 96%|█████████▌| 96/100 [06:13<00:15,  3.90s/it]

loss: 2746.570226  [    0/  160]
lr:  0.2


 97%|█████████▋| 97/100 [06:17<00:11,  3.89s/it]

loss: 2712.883031  [    0/  160]
lr:  0.2


 98%|█████████▊| 98/100 [06:21<00:07,  3.93s/it]

loss: 2678.038792  [    0/  160]
lr:  0.2


 99%|█████████▉| 99/100 [06:25<00:03,  3.90s/it]

loss: 2670.430919  [    0/  160]
lr:  0.2


100%|██████████| 100/100 [06:29<00:00,  3.90s/it]

loss: 2608.851343  [    0/  160]
lr:  0.2
Done!





loss: 2612.068807  [    0/  160]
lr:  0.2
2612.0688069459147
Test Error: 
 Accuracy: 0.0%, Avg loss: 1601.819688 

1601.8196875253097


In [15]:
W = utils.freq_generator(max_freq, dim).to(device)
model = fm.Fourier_model(W)
model.to(device)
loss_fn = nn.MSELoss(reduction='mean') # equiv. to torch.linalg.norm(input-target)**2
optimizer = torch.optim.Adam(model.parameters(), lr=0.2)

epochs = 200
for t in tqdm(range(epochs)):
    # print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer, printing=True)
    # test(test_dataloader, model, loss_fn)
    # scheduler.step()
print("Done!")
print(train(train_dataloader, model, loss_fn, optimizer, printing=True))
print(test(test_dataloader, model, loss_fn))

 70%|███████   | 140/200 [00:00<00:00, 707.15it/s]

loss: 10536.059660  [    0/  160]
lr:  0.2
loss: 10050.367515  [    0/  160]
lr:  0.2
loss: 9588.613246  [    0/  160]
lr:  0.2
loss: 9150.663023  [    0/  160]
lr:  0.2
loss: 8735.713804  [    0/  160]
lr:  0.2
loss: 8343.156592  [    0/  160]
lr:  0.2
loss: 7972.735636  [    0/  160]
lr:  0.2
loss: 7624.208262  [    0/  160]
lr:  0.2
loss: 7297.212123  [    0/  160]
lr:  0.2
loss: 6991.237904  [    0/  160]
lr:  0.2
loss: 6705.645401  [    0/  160]
lr:  0.2
loss: 6439.703227  [    0/  160]
lr:  0.2
loss: 6192.630106  [    0/  160]
lr:  0.2
loss: 5963.618918  [    0/  160]
lr:  0.2
loss: 5751.840501  [    0/  160]
lr:  0.2
loss: 5556.436570  [    0/  160]
lr:  0.2
loss: 5376.512036  [    0/  160]
lr:  0.2
loss: 5211.132126  [    0/  160]
lr:  0.2
loss: 5059.325473  [    0/  160]
lr:  0.2
loss: 4920.092317  [    0/  160]
lr:  0.2
loss: 4792.416413  [    0/  160]
lr:  0.2
loss: 4675.279190  [    0/  160]
lr:  0.2
loss: 4567.674936  [    0/  160]
lr:  0.2
loss: 4468.625886  [    0/  160]

100%|██████████| 200/200 [00:00<00:00, 706.11it/s]

loss: 2201.906521  [    0/  160]
lr:  0.2
loss: 2194.728726  [    0/  160]
lr:  0.2
loss: 2187.591161  [    0/  160]
lr:  0.2
loss: 2180.493815  [    0/  160]
lr:  0.2
loss: 2173.436672  [    0/  160]
lr:  0.2
loss: 2166.419704  [    0/  160]
lr:  0.2
loss: 2159.442877  [    0/  160]
lr:  0.2
loss: 2152.506146  [    0/  160]
lr:  0.2
loss: 2145.609459  [    0/  160]
lr:  0.2
loss: 2138.752753  [    0/  160]
lr:  0.2
loss: 2131.935961  [    0/  160]
lr:  0.2
loss: 2125.159006  [    0/  160]
lr:  0.2
loss: 2118.421806  [    0/  160]
lr:  0.2
loss: 2111.724271  [    0/  160]
lr:  0.2
loss: 2105.066309  [    0/  160]
lr:  0.2
loss: 2098.447820  [    0/  160]
lr:  0.2
loss: 2091.868702  [    0/  160]
lr:  0.2
loss: 2085.328847  [    0/  160]
lr:  0.2
loss: 2078.828144  [    0/  160]
lr:  0.2
loss: 2072.366479  [    0/  160]
lr:  0.2
loss: 2065.943733  [    0/  160]
lr:  0.2
loss: 2059.559785  [    0/  160]
lr:  0.2
loss: 2053.214510  [    0/  160]
lr:  0.2
loss: 2046.907778  [    0/  160]
l


