In [2]:
import numpy as onp
import jax.numpy as np
from jax import random, grad, vmap, jit
from jax.experimental import optimizers
from jax.flatten_util import ravel_pytree
from jax import lax

import itertools
from functools import partial
from torch.utils import data
from tqdm import trange

import scipy.io
from scipy.interpolate import griddata

import matplotlib.pyplot as plt

In [2]:
L = 2

In [3]:
def input_encoding(t, x):
    w = 2.0 * np.pi / L
    k = np.arange(1, M + 1)
    out = np.hstack([t, 1, 
                     np.cos(k * w * x), np.sin(k * w * x)])
    return out

In [13]:
M = 5

In [14]:
t = 2


In [15]:
x = 2

In [16]:
out = input_encoding(t, x)

In [18]:
print(out.shape)

(12,)


In [4]:
t = np.linspace(0, 1, 10)
x = np.linspace(-1, 1, 24)



In [5]:
t

DeviceArray([0.        , 0.11111111, 0.22222222, 0.33333334, 0.44444445,
             0.5555556 , 0.6666667 , 0.7777778 , 0.8888889 , 1.        ],            dtype=float32)

In [6]:
x

DeviceArray([-1.        , -0.9130435 , -0.826087  , -0.7391305 ,
             -0.6521739 , -0.5652174 , -0.47826087, -0.39130437,
             -0.30434778, -0.21739128, -0.13043478, -0.04347825,
              0.04347825,  0.13043475,  0.21739125,  0.30434787,
              0.39130437,  0.47826087,  0.5652174 ,  0.6521739 ,
              0.7391304 ,  0.826087  ,  0.9130435 ,  1.        ],            dtype=float32)

In [7]:
tt ,xx = np.meshgrid(t, x)

In [8]:
tt

DeviceArray([[0.        , 0.11111111, 0.22222222, 0.33333334, 0.44444445,
              0.5555556 , 0.6666667 , 0.7777778 , 0.8888889 , 1.        ],
             [0.        , 0.11111111, 0.22222222, 0.33333334, 0.44444445,
              0.5555556 , 0.6666667 , 0.7777778 , 0.8888889 , 1.        ],
             [0.        , 0.11111111, 0.22222222, 0.33333334, 0.44444445,
              0.5555556 , 0.6666667 , 0.7777778 , 0.8888889 , 1.        ],
             [0.        , 0.11111111, 0.22222222, 0.33333334, 0.44444445,
              0.5555556 , 0.6666667 , 0.7777778 , 0.8888889 , 1.        ],
             [0.        , 0.11111111, 0.22222222, 0.33333334, 0.44444445,
              0.5555556 , 0.6666667 , 0.7777778 , 0.8888889 , 1.        ],
             [0.        , 0.11111111, 0.22222222, 0.33333334, 0.44444445,
              0.5555556 , 0.6666667 , 0.7777778 , 0.8888889 , 1.        ],
             [0.        , 0.11111111, 0.22222222, 0.33333334, 0.44444445,
              0.5555556 , 0.6666

In [9]:
xx

DeviceArray([[-1.        , -1.        , -1.        , -1.        ,
              -1.        , -1.        , -1.        , -1.        ,
              -1.        , -1.        ],
             [-0.9130435 , -0.9130435 , -0.9130435 , -0.9130435 ,
              -0.9130435 , -0.9130435 , -0.9130435 , -0.9130435 ,
              -0.9130435 , -0.9130435 ],
             [-0.826087  , -0.826087  , -0.826087  , -0.826087  ,
              -0.826087  , -0.826087  , -0.826087  , -0.826087  ,
              -0.826087  , -0.826087  ],
             [-0.7391305 , -0.7391305 , -0.7391305 , -0.7391305 ,
              -0.7391305 , -0.7391305 , -0.7391305 , -0.7391305 ,
              -0.7391305 , -0.7391305 ],
             [-0.6521739 , -0.6521739 , -0.6521739 , -0.6521739 ,
              -0.6521739 , -0.6521739 , -0.6521739 , -0.6521739 ,
              -0.6521739 , -0.6521739 ],
             [-0.5652174 , -0.5652174 , -0.5652174 , -0.5652174 ,
              -0.5652174 , -0.5652174 , -0.5652174 , -0.5652174 ,
   

In [10]:
tt.reshape(-1, 1)

DeviceArray([[0.        ],
             [0.11111111],
             [0.22222222],
             [0.33333334],
             [0.44444445],
             [0.5555556 ],
             [0.6666667 ],
             [0.7777778 ],
             [0.8888889 ],
             [1.        ],
             [0.        ],
             [0.11111111],
             [0.22222222],
             [0.33333334],
             [0.44444445],
             [0.5555556 ],
             [0.6666667 ],
             [0.7777778 ],
             [0.8888889 ],
             [1.        ],
             [0.        ],
             [0.11111111],
             [0.22222222],
             [0.33333334],
             [0.44444445],
             [0.5555556 ],
             [0.6666667 ],
             [0.7777778 ],
             [0.8888889 ],
             [1.        ],
             [0.        ],
             [0.11111111],
             [0.22222222],
             [0.33333334],
             [0.44444445],
             [0.5555556 ],
             [0.6666667 ],
 

In [11]:
xx.reshape(-1, 1)

DeviceArray([[-1.        ],
             [-1.        ],
             [-1.        ],
             [-1.        ],
             [-1.        ],
             [-1.        ],
             [-1.        ],
             [-1.        ],
             [-1.        ],
             [-1.        ],
             [-0.9130435 ],
             [-0.9130435 ],
             [-0.9130435 ],
             [-0.9130435 ],
             [-0.9130435 ],
             [-0.9130435 ],
             [-0.9130435 ],
             [-0.9130435 ],
             [-0.9130435 ],
             [-0.9130435 ],
             [-0.826087  ],
             [-0.826087  ],
             [-0.826087  ],
             [-0.826087  ],
             [-0.826087  ],
             [-0.826087  ],
             [-0.826087  ],
             [-0.826087  ],
             [-0.826087  ],
             [-0.826087  ],
             [-0.7391305 ],
             [-0.7391305 ],
             [-0.7391305 ],
             [-0.7391305 ],
             [-0.7391305 ],
             [-0.739

In [1]:
import numpy as np

In [2]:
gp = np.array([0.8611363116,-0.8611363116,0.3399810436,-0.3399810436])
gc = np.array([0.3478548451,0.3478548451,0.6521451549,0.6521451549])

In [72]:
# 生成数据集，训练点是高斯点，并且给出对应的高斯权重
l = np.linspace(-1, 1, 26)[:, None]
l = np.hstack([l[:-1], l[1:]])
c = (l[:, 1] - l[:, 0])/2
c = c[:, None]
gp = gp[None, :]
d = ((l[:, 1] + l[:, 0])/2)
d = d[:, None]
n_p = c * gp + d  # n_p是100个高斯积分点
n_p = n_p.reshape(100, 1)
gcl = c * gc     # gcl是这100个高斯积分点对应的权重  
gcl = gcl.reshape(100, 1)
t = np.linspace(0, 1, 128)[:, None]
t, x = np.meshgrid(t, n_p) #   mesh size (100, 128)
print(t.shape, x.shape)
txg = np.hstack([t.reshape(-1, 1), x.reshape(-1, 1)])

(100, 128) (100, 128)


In [27]:
txg.shape

(12800, 2)

In [28]:
txg[:129, :]

array([[ 0.        , -0.92555455],
       [ 0.00787402, -0.92555455],
       [ 0.01574803, -0.92555455],
       [ 0.02362205, -0.92555455],
       [ 0.03149606, -0.92555455],
       [ 0.03937008, -0.92555455],
       [ 0.04724409, -0.92555455],
       [ 0.05511811, -0.92555455],
       [ 0.06299213, -0.92555455],
       [ 0.07086614, -0.92555455],
       [ 0.07874016, -0.92555455],
       [ 0.08661417, -0.92555455],
       [ 0.09448819, -0.92555455],
       [ 0.1023622 , -0.92555455],
       [ 0.11023622, -0.92555455],
       [ 0.11811024, -0.92555455],
       [ 0.12598425, -0.92555455],
       [ 0.13385827, -0.92555455],
       [ 0.14173228, -0.92555455],
       [ 0.1496063 , -0.92555455],
       [ 0.15748031, -0.92555455],
       [ 0.16535433, -0.92555455],
       [ 0.17322835, -0.92555455],
       [ 0.18110236, -0.92555455],
       [ 0.18897638, -0.92555455],
       [ 0.19685039, -0.92555455],
       [ 0.20472441, -0.92555455],
       [ 0.21259843, -0.92555455],
       [ 0.22047244,

In [58]:
txg[:, 0].reshape(100, 128)

array([[0.        , 0.00787402, 0.01574803, ..., 0.98425197, 0.99212598,
        1.        ],
       [0.        , 0.00787402, 0.01574803, ..., 0.98425197, 0.99212598,
        1.        ],
       [0.        , 0.00787402, 0.01574803, ..., 0.98425197, 0.99212598,
        1.        ],
       ...,
       [0.        , 0.00787402, 0.01574803, ..., 0.98425197, 0.99212598,
        1.        ],
       [0.        , 0.00787402, 0.01574803, ..., 0.98425197, 0.99212598,
        1.        ],
       [0.        , 0.00787402, 0.01574803, ..., 0.98425197, 0.99212598,
        1.        ]])

In [59]:
txg[:, 1].reshape(100, 128)

array([[-0.92555455, -0.92555455, -0.92555455, ..., -0.92555455,
        -0.92555455, -0.92555455],
       [-0.99444545, -0.99444545, -0.99444545, ..., -0.99444545,
        -0.99444545, -0.99444545],
       [-0.94640076, -0.94640076, -0.94640076, ..., -0.94640076,
        -0.94640076, -0.94640076],
       ...,
       [ 0.92555455,  0.92555455,  0.92555455, ...,  0.92555455,
         0.92555455,  0.92555455],
       [ 0.97359924,  0.97359924,  0.97359924, ...,  0.97359924,
         0.97359924,  0.97359924],
       [ 0.94640076,  0.94640076,  0.94640076, ...,  0.94640076,
         0.94640076,  0.94640076]])

In [50]:
txg.shape

(12800, 2)

In [51]:
p1 = txg[:, 1].reshape(100, 128)

In [60]:
p1[:, 0]

array([-0.92555455, -0.92555455, -0.99444545, -0.94640076, -0.97359924,
       -0.97359924, -0.84555455, -0.91444545, -0.86640076, -0.89359924,
       -0.89359924, -0.76555455, -0.83444545, -0.78640076, -0.78640076,
       -0.81359924, -0.68555455, -0.75444545, -0.70640076, -0.70640076,
       -0.73359924, -0.60555455, -0.67444545, -0.67444545, -0.62640076,
       -0.65359924, -0.52555455, -0.59444545, -0.59444545, -0.54640076,
       -0.57359924, -0.44555455, -0.51444545, -0.51444545, -0.46640076,
       -0.49359924, -0.36555455, -0.36555455, -0.43444545, -0.38640076,
       -0.41359924, -0.28555455, -0.28555455, -0.35444545, -0.30640076,
       -0.33359924, -0.33359924, -0.20555455, -0.27444545, -0.22640076,
       -0.25359924, -0.25359924, -0.12555455, -0.19444545, -0.14640076,
       -0.14640076, -0.17359924, -0.04555455, -0.11444545, -0.06640076,
       -0.06640076, -0.09359924,  0.03444545, -0.03444545,  0.01359924,
        0.01359924, -0.01359924,  0.11444545,  0.04555455,  0.04

In [61]:
import math
quard = np.cos(math.pi * txg[:, 1].reshape(100, 128))

In [62]:
quard.shape

(100, 128)

In [63]:
gc.shape

(4,)

In [64]:
gcl.shape

(25, 4)

In [65]:
gcl = gcl.reshape(100, 1)

In [67]:
quard.shape, gcl.shape

((100, 128), (100, 1))

In [68]:
sum1 = quard ** 2 * gcl

In [69]:
sum2 = np.sum(sum1, axis = 0)

In [70]:
sum2.shape

(128,)

In [71]:
sum2

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [73]:
import math
quard = np.cos(math.pi * txg[:, 1].reshape(100, 128))
sum1 = quard ** 2 * gcl
sum2 = np.sum(sum1, axis = 0)

In [74]:
sum2

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [75]:
a = np.array([1,2,3])
a.ndim

1

In [76]:
a.shape

(3,)

In [77]:
a.shape[0]

3

In [78]:
a.reshape(3, 1)

array([[1],
       [2],
       [3]])

In [79]:
a


array([1, 2, 3])

In [80]:
a.shape

(3,)

In [81]:
b = a.reshape(3, 1)

In [82]:
b.shape

(3, 1)

In [83]:
b.shape[1]

1

In [1]:
import torch

In [2]:
from torch import nn

In [3]:
conv2d = nn.Conv2d(1, 1, kernel_size = (1, 2), bias = False)

In [4]:
conv2d.weight

Parameter containing:
tensor([[[[ 0.5213, -0.3201]]]], requires_grad=True)

In [6]:
conv2d.weight.data

tensor([[[[ 0.5213, -0.3201]]]])

In [7]:
conv2d.weight.data = torch.tensor([[[[ 1, 2]]]])

In [8]:
conv2d.weight.data 

tensor([[[[1, 2]]]])

In [9]:
conv2d.weight

Parameter containing:
tensor([[[[1, 2]]]], requires_grad=True)

In [10]:
conv2d = nn.Conv2d(1, 1, kernel_size = (1, 2), bias = False)
X = torch.ones((6, 8))
X[:, 2:6] = 0

In [16]:
Y = torch.zeros((6, 7))

In [17]:
Y[:, 1] = 1

In [18]:
Y

tensor([[0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.]])

In [19]:
Y[:, -2] = -1

In [24]:
X

tensor([[1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.]])

In [27]:
X.is_contiguous()

True

In [28]:
X.resize_(1,1,6,8)

tensor([[[[1., 1., 0., 0., 0., 0., 1., 1.],
          [1., 1., 0., 0., 0., 0., 1., 1.],
          [1., 1., 0., 0., 0., 0., 1., 1.],
          [1., 1., 0., 0., 0., 0., 1., 1.],
          [1., 1., 0., 0., 0., 0., 1., 1.],
          [1., 1., 0., 0., 0., 0., 1., 1.]]]])

In [29]:
X

tensor([[[[1., 1., 0., 0., 0., 0., 1., 1.],
          [1., 1., 0., 0., 0., 0., 1., 1.],
          [1., 1., 0., 0., 0., 0., 1., 1.],
          [1., 1., 0., 0., 0., 0., 1., 1.],
          [1., 1., 0., 0., 0., 0., 1., 1.],
          [1., 1., 0., 0., 0., 0., 1., 1.]]]])

In [31]:
Y.resize_(1,1,6,7)

tensor([[[[ 0.,  1.,  0.,  0.,  0., -1.,  0.],
          [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
          [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
          [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
          [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
          [ 0.,  1.,  0.,  0.,  0., -1.,  0.]]]])

In [32]:
conv2d

Conv2d(1, 1, kernel_size=(1, 2), stride=(1, 1), bias=False)

In [34]:
for i in range(10):
    y = conv2d(X)
    l = (y - Y) ** 2
    conv2d.zero_grad()
    l.sum().backward()
    conv2d.weight.data -= 3e-2 * conv2d.weight.grad
    print(conv2d.weight.data)
    print(f"epoch:{i}", f"loss:{l.sum():.3e}")

tensor([[[[-0.0676, -0.6161]]]])
epoch:0 loss:2.639e+01
tensor([[[[ 0.8090, -0.2620]]]])
epoch:1 loss:1.333e+01
tensor([[[[ 0.4839, -0.9215]]]])
epoch:2 loss:7.077e+00
tensor([[[[ 0.9848, -0.6347]]]])
epoch:3 loss:3.933e+00
tensor([[[[ 0.7382, -1.0183]]]])
epoch:4 loss:2.273e+00
tensor([[[[ 1.0341, -0.8100]]]])
epoch:5 loss:1.354e+00
tensor([[[[ 0.8605, -1.0397]]]])
epoch:6 loss:8.259e-01
tensor([[[[ 1.0398, -0.8964]]]])
epoch:7 loss:5.118e-01
tensor([[[[ 0.9222, -1.0369]]]])
epoch:8 loss:3.207e-01
tensor([[[[ 1.0328, -0.9410]]]])
epoch:9 loss:2.024e-01


In [35]:
for i in range(10):
    y = conv2d(X)
    l = (y - Y) ** 2
    conv2d.zero_grad()
    l.sum().backward()
    conv2d.weight.data[:] -= 3e-2 * conv2d.weight.grad
    print(conv2d.weight.data)
    print(f"epoch:{i}", f"loss:{l.sum():.3e}")

tensor([[[[ 0.9549, -1.0283]]]])
epoch:0 loss:1.284e-01
tensor([[[[ 1.0240, -0.9653]]]])
epoch:1 loss:8.169e-02
tensor([[[[ 0.9731, -1.0201]]]])
epoch:2 loss:5.209e-02
tensor([[[[ 1.0166, -0.9790]]]])
epoch:3 loss:3.326e-02
tensor([[[[ 0.9836, -1.0136]]]])
epoch:4 loss:2.125e-02
tensor([[[[ 1.0111, -0.9871]]]])
epoch:5 loss:1.359e-02
tensor([[[[ 0.9898, -1.0090]]]])
epoch:6 loss:8.691e-03
tensor([[[[ 1.0073, -0.9919]]]])
epoch:7 loss:5.560e-03
tensor([[[[ 0.9936, -1.0059]]]])
epoch:8 loss:3.558e-03
tensor([[[[ 1.0048, -0.9949]]]])
epoch:9 loss:2.276e-03


In [37]:
X.shape

torch.Size([1, 1, 6, 8])

In [38]:
(1, 1) + X.shape

(1, 1, 1, 1, 6, 8)

In [40]:
X.shape[2:]

torch.Size([6, 8])

In [41]:
a = torch.arange(24).reshape(2, 3, 4)

In [42]:
a

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])

In [43]:
b = torch.arange(40).reshape(2, 4, 5)

In [44]:
b

tensor([[[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9],
         [10, 11, 12, 13, 14],
         [15, 16, 17, 18, 19]],

        [[20, 21, 22, 23, 24],
         [25, 26, 27, 28, 29],
         [30, 31, 32, 33, 34],
         [35, 36, 37, 38, 39]]])

In [45]:
c = [ x @ y for x, y in zip(a,b) ]

In [46]:
c

[tensor([[ 70,  76,  82,  88,  94],
         [190, 212, 234, 256, 278],
         [310, 348, 386, 424, 462]]),
 tensor([[1510, 1564, 1618, 1672, 1726],
         [1950, 2020, 2090, 2160, 2230],
         [2390, 2476, 2562, 2648, 2734]])]

In [47]:
sum(c)

tensor([[1580, 1640, 1700, 1760, 1820],
        [2140, 2232, 2324, 2416, 2508],
        [2700, 2824, 2948, 3072, 3196]])

In [49]:
K = torch.tensor([[[0,1 ], [2, 3]], [[1, 2], [3, 4]]]).float()

In [53]:
K1 = torch.stack((K, K + 1, K + 2), 0)