In [1]:
import numpy as np
import torch

torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(torch.cuda.get_device_name(0))

import MLPLoRA

NVIDIA GeForce RTX 3090


# Generate Random Gaussian Data and Fixed Target Function (Two-Layer)

In [2]:
def relu(x):
    return np.maximum(x, 0)


def generate_target_func(input_dim, hidden_dim, output_dim, rank_W1, rank_W2):
    fixed_W1 = np.random.normal(size=(input_dim, hidden_dim)) / np.sqrt(input_dim)
    fixed_W2 = np.random.normal(size=(hidden_dim, output_dim)) / np.sqrt(output_dim + hidden_dim)

    print('Operator Norm of fixed W1: ', np.linalg.norm(fixed_W1, ord=2))
    print('Operator Norm of fixed W2: ', np.linalg.norm(fixed_W2, ord=2))

    A1 = np.random.normal(size=(input_dim, rank_W1)) / np.sqrt(input_dim)
    B1 = np.random.normal(size=(rank_W1, hidden_dim)) / np.sqrt(hidden_dim)
    A2 = np.random.normal(size=(hidden_dim, rank_W2)) / np.sqrt(hidden_dim)
    B2 = np.random.normal(size=(rank_W2, output_dim)) / np.sqrt(output_dim)

    W1 = fixed_W1 + 3 * A1 @ B1
    W2 = fixed_W2 + 3 * A2 @ B2

    print('Operator Norm of W1: ', np.linalg.norm(W1, ord=2))
    print('Operator Norm of W2: ', np.linalg.norm(W2, ord=2))
    
    return lambda X: relu(X @ W1) @ W2, fixed_W1, fixed_W2


def generate_data(input_dim, num_train, num_test, target_func):
    X = np.random.normal(size=(num_train + num_test, input_dim))
    X /= np.linalg.norm(X, axis=1, keepdims=True)
    Y = target_func(X)
    return (X[:num_train], Y[:num_train]), (X[num_train:], Y[num_train:])

def data_to_tensor(data_pair):
    return torch.from_numpy(data_pair[0].astype(np.float32)), torch.from_numpy(data_pair[1].astype(np.float32))

In [3]:
input_dim = 500
hidden_dim = 100
output_dim = 20
num_train_samples = 10000
num_test_samples = 2000
rank_W1 = 10
rank_W2 = 10

target_func, fixed_W1, fixed_W2 = generate_target_func(input_dim, hidden_dim, output_dim, rank_W1, rank_W2)
train_data, test_data = generate_data(input_dim, num_train_samples, num_test_samples, target_func)

Operator Norm of fixed W1:  1.4243566211839866
Operator Norm of fixed W2:  1.2191197458343974
Operator Norm of W1:  4.100107873890544
Operator Norm of W2:  4.640019773511152


In [5]:
import measurements
lr = 10
batch_size = 200
phys_bs = 5000
num_epochs = 300

msrs = measurements.Measurement(verbose=True)
msrs.add_train_recorder('MSE', phys_bs, verbose=True)
msrs.add_test_recorder('MSE', phys_bs, verbose=True)

mlp = MLPLoRA.fully_connected_net(input_dim, output_dim, [hidden_dim], 'relu').cuda()
optimizer = torch.optim.SGD(mlp.parameters(), lr=lr)
loss_fn = torch.nn.MSELoss()

MLPLoRA.train(data_to_tensor(train_data), data_to_tensor(test_data), mlp, loss_fn, optimizer, batch_size, num_epochs, msrs)

Epoch #0
  Metrics on training data:
    MSE: 0.05905507504940033.
  Metrics on testing data:
    MSE: 0.0583227276802063.
Epoch #1
  Metrics on training data:
    MSE: 0.05726661905646324.
  Metrics on testing data:
    MSE: 0.05664518103003502.
Epoch #2
  Metrics on training data:
    MSE: 0.05499308183789253.
  Metrics on testing data:
    MSE: 0.05453081801533699.
Epoch #3
  Metrics on training data:
    MSE: 0.05192647501826286.
  Metrics on testing data:
    MSE: 0.051678940653800964.
Epoch #4
  Metrics on training data:
    MSE: 0.04791194200515747.
  Metrics on testing data:
    MSE: 0.047939784824848175.
Epoch #5
  Metrics on training data:
    MSE: 0.04321407526731491.
  Metrics on testing data:
    MSE: 0.04352974146604538.
Epoch #6
  Metrics on training data:
    MSE: 0.03867059946060181.
  Metrics on testing data:
    MSE: 0.039262909442186356.
Epoch #7
  Metrics on training data:
    MSE: 0.038721006363630295.
  Metrics on testing data:
    MSE: 0.03953607380390167.
Epoch

    MSE: 0.014972185716032982.
  Metrics on testing data:
    MSE: 0.01750289648771286.
Epoch #67
  Metrics on training data:
    MSE: 0.01492446567863226.
  Metrics on testing data:
    MSE: 0.017468607053160667.
Epoch #68
  Metrics on training data:
    MSE: 0.014848470687866211.
  Metrics on testing data:
    MSE: 0.01739986054599285.
Epoch #69
  Metrics on training data:
    MSE: 0.014831366017460823.
  Metrics on testing data:
    MSE: 0.017393646761775017.
Epoch #70
  Metrics on training data:
    MSE: 0.014777373522520065.
  Metrics on testing data:
    MSE: 0.017370009794831276.
Epoch #71
  Metrics on training data:
    MSE: 0.014781555160880089.
  Metrics on testing data:
    MSE: 0.017400147393345833.
Epoch #72
  Metrics on training data:
    MSE: 0.014717819169163704.
  Metrics on testing data:
    MSE: 0.017351916059851646.
Epoch #73
  Metrics on training data:
    MSE: 0.014686309732496738.
  Metrics on testing data:
    MSE: 0.0173238106071949.
Epoch #74
  Metrics on trai

Epoch #133
  Metrics on training data:
    MSE: 0.013048775494098663.
  Metrics on testing data:
    MSE: 0.01683064177632332.
Epoch #134
  Metrics on training data:
    MSE: 0.012996014207601547.
  Metrics on testing data:
    MSE: 0.01677047461271286.
Epoch #135
  Metrics on training data:
    MSE: 0.013069882057607174.
  Metrics on testing data:
    MSE: 0.016895638778805733.
Epoch #136
  Metrics on training data:
    MSE: 0.012944430112838745.
  Metrics on testing data:
    MSE: 0.016814038157463074.
Epoch #137
  Metrics on training data:
    MSE: 0.013503787107765675.
  Metrics on testing data:
    MSE: 0.017408400774002075.
Epoch #138
  Metrics on training data:
    MSE: 0.013061825186014175.
  Metrics on testing data:
    MSE: 0.01693565584719181.
Epoch #139
  Metrics on training data:
    MSE: 0.012730718590319157.
  Metrics on testing data:
    MSE: 0.01661917380988598.
Epoch #140
  Metrics on training data:
    MSE: 0.013148687779903412.
  Metrics on testing data:
    MSE: 0.

Epoch #200
  Metrics on training data:
    MSE: 0.011471958830952644.
  Metrics on testing data:
    MSE: 0.016647160053253174.
Epoch #201
  Metrics on training data:
    MSE: 0.011402335949242115.
  Metrics on testing data:
    MSE: 0.016552060842514038.
Epoch #202
  Metrics on training data:
    MSE: 0.011820375919342041.
  Metrics on testing data:
    MSE: 0.016896989196538925.
Epoch #203
  Metrics on training data:
    MSE: 0.0112290158867836.
  Metrics on testing data:
    MSE: 0.0163981094956398.
Epoch #204
  Metrics on training data:
    MSE: 0.012059350498020649.
  Metrics on testing data:
    MSE: 0.01713424362242222.
Epoch #205
  Metrics on training data:
    MSE: 0.011594091542065144.
  Metrics on testing data:
    MSE: 0.016747470945119858.
Epoch #206
  Metrics on training data:
    MSE: 0.011731425300240517.
  Metrics on testing data:
    MSE: 0.01693541184067726.
Epoch #207
  Metrics on training data:
    MSE: 0.011636110953986645.
  Metrics on testing data:
    MSE: 0.01

Epoch #267
  Metrics on training data:
    MSE: 0.01009021420031786.
  Metrics on testing data:
    MSE: 0.016304505988955498.
Epoch #268
  Metrics on training data:
    MSE: 0.009951578453183174.
  Metrics on testing data:
    MSE: 0.01627219282090664.
Epoch #269
  Metrics on training data:
    MSE: 0.01079690083861351.
  Metrics on testing data:
    MSE: 0.017036564648151398.
Epoch #270
  Metrics on training data:
    MSE: 0.01010918989777565.
  Metrics on testing data:
    MSE: 0.016366582363843918.
Epoch #271
  Metrics on training data:
    MSE: 0.010071857832372189.
  Metrics on testing data:
    MSE: 0.01638125814497471.
Epoch #272
  Metrics on training data:
    MSE: 0.010983407497406006.
  Metrics on testing data:
    MSE: 0.017327001318335533.
Epoch #273
  Metrics on training data:
    MSE: 0.010210339911282063.
  Metrics on testing data:
    MSE: 0.016539350152015686.
Epoch #274
  Metrics on training data:
    MSE: 0.010145113803446293.
  Metrics on testing data:
    MSE: 0.0

In [6]:
msrs = measurements.Measurement(verbose=True)
msrs.add_train_recorder('MSE', phys_bs, verbose=True)
msrs.add_test_recorder('MSE', phys_bs, verbose=True)

lora, lora_layers = MLPLoRA.get_mlp_lora([fixed_W1, fixed_W2], [rank_W1, rank_W2], 'relu')
lora = lora.cuda()
optimizer = torch.optim.SGD(lora.parameters(), lr=50)
loss_fn = torch.nn.MSELoss()

MLPLoRA.train(data_to_tensor(train_data), data_to_tensor(test_data), lora, loss_fn, optimizer, batch_size, num_epochs, msrs)

Epoch #0
  Metrics on training data:
    MSE: 0.04993470758199692.
  Metrics on testing data:
    MSE: 0.04920484498143196.
Epoch #1
  Metrics on training data:
    MSE: 0.038984403014183044.
  Metrics on testing data:
    MSE: 0.03899824619293213.
Epoch #2
  Metrics on training data:
    MSE: 0.03008093126118183.
  Metrics on testing data:
    MSE: 0.0308341383934021.
Epoch #3
  Metrics on training data:
    MSE: 0.02412392571568489.
  Metrics on testing data:
    MSE: 0.02529069595038891.
Epoch #4
  Metrics on training data:
    MSE: 0.019846230745315552.
  Metrics on testing data:
    MSE: 0.021218284964561462.
Epoch #5
  Metrics on training data:
    MSE: 0.01710786484181881.
  Metrics on testing data:
    MSE: 0.01851341314613819.
Epoch #6
  Metrics on training data:
    MSE: 0.015144884586334229.
  Metrics on testing data:
    MSE: 0.01652984321117401.
Epoch #7
  Metrics on training data:
    MSE: 0.013562090694904327.
  Metrics on testing data:
    MSE: 0.014909537509083748.
Epo

Epoch #65
  Metrics on training data:
    MSE: 2.271329549330403e-06.
  Metrics on testing data:
    MSE: 2.570464403106598e-06.
Epoch #66
  Metrics on training data:
    MSE: 1.9803546820185147e-06.
  Metrics on testing data:
    MSE: 2.2358560727298027e-06.
Epoch #67
  Metrics on training data:
    MSE: 1.7312233921984443e-06.
  Metrics on testing data:
    MSE: 1.9498515939631034e-06.
Epoch #68
  Metrics on training data:
    MSE: 1.517231680736586e-06.
  Metrics on testing data:
    MSE: 1.704708552097145e-06.
Epoch #69
  Metrics on training data:
    MSE: 1.3328746035767836e-06.
  Metrics on testing data:
    MSE: 1.4940039818611694e-06.
Epoch #70
  Metrics on training data:
    MSE: 1.1735303360183025e-06.
  Metrics on testing data:
    MSE: 1.31224544475117e-06.
Epoch #71
  Metrics on training data:
    MSE: 1.035328523357748e-06.
  Metrics on testing data:
    MSE: 1.1550952194738784e-06.
Epoch #72
  Metrics on training data:
    MSE: 9.150818414127571e-07.
  Metrics on testing

Epoch #128
  Metrics on training data:
    MSE: 4.6352583993325425e-09.
  Metrics on testing data:
    MSE: 4.784819207515056e-09.
Epoch #129
  Metrics on training data:
    MSE: 4.275855225444047e-09.
  Metrics on testing data:
    MSE: 4.410143805699818e-09.
Epoch #130
  Metrics on training data:
    MSE: 3.94444876761213e-09.
  Metrics on testing data:
    MSE: 4.06534894636934e-09.
Epoch #131
  Metrics on training data:
    MSE: 3.6392178159871946e-09.
  Metrics on testing data:
    MSE: 3.747995247493918e-09.
Epoch #132
  Metrics on training data:
    MSE: 3.3583360536937334e-09.
  Metrics on testing data:
    MSE: 3.4563694129019495e-09.
Epoch #133
  Metrics on training data:
    MSE: 3.099692502672724e-09.
  Metrics on testing data:
    MSE: 3.1877971373717173e-09.
Epoch #134
  Metrics on training data:
    MSE: 2.861362258244071e-09.
  Metrics on testing data:
    MSE: 2.9407210000442774e-09.
Epoch #135
  Metrics on training data:
    MSE: 2.6417823484337077e-09.
  Metrics on t

    MSE: 3.4565027923205704e-11.
  Metrics on testing data:
    MSE: 3.471832890622473e-11.
Epoch #192
  Metrics on training data:
    MSE: 3.203290757647359e-11.
  Metrics on testing data:
    MSE: 3.217186933524019e-11.
Epoch #193
  Metrics on training data:
    MSE: 2.9682818702969627e-11.
  Metrics on testing data:
    MSE: 2.981288826919837e-11.
Epoch #194
  Metrics on training data:
    MSE: 2.7512450651023812e-11.
  Metrics on testing data:
    MSE: 2.7624598789022237e-11.
Epoch #195
  Metrics on training data:
    MSE: 2.549306772625659e-11.
  Metrics on testing data:
    MSE: 2.559044642858055e-11.
Epoch #196
  Metrics on training data:
    MSE: 2.3622531014622083e-11.
  Metrics on testing data:
    MSE: 2.371262734779389e-11.
Epoch #197
  Metrics on training data:
    MSE: 2.1887592432934255e-11.
  Metrics on testing data:
    MSE: 2.196685888744554e-11.
Epoch #198
  Metrics on training data:
    MSE: 2.0277907825105856e-11.
  Metrics on testing data:
    MSE: 2.0349472107383

Epoch #254
  Metrics on training data:
    MSE: 4.122152651771932e-13.
  Metrics on testing data:
    MSE: 4.207804352347744e-13.
Epoch #255
  Metrics on training data:
    MSE: 3.9309563091595445e-13.
  Metrics on testing data:
    MSE: 4.012946930049416e-13.
Epoch #256
  Metrics on training data:
    MSE: 3.754965902016266e-13.
  Metrics on testing data:
    MSE: 3.8422607109844797e-13.
Epoch #257
  Metrics on training data:
    MSE: 3.593093162411465e-13.
  Metrics on testing data:
    MSE: 3.680996750899529e-13.
Epoch #258
  Metrics on training data:
    MSE: 3.4448453204577456e-13.
  Metrics on testing data:
    MSE: 3.5308634813678574e-13.
Epoch #259
  Metrics on training data:
    MSE: 3.307916278134232e-13.
  Metrics on testing data:
    MSE: 3.394682538543359e-13.
Epoch #260
  Metrics on training data:
    MSE: 3.1806233536692263e-13.
  Metrics on testing data:
    MSE: 3.273776378025922e-13.
Epoch #261
  Metrics on training data:
    MSE: 3.064947655482403e-13.
  Metrics on t

In [8]:
import measurements
lr = 50
batch_size = 200
phys_bs = 5000
num_epochs = 300


msrs = measurements.Measurement(verbose=True)
msrs.add_train_recorder('MSE', phys_bs, verbose=True)
msrs.add_test_recorder('MSE', phys_bs, verbose=True)

lora_t = MLPLoRA.get_mlp_lora([fixed_W1, fixed_W2], [10, 10], 'relu')
lora_t = (lora_t[0].cuda(), lora_t[1])
loss_fn = torch.nn.MSELoss()
opt_constr = lambda params: torch.optim.SGD(params, lr=10)

num_com_rounds = 10
num_epoch_per_round = 60

MLPLoRA.train_block_sub_lora(data_to_tensor(train_data), data_to_tensor(test_data), lora_t, [[5, 5], [5, 5]], 'relu',
                             loss_fn, opt_constr, batch_size, num_com_rounds, num_epoch_per_round, msrs)

Epoch #60
  Metrics on training data:
    MSE: 0.015292063355445862.
  Metrics on testing data:
    MSE: 0.016333363950252533.
Epoch #120
  Metrics on training data:
    MSE: 0.012728547677397728.
  Metrics on testing data:
    MSE: 0.013928985223174095.
Epoch #180
  Metrics on training data:
    MSE: 0.012270603328943253.
  Metrics on testing data:
    MSE: 0.013374208472669125.
Epoch #240
  Metrics on training data:
    MSE: 0.01231649611145258.
  Metrics on testing data:
    MSE: 0.013440415263175964.
Epoch #300
  Metrics on training data:
    MSE: 0.012253068387508392.
  Metrics on testing data:
    MSE: 0.013366270810365677.
Epoch #360
  Metrics on training data:
    MSE: 0.012224958278238773.
  Metrics on testing data:
    MSE: 0.013400171883404255.


KeyboardInterrupt: 

In [33]:
def generate_target_func_multi(input_dim, hidden_dim, output_dim, rank_W1, rank_W2, num_funcs):
    fixed_W1 = 1e-6 * np.random.normal(size=(input_dim, hidden_dim)) / np.sqrt(input_dim + hidden_dim)
    fixed_W2 = 1e-6 * np.random.normal(size=(hidden_dim, output_dim)) / np.sqrt(output_dim + hidden_dim)

    print('Operator Norm of fixed W1: ', np.linalg.norm(fixed_W1, ord=2))
    print('Operator Norm of fixed W2: ', np.linalg.norm(fixed_W2, ord=2))
    
    target_funcs = []
    for _ in range(num_funcs):

        A1 = np.random.normal(size=(input_dim, rank_W1)) / np.sqrt(input_dim)
        B1 = np.random.normal(size=(rank_W1, hidden_dim)) / np.sqrt(hidden_dim)
        A2 = np.random.normal(size=(hidden_dim, rank_W2)) / np.sqrt(hidden_dim)
        B2 = np.random.normal(size=(rank_W2, output_dim)) / np.sqrt(output_dim)

        W1 = fixed_W1 + 6 * A1 @ B1
        W2 = fixed_W2 + 6 * A2 @ B2
        
        target_funcs.append(lambda X: relu(X @ W1) @ W2)
    
    return target_funcs, fixed_W1, fixed_W2

In [34]:
input_dim = 500
hidden_dim = 500
output_dim = 500
num_train_samples = 10000
num_test_samples = 2000
rank_W1 = 10
rank_W2 = 10
num_workers = 10

target_funcs, fixed_W1, fixed_W2 = generate_target_func_multi(input_dim, hidden_dim, output_dim, rank_W1, rank_W2, num_workers)
train_data_list = []
test_data_list = []
for t_func in target_funcs:
    cur_train, cur_test = generate_data(input_dim, int(num_train_samples / num_workers), int(num_test_samples / num_workers), t_func)
    train_data_list.append(cur_train)
    test_data_list.append(cur_test)
    
all_train_X = np.concatenate([x for x, _ in train_data_list], axis=0)
all_train_Y = np.concatenate([y for _, y in train_data_list], axis=0)
all_test_X = np.concatenate([x for x, _ in test_data_list], axis=0)
all_test_Y = np.concatenate([y for _, y in test_data_list], axis=0)

train_data = (all_train_X, all_train_Y)
test_data = (all_test_X, all_test_Y)

Operator Norm of fixed W1:  1.418432343326069e-06
Operator Norm of fixed W2:  1.3963590336249437e-06


In [24]:
import measurements
lr = 50
batch_size = 200
phys_bs = 5000
num_epochs = 300


msrs = measurements.Measurement(verbose=True)
msrs.add_train_recorder('MSE', phys_bs, verbose=True)
msrs.add_test_recorder('MSE', phys_bs, verbose=True)

lora_t = MLPLoRA.get_mlp_lora([fixed_W1, fixed_W2], [20, ], 'relu')
lora_t = (lora_t[0].cuda(), lora_t[1])
loss_fn = torch.nn.MSELoss()
opt_constr = lambda params: torch.optim.SGD(params, lr=50)

num_com_rounds = 10
num_epoch_per_round = 60

train_data_tensor_list = [data_to_tensor(data_t) for data_t in train_data_list]

MLPLoRA.train_block_sub_lora_multi(train_data_tensor_list, data_to_tensor(train_data), data_to_tensor(test_data), lora_t, [[8, 8], [8, 8]], 'relu',
                             loss_fn, opt_constr, batch_size, num_com_rounds, num_epoch_per_round, msrs)

Epoch #60
  Metrics on training data:
    MSE: 0.2602985203266144.
  Metrics on testing data:
    MSE: 0.2631555497646332.
Epoch #120
  Metrics on training data:
    MSE: 0.2566103935241699.
  Metrics on testing data:
    MSE: 0.2593747079372406.
Epoch #180
  Metrics on training data:
    MSE: 0.2481585592031479.
  Metrics on testing data:
    MSE: 0.2506304979324341.
Epoch #240
  Metrics on training data:
    MSE: 0.23727284371852875.
  Metrics on testing data:
    MSE: 0.23927879333496094.
Epoch #300
  Metrics on training data:
    MSE: 0.22722892463207245.
  Metrics on testing data:
    MSE: 0.22865332663059235.
Epoch #360
  Metrics on training data:
    MSE: 0.2257036566734314.
  Metrics on testing data:
    MSE: 0.2269832193851471.
Epoch #420
  Metrics on training data:
    MSE: 0.22377130389213562.
  Metrics on testing data:
    MSE: 0.22496439516544342.
Epoch #480
  Metrics on training data:
    MSE: 0.22128358483314514.
  Metrics on testing data:
    MSE: 0.22223785519599915.
E

In [35]:
msrs = measurements.Measurement(verbose=True)
msrs.add_train_recorder('MSE', phys_bs, verbose=True)
msrs.add_test_recorder('MSE', phys_bs, verbose=True)

lora, lora_layers = MLPLoRA.get_mlp_lora([fixed_W1, fixed_W2], [10, 10], 'relu')
lora = lora.cuda()
optimizer = torch.optim.SGD(lora.parameters(), lr=100)
loss_fn = torch.nn.MSELoss()

MLPLoRA.train(data_to_tensor(train_data), data_to_tensor(test_data), lora, loss_fn, optimizer, batch_size, num_epochs, msrs)

Epoch #0
  Metrics on training data:
    MSE: 0.11342192441225052.
  Metrics on testing data:
    MSE: 0.11741086095571518.
Epoch #1
  Metrics on training data:
    MSE: 0.11328649520874023.
  Metrics on testing data:
    MSE: 0.11727248132228851.
Epoch #2
  Metrics on training data:
    MSE: 0.11309223622083664.
  Metrics on testing data:
    MSE: 0.11707468330860138.
Epoch #3
  Metrics on training data:
    MSE: 0.11276525259017944.
  Metrics on testing data:
    MSE: 0.11674168705940247.
Epoch #4
  Metrics on training data:
    MSE: 0.11218676716089249.
  Metrics on testing data:
    MSE: 0.1161518320441246.
Epoch #5
  Metrics on training data:
    MSE: 0.11115474253892899.
  Metrics on testing data:
    MSE: 0.11509814113378525.
Epoch #6
  Metrics on training data:
    MSE: 0.10933151841163635.
  Metrics on testing data:
    MSE: 0.11323495209217072.
Epoch #7
  Metrics on training data:
    MSE: 0.10619838535785675.
  Metrics on testing data:
    MSE: 0.11003153771162033.
Epoch #8


Epoch #67
  Metrics on training data:
    MSE: 0.06490674614906311.
  Metrics on testing data:
    MSE: 0.06963704526424408.
Epoch #68
  Metrics on training data:
    MSE: 0.06383872032165527.
  Metrics on testing data:
    MSE: 0.06856056302785873.
Epoch #69
  Metrics on training data:
    MSE: 0.06260247528553009.
  Metrics on testing data:
    MSE: 0.06731867045164108.
Epoch #70
  Metrics on training data:
    MSE: 0.06123872101306915.
  Metrics on testing data:
    MSE: 0.06595244258642197.
Epoch #71
  Metrics on training data:
    MSE: 0.05982641875743866.
  Metrics on testing data:
    MSE: 0.0645405501127243.
Epoch #72
  Metrics on training data:
    MSE: 0.058469224721193314.
  Metrics on testing data:
    MSE: 0.06318625807762146.
Epoch #73
  Metrics on training data:
    MSE: 0.05725983902812004.
  Metrics on testing data:
    MSE: 0.061982281506061554.
Epoch #74
  Metrics on training data:
    MSE: 0.056244250386953354.
  Metrics on testing data:
    MSE: 0.06097826361656189

    MSE: 0.03928486257791519.
  Metrics on testing data:
    MSE: 0.04318656027317047.
Epoch #134
  Metrics on training data:
    MSE: 0.03921029716730118.
  Metrics on testing data:
    MSE: 0.043100230395793915.
Epoch #135
  Metrics on training data:
    MSE: 0.03913300111889839.
  Metrics on testing data:
    MSE: 0.043010469526052475.
Epoch #136
  Metrics on training data:
    MSE: 0.039052631705999374.
  Metrics on testing data:
    MSE: 0.04291699454188347.
Epoch #137
  Metrics on training data:
    MSE: 0.038968902081251144.
  Metrics on testing data:
    MSE: 0.04281936585903168.
Epoch #138
  Metrics on training data:
    MSE: 0.03888150677084923.
  Metrics on testing data:
    MSE: 0.04271724820137024.
Epoch #139
  Metrics on training data:
    MSE: 0.03879005089402199.
  Metrics on testing data:
    MSE: 0.0426102876663208.
Epoch #140
  Metrics on training data:
    MSE: 0.038694098591804504.
  Metrics on testing data:
    MSE: 0.04249805957078934.
Epoch #141
  Metrics on tra

Epoch #200
  Metrics on training data:
    MSE: 0.02414620853960514.
  Metrics on testing data:
    MSE: 0.026981204748153687.
Epoch #201
  Metrics on training data:
    MSE: 0.023990068584680557.
  Metrics on testing data:
    MSE: 0.02680891938507557.
Epoch #202
  Metrics on training data:
    MSE: 0.023832498118281364.
  Metrics on testing data:
    MSE: 0.026634691283106804.
Epoch #203
  Metrics on training data:
    MSE: 0.02367330528795719.
  Metrics on testing data:
    MSE: 0.026458144187927246.
Epoch #204
  Metrics on training data:
    MSE: 0.023512130603194237.
  Metrics on testing data:
    MSE: 0.02627907693386078.
Epoch #205
  Metrics on training data:
    MSE: 0.023348428308963776.
  Metrics on testing data:
    MSE: 0.026097269728779793.
Epoch #206
  Metrics on training data:
    MSE: 0.023181775584816933.
  Metrics on testing data:
    MSE: 0.02591228298842907.
Epoch #207
  Metrics on training data:
    MSE: 0.02301201783120632.
  Metrics on testing data:
    MSE: 0.02

Epoch #267
  Metrics on training data:
    MSE: 0.01150631345808506.
  Metrics on testing data:
    MSE: 0.012880058959126472.
Epoch #268
  Metrics on training data:
    MSE: 0.011364806443452835.
  Metrics on testing data:
    MSE: 0.012723589316010475.
Epoch #269
  Metrics on training data:
    MSE: 0.011224761605262756.
  Metrics on testing data:
    MSE: 0.012568806298077106.
Epoch #270
  Metrics on training data:
    MSE: 0.011086056008934975.
  Metrics on testing data:
    MSE: 0.012415486387908459.
Epoch #271
  Metrics on training data:
    MSE: 0.010948739014565945.
  Metrics on testing data:
    MSE: 0.012263961136341095.
Epoch #272
  Metrics on training data:
    MSE: 0.010812926106154919.
  Metrics on testing data:
    MSE: 0.01211429201066494.
Epoch #273
  Metrics on training data:
    MSE: 0.010678568854928017.
  Metrics on testing data:
    MSE: 0.011966301128268242.
Epoch #274
  Metrics on training data:
    MSE: 0.0105457017198205.
  Metrics on testing data:
    MSE: 0.

In [32]:
msrs = measurements.Measurement(verbose=True)
msrs.add_train_recorder('MSE', phys_bs, verbose=True)
msrs.add_test_recorder('MSE', phys_bs, verbose=True)

mlp = MLPLoRA.fully_connected_net(input_dim, output_dim, [hidden_dim], 'relu').cuda()
optimizer = torch.optim.SGD(mlp.parameters(), lr=lr)
loss_fn = torch.nn.MSELoss()

MLPLoRA.train(data_to_tensor(train_data), data_to_tensor(test_data), mlp, loss_fn, optimizer, batch_size, num_epochs, msrs)

Epoch #0
  Metrics on training data:
    MSE: 0.2620835304260254.
  Metrics on testing data:
    MSE: 0.26590004563331604.
Epoch #1
  Metrics on training data:
    MSE: 0.2573451101779938.
  Metrics on testing data:
    MSE: 0.2612677812576294.
Epoch #2
  Metrics on training data:
    MSE: 0.254048615694046.
  Metrics on testing data:
    MSE: 0.25804969668388367.
Epoch #3
  Metrics on training data:
    MSE: 0.25159478187561035.
  Metrics on testing data:
    MSE: 0.2556613087654114.
Epoch #4
  Metrics on training data:
    MSE: 0.24962179362773895.
  Metrics on testing data:
    MSE: 0.2537543475627899.
Epoch #5
  Metrics on training data:
    MSE: 0.24794022738933563.
  Metrics on testing data:
    MSE: 0.2521304786205292.
Epoch #6
  Metrics on training data:
    MSE: 0.24641546607017517.
  Metrics on testing data:
    MSE: 0.2506639361381531.
Epoch #7
  Metrics on training data:
    MSE: 0.24496327340602875.
  Metrics on testing data:
    MSE: 0.24927657842636108.
Epoch #8
  Metric

    MSE: 0.10432411730289459.
  Metrics on testing data:
    MSE: 0.1093147024512291.
Epoch #67
  Metrics on training data:
    MSE: 0.10378407686948776.
  Metrics on testing data:
    MSE: 0.10879385471343994.
Epoch #68
  Metrics on training data:
    MSE: 0.10326500982046127.
  Metrics on testing data:
    MSE: 0.10829459130764008.
Epoch #69
  Metrics on training data:
    MSE: 0.10276620835065842.
  Metrics on testing data:
    MSE: 0.10781523585319519.
Epoch #70
  Metrics on training data:
    MSE: 0.1022857204079628.
  Metrics on testing data:
    MSE: 0.10735505819320679.
Epoch #71
  Metrics on training data:
    MSE: 0.10182318091392517.
  Metrics on testing data:
    MSE: 0.10691263526678085.
Epoch #72
  Metrics on training data:
    MSE: 0.10137735307216644.
  Metrics on testing data:
    MSE: 0.10648726671934128.
Epoch #73
  Metrics on training data:
    MSE: 0.10094769299030304.
  Metrics on testing data:
    MSE: 0.1060781329870224.
Epoch #74
  Metrics on training data:
   

Epoch #135
  Metrics on training data:
    MSE: 0.08848083019256592.
  Metrics on testing data:
    MSE: 0.09456223249435425.
Epoch #136
  Metrics on training data:
    MSE: 0.0883832722902298.
  Metrics on testing data:
    MSE: 0.09447570145130157.
Epoch #137
  Metrics on training data:
    MSE: 0.08828867226839066.
  Metrics on testing data:
    MSE: 0.09439250826835632.
Epoch #138
  Metrics on training data:
    MSE: 0.08819665014743805.
  Metrics on testing data:
    MSE: 0.09431156516075134.
Epoch #139
  Metrics on training data:
    MSE: 0.0881032720208168.
  Metrics on testing data:
    MSE: 0.0942293107509613.
Epoch #140
  Metrics on training data:
    MSE: 0.08801336586475372.
  Metrics on testing data:
    MSE: 0.09415014833211899.
Epoch #141
  Metrics on training data:
    MSE: 0.08792410790920258.
  Metrics on testing data:
    MSE: 0.09407199174165726.
Epoch #142
  Metrics on training data:
    MSE: 0.08783537149429321.
  Metrics on testing data:
    MSE: 0.09399444609880

Epoch #202
  Metrics on training data:
    MSE: 0.08373244106769562.
  Metrics on testing data:
    MSE: 0.09048926085233688.
Epoch #203
  Metrics on training data:
    MSE: 0.0836765244603157.
  Metrics on testing data:
    MSE: 0.09044252336025238.
Epoch #204
  Metrics on training data:
    MSE: 0.08362238854169846.
  Metrics on testing data:
    MSE: 0.09039729088544846.
Epoch #205
  Metrics on training data:
    MSE: 0.08356702327728271.
  Metrics on testing data:
    MSE: 0.09035121649503708.
Epoch #206
  Metrics on training data:
    MSE: 0.08351138979196548.
  Metrics on testing data:
    MSE: 0.09030510485172272.
Epoch #207
  Metrics on training data:
    MSE: 0.08345808833837509.
  Metrics on testing data:
    MSE: 0.09026042371988297.
Epoch #208
  Metrics on training data:
    MSE: 0.08340266346931458.
  Metrics on testing data:
    MSE: 0.09021435678005219.
Epoch #209
  Metrics on training data:
    MSE: 0.08334919810295105.
  Metrics on testing data:
    MSE: 0.090170040726

Epoch #270
  Metrics on training data:
    MSE: 0.0804310068488121.
  Metrics on testing data:
    MSE: 0.08779694139957428.
Epoch #271
  Metrics on training data:
    MSE: 0.08039099723100662.
  Metrics on testing data:
    MSE: 0.08776562660932541.
Epoch #272
  Metrics on training data:
    MSE: 0.08034685999155045.
  Metrics on testing data:
    MSE: 0.08773041516542435.
Epoch #273
  Metrics on training data:
    MSE: 0.08030465245246887.
  Metrics on testing data:
    MSE: 0.0876968652009964.
Epoch #274
  Metrics on training data:
    MSE: 0.08026278764009476.
  Metrics on testing data:
    MSE: 0.08766406774520874.
Epoch #275
  Metrics on training data:
    MSE: 0.08022008091211319.
  Metrics on testing data:
    MSE: 0.08762966841459274.
Epoch #276
  Metrics on training data:
    MSE: 0.08017876744270325.
  Metrics on testing data:
    MSE: 0.08759710192680359.
Epoch #277
  Metrics on training data:
    MSE: 0.08013834804296494.
  Metrics on testing data:
    MSE: 0.0875652879476