In [1]:
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt

# torch.manual_seed(1)    # reproducible

# Hyper Parameters
TIME_STEP = 10      # rnn time step
INPUT_SIZE = 1      # rnn input size
LR = 0.02           # learning rate

# show data
steps = np.linspace(0, np.pi*2, 100, dtype=np.float32)  # float32 for converting torch FloatTensor
x_np = np.sin(steps)
y_np = np.cos(steps)
plt.plot(steps, y_np, 'r-', label='target (cos)')
plt.plot(steps, x_np, 'b-', label='input (sin)')
plt.legend(loc='best')
plt.show()


class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()

        self.rnn = nn.RNN(
            input_size=INPUT_SIZE,
            hidden_size=32,     # rnn hidden unit
            num_layers=1,       # number of rnn layer
            batch_first=True,   # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)
        )
        self.out = nn.Linear(32, 1)

    def forward(self, x, h_state):
        # x (batch, time_step, input_size)
        # h_state (n_layers, batch, hidden_size)
        # r_out (batch, time_step, hidden_size)
        r_out, h_state = self.rnn(x, h_state)

        outs = []    # save all predictions
        for time_step in range(r_out.size(1)):    # calculate output for each time step
            outs.append(self.out(r_out[:, time_step, :]))
        return torch.stack(outs, dim=1), h_state

        # instead, for simplicity, you can replace above codes by follows
        # r_out = r_out.view(-1, 32)
        # outs = self.out(r_out)
        # outs = outs.view(-1, TIME_STEP, 1)
        # return outs, h_state
        
        # or even simpler, since nn.Linear can accept inputs of any dimension 
        # and returns outputs with same dimension except for the last
        # outs = self.out(r_out)
        # return outs

rnn = RNN()
print(rnn)

optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)   # optimize all cnn parameters
loss_func = nn.MSELoss()

h_state = None      # for initial hidden state

plt.figure(1, figsize=(12, 5))
plt.ion()           # continuously plot

for step in range(10):
    start, end = step * np.pi, (step+1)*np.pi   # time range
    # use sin predicts cos
    steps = np.linspace(start, end, TIME_STEP, dtype=np.float32, endpoint=False)  # float32 for converting torch FloatTensor
    x_np = np.sin(steps)
    y_np = np.cos(steps)

    x = torch.from_numpy(x_np[np.newaxis, :, np.newaxis])    # shape (batch, time_step, input_size)
    y = torch.from_numpy(y_np[np.newaxis, :, np.newaxis])
    print(x.size(),y.size())
    prediction, h_state = rnn(x, h_state)   # rnn output
    # !! next step is important !!
    h_state = h_state.data        # repack the hidden state, break the connection from last iteration
#     print(prediction)
#     print(h_state)
    loss = loss_func(prediction, y)         # calculate loss
    optimizer.zero_grad()                   # clear gradients for this training step
    loss.backward()                         # backpropagation, compute gradients
    optimizer.step()                        # apply gradients

    # plotting
    plt.plot(steps, y_np.flatten(), 'r-')
    plt.plot(steps, prediction.data.numpy().flatten(), 'b-')
    plt.draw(); plt.pause(0.05)

plt.ioff()
plt.show()

<Figure size 640x480 with 1 Axes>

RNN(
  (rnn): RNN(1, 32, batch_first=True)
  (out): Linear(in_features=32, out_features=1, bias=True)
)
torch.Size([1, 10, 1]) torch.Size([1, 10, 1])
tensor([[[0.1356],
         [0.1382],
         [0.1275],
         [0.1018],
         [0.0956],
         [0.0842],
         [0.0794],
         [0.0846],
         [0.0953],
         [0.1117]]], grad_fn=<StackBackward>)
tensor([[[-0.1796,  0.0100,  0.3208, -0.0202,  0.2393,  0.4497,  0.2105,
           0.1059, -0.0060,  0.1434,  0.0569,  0.3607, -0.3411,  0.0944,
          -0.1815, -0.2199,  0.3678, -0.0986, -0.1879, -0.1978,  0.0959,
           0.0793,  0.3179, -0.0416,  0.1104, -0.2481,  0.0006, -0.1988,
          -0.1125, -0.0597, -0.1772, -0.2338]]])


<Figure size 1200x500 with 1 Axes>

torch.Size([1, 10, 1]) torch.Size([1, 10, 1])
tensor([[[-0.1653],
         [-0.2419],
         [-0.1909],
         [-0.1614],
         [-0.1300],
         [-0.1127],
         [-0.1109],
         [-0.1253],
         [-0.1536],
         [-0.1916]]], grad_fn=<StackBackward>)
tensor([[[ 0.1612,  0.1669,  0.3009,  0.2364,  0.0261,  0.1794,  0.2416,
           0.2336,  0.0103,  0.1669, -0.2134,  0.4030, -0.2663,  0.1163,
          -0.3165, -0.3507,  0.1620,  0.0698, -0.4437, -0.1183, -0.1427,
          -0.0931,  0.1382, -0.3424,  0.1972, -0.3510, -0.0856, -0.0955,
           0.0585, -0.2476, -0.3974, -0.2929]]])


<Figure size 640x480 with 1 Axes>

torch.Size([1, 10, 1]) torch.Size([1, 10, 1])
tensor([[[-0.1865],
         [-0.1150],
         [-0.0572],
         [-0.0771],
         [-0.1128],
         [-0.1371],
         [-0.1382],
         [-0.1176],
         [-0.0753],
         [-0.0150]]], grad_fn=<StackBackward>)
tensor([[[ 0.0613,  0.0931,  0.0939,  0.0051,  0.0082, -0.1573,  0.0718,
           0.2372,  0.0843,  0.3476, -0.1774, -0.0294, -0.0088,  0.1176,
          -0.2440, -0.1118, -0.0528,  0.0594, -0.2991,  0.0181, -0.3140,
          -0.2333, -0.0734, -0.1516,  0.1116, -0.2464, -0.0468,  0.1504,
           0.4136, -0.0799, -0.4051, -0.1416]]])


<Figure size 640x480 with 1 Axes>

torch.Size([1, 10, 1]) torch.Size([1, 10, 1])
tensor([[[0.2848],
         [0.4324],
         [0.5289],
         [0.5910],
         [0.6436],
         [0.6753],
         [0.6862],
         [0.6721],
         [0.6355],
         [0.5791]]], grad_fn=<StackBackward>)
tensor([[[-0.3110, -0.1758,  0.3462, -0.0338,  0.3475,  0.0385, -0.1986,
          -0.0242, -0.1737,  0.4917,  0.0229, -0.3788, -0.2446, -0.2670,
           0.1318,  0.1031,  0.0678, -0.2254, -0.1970, -0.1955, -0.1974,
          -0.4459, -0.1484, -0.1831,  0.0223, -0.2857, -0.2577, -0.1059,
           0.3414, -0.0857,  0.0587, -0.2553]]])


<Figure size 640x480 with 1 Axes>

torch.Size([1, 10, 1]) torch.Size([1, 10, 1])
tensor([[[ 0.3071],
         [ 0.1960],
         [ 0.0889],
         [ 0.0362],
         [ 0.0019],
         [-0.0216],
         [-0.0309],
         [-0.0217],
         [ 0.0043],
         [ 0.0455]]], grad_fn=<StackBackward>)
tensor([[[-0.0796,  0.0202, -0.0680, -0.1403,  0.0318, -0.0869,  0.0367,
           0.2003,  0.2156,  0.2739, -0.0725,  0.0028,  0.0369,  0.1651,
          -0.2597, -0.1124,  0.0614,  0.0491, -0.1865,  0.0507, -0.1945,
          -0.1590, -0.0730, -0.0541,  0.0547, -0.1653, -0.0588,  0.1509,
           0.3428, -0.0376, -0.3631, -0.1119]]])


<Figure size 640x480 with 1 Axes>

torch.Size([1, 10, 1]) torch.Size([1, 10, 1])
tensor([[[0.0171],
         [0.0261],
         [0.0417],
         [0.1036],
         [0.1648],
         [0.1989],
         [0.2146],
         [0.2152],
         [0.1951],
         [0.1551]]], grad_fn=<StackBackward>)
tensor([[[-0.1978, -0.0585,  0.1924, -0.1646,  0.2286,  0.1719, -0.0050,
          -0.0465, -0.0127,  0.4318,  0.0867, -0.2133, -0.0465, -0.0222,
          -0.0230,  0.0366,  0.1667, -0.0875, -0.1698, -0.0436, -0.1477,
          -0.2180,  0.0429, -0.0647, -0.0211, -0.1941, -0.0776,  0.0561,
           0.2387,  0.0092, -0.0862, -0.0620]]])


<Figure size 640x480 with 1 Axes>

torch.Size([1, 10, 1]) torch.Size([1, 10, 1])
tensor([[[ 0.0191],
         [-0.0649],
         [-0.1636],
         [-0.2452],
         [-0.3106],
         [-0.3584],
         [-0.3860],
         [-0.3939],
         [-0.3861],
         [-0.3621]]], grad_fn=<StackBackward>)
tensor([[[-0.0409,  0.1564, -0.3483, -0.4113, -0.1421,  0.0795,  0.2641,
           0.2703,  0.5115,  0.1441, -0.0131,  0.2929,  0.2429,  0.5137,
          -0.4860, -0.2679,  0.2687,  0.1946, -0.1693,  0.2706, -0.0750,
           0.1737,  0.1151,  0.0319, -0.0127, -0.1264,  0.1371,  0.3348,
           0.1661,  0.0195, -0.5501,  0.0900]]])


<Figure size 640x480 with 1 Axes>

torch.Size([1, 10, 1]) torch.Size([1, 10, 1])
tensor([[[-0.3305],
         [-0.3221],
         [-0.2839],
         [-0.2046],
         [-0.1189],
         [-0.0423],
         [ 0.0293],
         [ 0.0813],
         [ 0.1001],
         [ 0.0861]]], grad_fn=<StackBackward>)
tensor([[[-1.4383e-01, -4.9913e-02,  1.8894e-01, -1.4515e-01,  2.1505e-01,
           2.2697e-01,  3.3415e-03, -1.0656e-01, -2.4659e-02,  4.0497e-01,
           1.0289e-01, -1.8112e-01, -7.4947e-03, -2.6365e-02,  2.4735e-04,
           4.0188e-02,  1.5087e-01, -5.6155e-02, -1.2019e-01, -5.2025e-02,
          -1.2645e-01, -1.7885e-01,  9.3042e-02, -2.3282e-02, -2.6089e-02,
          -1.2256e-01, -5.9817e-02,  4.1941e-02,  2.0374e-01,  5.8645e-02,
          -7.2370e-02, -2.6900e-02]]])


<Figure size 640x480 with 1 Axes>

torch.Size([1, 10, 1]) torch.Size([1, 10, 1])
tensor([[[ 0.0476],
         [-0.0035],
         [-0.0995],
         [-0.2217],
         [-0.3453],
         [-0.4433],
         [-0.5113],
         [-0.5520],
         [-0.5705],
         [-0.5693]]], grad_fn=<StackBackward>)
tensor([[[-0.0398,  0.2429, -0.5270, -0.6016, -0.2809,  0.2095,  0.4127,
           0.3275,  0.6540,  0.0224,  0.0492,  0.4569,  0.3641,  0.7066,
          -0.6206, -0.3686,  0.4168,  0.2776, -0.1782,  0.4092, -0.0258,
           0.3804,  0.2475,  0.1334, -0.0526, -0.1721,  0.2793,  0.4764,
           0.0128,  0.0473, -0.6583,  0.2645]]])


<Figure size 640x480 with 1 Axes>

torch.Size([1, 10, 1]) torch.Size([1, 10, 1])
tensor([[[-0.5318],
         [-0.5109],
         [-0.4375],
         [-0.3173],
         [-0.1645],
         [ 0.0142],
         [ 0.2010],
         [ 0.3596],
         [ 0.4695],
         [ 0.5298]]], grad_fn=<StackBackward>)
tensor([[[-0.1135, -0.2357,  0.5523,  0.3157,  0.4726,  0.2072, -0.3172,
          -0.3683, -0.4499,  0.5140,  0.0447, -0.4488, -0.2874, -0.5745,
           0.4245,  0.2205, -0.1483, -0.2868,  0.0693, -0.3748, -0.1049,
          -0.5138, -0.0777, -0.1645,  0.0008,  0.0416, -0.3431, -0.3652,
           0.2787,  0.0981,  0.3670, -0.2805]]])


<Figure size 640x480 with 1 Axes>