<a href="https://colab.research.google.com/github/ThousandAI/Application-of-AI/blob/main/class08/LSTM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
from sklearn.metrics import mean_absolute_percentage_error ,mean_squared_error

In [2]:
class StockDataset(Dataset):
    def __init__(self,start,end,seq_length,dim,mode,train_dataset=None):
        self.seq_length = seq_length
        self.data = np.array(pd.read_csv("/content/drive/MyDrive/Colab-Notebooks/Application-of-AI/class08/stock.csv"))[start:end,:-1]
        if mode == 'train':
            self.sc = StandardScaler()
            self.sc_close = StandardScaler()
            self.sc_close.fit(self.data[:,1:2])
            self.data = self.sc.fit_transform(self.data)
        elif mode == 'test':
            self.data = train_dataset.sc.fit_transform(self.data)
        self.n_samples = len(self.data) - self.seq_length
        self.x = np.zeros((self.n_samples,seq_length,dim))
        self.y = np.zeros((self.n_samples, 1))
        for i in range(len(self.data)-self.seq_length):
            self.x[i,:,:] = self.data[i:i+self.seq_length,:]
            self.y[i,0] = self.data[i+seq_length,1]

    def __getitem__(self, index):
        return self.x[index], self.y[index]

    def __len__(self):
        return self.n_samples


In [3]:
# hyper parameters
input_size = 11 # 股價維度
seq_length = 90 # 序列長度
hidden_size = 256
num_layers = 2
epochs = 100
batch_size = 16
learning_rate = 1e-3
start = 0
end = 5000
test_start = 5000
test_end = 5900
# device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
# dataset
train_dataset = StockDataset(start=start,end=end,seq_length=seq_length,dim=input_size,mode="train")
test_dataset = StockDataset(start=test_start,end=test_end,seq_length=seq_length,dim=input_size,mode="test",
                            train_dataset=train_dataset)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle= True)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [5]:
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(LSTM, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size,1)

    def forward(self, x):
        # out: tensor of shape (batch_size, seq_length, hidden_size)
        # out: (n, 15, 768)
        out,_ = self.lstm(x)


        # Decode the hidden state of the last time step
        out = out[:, -1, :]
        # out: (n, 768)

        out = self.fc(out)
        # out: (n, 1)
        return out

In [6]:
lstm = LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers).to(device)
optimizer = torch.optim.Adam(lstm.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

In [7]:
# Train the model

mape_loss = 0
rmse_loss = 0
for epoch in tqdm(range(epochs)):
    for i, (inputs, labels) in enumerate(train_dataloader):
        inputs = inputs.type(torch.FloatTensor).to(device)
        labels = labels.type(torch.FloatTensor).to(device)

        # Forward pass
        outputs = lstm(inputs)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.8f}')

  1%|          | 1/100 [00:03<06:26,  3.91s/it]

Epoch [1/100], Loss: 0.00126820


  2%|▏         | 2/100 [00:08<07:13,  4.43s/it]

Epoch [2/100], Loss: 0.00467401


  3%|▎         | 3/100 [00:11<06:08,  3.80s/it]

Epoch [3/100], Loss: 0.01101921


  4%|▍         | 4/100 [00:15<06:06,  3.81s/it]

Epoch [4/100], Loss: 0.00254491


  5%|▌         | 5/100 [00:18<05:47,  3.66s/it]

Epoch [5/100], Loss: 0.02821011


  6%|▌         | 6/100 [00:22<05:29,  3.50s/it]

Epoch [6/100], Loss: 0.00459815


  7%|▋         | 7/100 [00:24<04:52,  3.14s/it]

Epoch [7/100], Loss: 0.01406800


  8%|▊         | 8/100 [00:26<04:27,  2.90s/it]

Epoch [8/100], Loss: 0.00400412


  9%|▉         | 9/100 [00:29<04:09,  2.75s/it]

Epoch [9/100], Loss: 0.01076903


 10%|█         | 10/100 [00:31<03:57,  2.64s/it]

Epoch [10/100], Loss: 0.00180501


 11%|█         | 11/100 [00:34<03:48,  2.57s/it]

Epoch [11/100], Loss: 0.00408265


 12%|█▏        | 12/100 [00:36<03:41,  2.52s/it]

Epoch [12/100], Loss: 0.00309705


 13%|█▎        | 13/100 [00:38<03:36,  2.49s/it]

Epoch [13/100], Loss: 0.00216733


 14%|█▍        | 14/100 [00:41<03:32,  2.47s/it]

Epoch [14/100], Loss: 0.00191173


 15%|█▌        | 15/100 [00:43<03:28,  2.45s/it]

Epoch [15/100], Loss: 0.00272395


 16%|█▌        | 16/100 [00:46<03:25,  2.45s/it]

Epoch [16/100], Loss: 0.00290818


 17%|█▋        | 17/100 [00:48<03:22,  2.44s/it]

Epoch [17/100], Loss: 0.00309534


 18%|█▊        | 18/100 [00:51<03:19,  2.43s/it]

Epoch [18/100], Loss: 0.00229798


 19%|█▉        | 19/100 [00:53<03:16,  2.43s/it]

Epoch [19/100], Loss: 0.01062003


 20%|██        | 20/100 [00:55<03:14,  2.43s/it]

Epoch [20/100], Loss: 0.00228619


 21%|██        | 21/100 [00:58<03:11,  2.42s/it]

Epoch [21/100], Loss: 0.00111260


 22%|██▏       | 22/100 [01:00<03:08,  2.42s/it]

Epoch [22/100], Loss: 0.00546099


 23%|██▎       | 23/100 [01:03<03:06,  2.42s/it]

Epoch [23/100], Loss: 0.00135614


 24%|██▍       | 24/100 [01:05<03:03,  2.42s/it]

Epoch [24/100], Loss: 0.00977796


 25%|██▌       | 25/100 [01:08<03:01,  2.42s/it]

Epoch [25/100], Loss: 0.00651542


 26%|██▌       | 26/100 [01:10<02:58,  2.42s/it]

Epoch [26/100], Loss: 0.00778525


 27%|██▋       | 27/100 [01:12<02:56,  2.42s/it]

Epoch [27/100], Loss: 0.00389105


 28%|██▊       | 28/100 [01:15<02:54,  2.42s/it]

Epoch [28/100], Loss: 0.00383164


 29%|██▉       | 29/100 [01:17<02:51,  2.42s/it]

Epoch [29/100], Loss: 0.00399799


 30%|███       | 30/100 [01:20<02:49,  2.42s/it]

Epoch [30/100], Loss: 0.00201225


 31%|███       | 31/100 [01:22<02:46,  2.41s/it]

Epoch [31/100], Loss: 0.01347353


 32%|███▏      | 32/100 [01:24<02:44,  2.42s/it]

Epoch [32/100], Loss: 0.00637990


 33%|███▎      | 33/100 [01:27<02:41,  2.42s/it]

Epoch [33/100], Loss: 0.00961489


 34%|███▍      | 34/100 [01:29<02:41,  2.44s/it]

Epoch [34/100], Loss: 0.00331725


 35%|███▌      | 35/100 [01:32<02:41,  2.49s/it]

Epoch [35/100], Loss: 0.00198668


 36%|███▌      | 36/100 [01:34<02:37,  2.47s/it]

Epoch [36/100], Loss: 0.00840443


 37%|███▋      | 37/100 [01:37<02:34,  2.45s/it]

Epoch [37/100], Loss: 0.00157361


 38%|███▊      | 38/100 [01:39<02:31,  2.44s/it]

Epoch [38/100], Loss: 0.00490683


 39%|███▉      | 39/100 [01:42<02:28,  2.43s/it]

Epoch [39/100], Loss: 0.00248960


 40%|████      | 40/100 [01:44<02:25,  2.43s/it]

Epoch [40/100], Loss: 0.02354202


 41%|████      | 41/100 [01:46<02:23,  2.43s/it]

Epoch [41/100], Loss: 0.00148393


 42%|████▏     | 42/100 [01:49<02:20,  2.42s/it]

Epoch [42/100], Loss: 0.00088679


 43%|████▎     | 43/100 [01:51<02:18,  2.42s/it]

Epoch [43/100], Loss: 0.00096167


 44%|████▍     | 44/100 [01:54<02:15,  2.42s/it]

Epoch [44/100], Loss: 0.00019221


 45%|████▌     | 45/100 [01:56<02:13,  2.42s/it]

Epoch [45/100], Loss: 0.00049548


 46%|████▌     | 46/100 [01:59<02:10,  2.42s/it]

Epoch [46/100], Loss: 0.00606500


 47%|████▋     | 47/100 [02:01<02:08,  2.42s/it]

Epoch [47/100], Loss: 0.00769970


 48%|████▊     | 48/100 [02:04<02:10,  2.50s/it]

Epoch [48/100], Loss: 0.00154918


 49%|████▉     | 49/100 [02:06<02:06,  2.49s/it]

Epoch [49/100], Loss: 0.00444686


 50%|█████     | 50/100 [02:09<02:07,  2.55s/it]

Epoch [50/100], Loss: 0.00338334


 51%|█████     | 51/100 [02:11<02:03,  2.51s/it]

Epoch [51/100], Loss: 0.00392538


 52%|█████▏    | 52/100 [02:14<01:59,  2.48s/it]

Epoch [52/100], Loss: 0.00095364


 53%|█████▎    | 53/100 [02:16<01:55,  2.46s/it]

Epoch [53/100], Loss: 0.00225099


 54%|█████▍    | 54/100 [02:18<01:52,  2.45s/it]

Epoch [54/100], Loss: 0.00461476


 55%|█████▌    | 55/100 [02:21<01:49,  2.43s/it]

Epoch [55/100], Loss: 0.00068710


 56%|█████▌    | 56/100 [02:23<01:46,  2.43s/it]

Epoch [56/100], Loss: 0.00192855


 57%|█████▋    | 57/100 [02:26<01:44,  2.43s/it]

Epoch [57/100], Loss: 0.00144335


 58%|█████▊    | 58/100 [02:28<01:41,  2.43s/it]

Epoch [58/100], Loss: 0.00236243


 59%|█████▉    | 59/100 [02:31<01:39,  2.42s/it]

Epoch [59/100], Loss: 0.00122972


 60%|██████    | 60/100 [02:33<01:37,  2.43s/it]

Epoch [60/100], Loss: 0.00151657


 61%|██████    | 61/100 [02:35<01:34,  2.42s/it]

Epoch [61/100], Loss: 0.00213268


 62%|██████▏   | 62/100 [02:38<01:32,  2.42s/it]

Epoch [62/100], Loss: 0.00223231


 63%|██████▎   | 63/100 [02:40<01:29,  2.42s/it]

Epoch [63/100], Loss: 0.00135059


 64%|██████▍   | 64/100 [02:43<01:27,  2.43s/it]

Epoch [64/100], Loss: 0.00149427


 65%|██████▌   | 65/100 [02:45<01:24,  2.42s/it]

Epoch [65/100], Loss: 0.00322628


 66%|██████▌   | 66/100 [02:48<01:22,  2.43s/it]

Epoch [66/100], Loss: 0.00099028


 67%|██████▋   | 67/100 [02:50<01:19,  2.42s/it]

Epoch [67/100], Loss: 0.00043573


 68%|██████▊   | 68/100 [02:52<01:17,  2.42s/it]

Epoch [68/100], Loss: 0.00070259


 69%|██████▉   | 69/100 [02:55<01:15,  2.43s/it]

Epoch [69/100], Loss: 0.00088727


 70%|███████   | 70/100 [02:57<01:12,  2.43s/it]

Epoch [70/100], Loss: 0.00065685


 71%|███████   | 71/100 [03:00<01:10,  2.43s/it]

Epoch [71/100], Loss: 0.00196927


 72%|███████▏  | 72/100 [03:02<01:08,  2.43s/it]

Epoch [72/100], Loss: 0.00101752


 73%|███████▎  | 73/100 [03:05<01:05,  2.42s/it]

Epoch [73/100], Loss: 0.00039463


 74%|███████▍  | 74/100 [03:07<01:03,  2.42s/it]

Epoch [74/100], Loss: 0.00143205


 75%|███████▌  | 75/100 [03:09<01:00,  2.42s/it]

Epoch [75/100], Loss: 0.00148997


 76%|███████▌  | 76/100 [03:12<00:58,  2.42s/it]

Epoch [76/100], Loss: 0.00250095


 77%|███████▋  | 77/100 [03:14<00:55,  2.42s/it]

Epoch [77/100], Loss: 0.00156698


 78%|███████▊  | 78/100 [03:17<00:53,  2.42s/it]

Epoch [78/100], Loss: 0.00232533


 79%|███████▉  | 79/100 [03:19<00:50,  2.42s/it]

Epoch [79/100], Loss: 0.00090487


 80%|████████  | 80/100 [03:21<00:48,  2.42s/it]

Epoch [80/100], Loss: 0.00143223


 81%|████████  | 81/100 [03:24<00:46,  2.42s/it]

Epoch [81/100], Loss: 0.00346395


 82%|████████▏ | 82/100 [03:26<00:43,  2.42s/it]

Epoch [82/100], Loss: 0.00050954


 83%|████████▎ | 83/100 [03:29<00:41,  2.42s/it]

Epoch [83/100], Loss: 0.00083260


 84%|████████▍ | 84/100 [03:31<00:38,  2.43s/it]

Epoch [84/100], Loss: 0.00150149


 85%|████████▌ | 85/100 [03:34<00:36,  2.43s/it]

Epoch [85/100], Loss: 0.00049746


 86%|████████▌ | 86/100 [03:36<00:33,  2.43s/it]

Epoch [86/100], Loss: 0.00124523


 87%|████████▋ | 87/100 [03:38<00:31,  2.42s/it]

Epoch [87/100], Loss: 0.00114362


 88%|████████▊ | 88/100 [03:41<00:29,  2.42s/it]

Epoch [88/100], Loss: 0.00050217


 89%|████████▉ | 89/100 [03:43<00:26,  2.42s/it]

Epoch [89/100], Loss: 0.00197233


 90%|█████████ | 90/100 [03:46<00:24,  2.42s/it]

Epoch [90/100], Loss: 0.00109689


 91%|█████████ | 91/100 [03:48<00:21,  2.42s/it]

Epoch [91/100], Loss: 0.00057214


 92%|█████████▏| 92/100 [03:51<00:19,  2.42s/it]

Epoch [92/100], Loss: 0.00092585


 93%|█████████▎| 93/100 [03:53<00:17,  2.51s/it]

Epoch [93/100], Loss: 0.00136609


 94%|█████████▍| 94/100 [03:56<00:14,  2.49s/it]

Epoch [94/100], Loss: 0.00072009


 95%|█████████▌| 95/100 [03:58<00:12,  2.47s/it]

Epoch [95/100], Loss: 0.00054244


 96%|█████████▌| 96/100 [04:01<00:09,  2.46s/it]

Epoch [96/100], Loss: 0.00025210


 97%|█████████▋| 97/100 [04:03<00:07,  2.45s/it]

Epoch [97/100], Loss: 0.00123182


 98%|█████████▊| 98/100 [04:05<00:04,  2.43s/it]

Epoch [98/100], Loss: 0.00094332


 99%|█████████▉| 99/100 [04:08<00:02,  2.43s/it]

Epoch [99/100], Loss: 0.00050559


100%|██████████| 100/100 [04:10<00:00,  2.51s/it]

Epoch [100/100], Loss: 0.00033199





In [8]:
with torch.no_grad():
    count  = 0
    for ind,(inputs,labels) in enumerate(test_dataloader):
        inputs = inputs.type(torch.FloatTensor).to(device)
        labels = labels.type(torch.FloatTensor).to(device)

        outputs = lstm(inputs)
        y_pred = train_dataset.sc_close.inverse_transform(outputs.cpu())
        y_true = train_dataset.sc_close.inverse_transform(labels.cpu())
        mape_loss += mean_absolute_percentage_error(y_true, y_pred)
        rmse_loss += mean_squared_error(y_true, y_pred,squared=False)
        count += 1
    print(f"mape: {mape_loss / count}")
    print(f"rmse: {rmse_loss / count}")
    print(f"count: {count}")

mape: 0.04810266321833367
rmse: 18.13821927229396
count: 51
