<a href="https://colab.research.google.com/github/ThousandAI/Application-of-AI/blob/main/class08/LSTM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
from sklearn.metrics import mean_absolute_percentage_error ,mean_squared_error

In [2]:
class StockDataset(Dataset):
    def __init__(self,start,end,seq_length,dim,mode,train_dataset=None):
        self.seq_length = seq_length
        self.data = np.array(pd.read_csv("/content/drive/MyDrive/Colab-Notebooks/Application-of-AI/class08/stock.csv"))[start:end,:-1]
        if mode == 'train':
            self.sc = StandardScaler()
            self.sc_close = StandardScaler()
            self.sc_close.fit(self.data[:,1:2])
            self.data = self.sc.fit_transform(self.data)
        elif mode == 'test':
            self.data = train_dataset.sc.fit_transform(self.data)
        self.n_samples = len(self.data) - self.seq_length
        self.x = np.zeros((self.n_samples,seq_length,dim))
        self.y = np.zeros((self.n_samples, 1))
        for i in range(len(self.data)-self.seq_length):
            self.x[i,:,:] = self.data[i:i+self.seq_length,:]
            self.y[i,0] = self.data[i+seq_length,1]

    def __getitem__(self, index):
        return self.x[index], self.y[index]

    def __len__(self):
        return self.n_samples


In [3]:
# hyper parameters
input_size = 11 # 股價維度
seq_length = 90 # 序列長度
hidden_size = 256
num_layers = 2
epochs = 100
batch_size = 16
learning_rate = 1e-3
start = 0
end = 5000
test_start = 5000
test_end = 5900
# device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
# dataset
train_dataset = StockDataset(start=start,end=end,seq_length=seq_length,dim=input_size,mode="train")
test_dataset = StockDataset(start=test_start,end=test_end,seq_length=seq_length,dim=input_size,mode="test",
                            train_dataset=train_dataset)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle= True)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [5]:
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(LSTM, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size,1)

    def forward(self, x):
        # out: tensor of shape (batch_size, seq_length, hidden_size)
        # out: (n, 15, 768)
        out,_ = self.lstm(x)


        # Decode the hidden state of the last time step
        out = out[:, -1, :]
        # out: (n, 768)

        out = self.fc(out)
        # out: (n, 1)
        return out

In [6]:
lstm = LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers).to(device)
optimizer = torch.optim.Adam(lstm.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

In [7]:
# Train the model

mape_loss = 0
rmse_loss = 0
for epoch in tqdm(range(epochs)):
    for i, (inputs, labels) in enumerate(train_dataloader):
        inputs = inputs.type(torch.FloatTensor).to(device)
        labels = labels.type(torch.FloatTensor).to(device)

        # Forward pass
        outputs = lstm(inputs)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.8f}')

  1%|          | 1/100 [00:03<05:14,  3.17s/it]

Epoch [1/100], Loss: 0.00411349


  2%|▏         | 2/100 [00:06<05:17,  3.23s/it]

Epoch [2/100], Loss: 0.00172368


  3%|▎         | 3/100 [00:09<05:16,  3.26s/it]

Epoch [3/100], Loss: 0.00213022


  4%|▍         | 4/100 [00:13<05:29,  3.44s/it]

Epoch [4/100], Loss: 0.00455120


  5%|▌         | 5/100 [00:16<05:24,  3.42s/it]

Epoch [5/100], Loss: 0.00293254


  6%|▌         | 6/100 [00:19<05:04,  3.24s/it]

Epoch [6/100], Loss: 0.00977520


  7%|▋         | 7/100 [00:22<04:35,  2.97s/it]

Epoch [7/100], Loss: 0.00418183


  8%|▊         | 8/100 [00:24<04:16,  2.79s/it]

Epoch [8/100], Loss: 0.00200666


  9%|▉         | 9/100 [00:26<04:02,  2.67s/it]

Epoch [9/100], Loss: 0.00098812


 10%|█         | 10/100 [00:29<03:52,  2.58s/it]

Epoch [10/100], Loss: 0.00200387


 11%|█         | 11/100 [00:31<03:44,  2.53s/it]

Epoch [11/100], Loss: 0.00382396


 12%|█▏        | 12/100 [00:34<03:39,  2.49s/it]

Epoch [12/100], Loss: 0.00136258


 13%|█▎        | 13/100 [00:36<03:34,  2.47s/it]

Epoch [13/100], Loss: 0.00299088


 14%|█▍        | 14/100 [00:38<03:30,  2.45s/it]

Epoch [14/100], Loss: 0.00316770


 15%|█▌        | 15/100 [00:41<03:27,  2.44s/it]

Epoch [15/100], Loss: 0.00727386


 16%|█▌        | 16/100 [00:43<03:24,  2.43s/it]

Epoch [16/100], Loss: 0.00445392


 17%|█▋        | 17/100 [00:46<03:21,  2.43s/it]

Epoch [17/100], Loss: 0.00096863


 18%|█▊        | 18/100 [00:48<03:18,  2.42s/it]

Epoch [18/100], Loss: 0.00201821


 19%|█▉        | 19/100 [00:51<03:15,  2.42s/it]

Epoch [19/100], Loss: 0.00801796


 20%|██        | 20/100 [00:53<03:13,  2.42s/it]

Epoch [20/100], Loss: 0.00333672


 21%|██        | 21/100 [00:55<03:10,  2.41s/it]

Epoch [21/100], Loss: 0.00400475


 22%|██▏       | 22/100 [00:58<03:07,  2.41s/it]

Epoch [22/100], Loss: 0.01099756


 23%|██▎       | 23/100 [01:00<03:05,  2.41s/it]

Epoch [23/100], Loss: 0.00452971


 24%|██▍       | 24/100 [01:03<03:03,  2.41s/it]

Epoch [24/100], Loss: 0.00239428


 25%|██▌       | 25/100 [01:05<03:00,  2.41s/it]

Epoch [25/100], Loss: 0.00727915


 26%|██▌       | 26/100 [01:07<02:58,  2.41s/it]

Epoch [26/100], Loss: 0.00670714


 27%|██▋       | 27/100 [01:10<02:56,  2.41s/it]

Epoch [27/100], Loss: 0.00558580


 28%|██▊       | 28/100 [01:12<02:53,  2.41s/it]

Epoch [28/100], Loss: 0.00334257


 29%|██▉       | 29/100 [01:15<02:51,  2.41s/it]

Epoch [29/100], Loss: 0.00231743


 30%|███       | 30/100 [01:17<02:48,  2.41s/it]

Epoch [30/100], Loss: 0.00456720


 31%|███       | 31/100 [01:19<02:46,  2.41s/it]

Epoch [31/100], Loss: 0.00111003


 32%|███▏      | 32/100 [01:22<02:43,  2.41s/it]

Epoch [32/100], Loss: 0.00387139


 33%|███▎      | 33/100 [01:24<02:41,  2.41s/it]

Epoch [33/100], Loss: 0.00441524


 34%|███▍      | 34/100 [01:27<02:39,  2.42s/it]

Epoch [34/100], Loss: 0.01158470


 35%|███▌      | 35/100 [01:29<02:36,  2.41s/it]

Epoch [35/100], Loss: 0.00319517


 36%|███▌      | 36/100 [01:31<02:34,  2.41s/it]

Epoch [36/100], Loss: 0.00243566


 37%|███▋      | 37/100 [01:34<02:31,  2.41s/it]

Epoch [37/100], Loss: 0.00589590


 38%|███▊      | 38/100 [01:36<02:29,  2.41s/it]

Epoch [38/100], Loss: 0.00528283


 39%|███▉      | 39/100 [01:39<02:27,  2.41s/it]

Epoch [39/100], Loss: 0.00350257


 40%|████      | 40/100 [01:41<02:24,  2.41s/it]

Epoch [40/100], Loss: 0.00179579


 41%|████      | 41/100 [01:44<02:22,  2.41s/it]

Epoch [41/100], Loss: 0.00223771


 42%|████▏     | 42/100 [01:46<02:19,  2.41s/it]

Epoch [42/100], Loss: 0.00185094


 43%|████▎     | 43/100 [01:49<02:19,  2.45s/it]

Epoch [43/100], Loss: 0.00098048


 44%|████▍     | 44/100 [01:51<02:18,  2.46s/it]

Epoch [44/100], Loss: 0.00343280


 45%|████▌     | 45/100 [01:53<02:14,  2.45s/it]

Epoch [45/100], Loss: 0.00382999


 46%|████▌     | 46/100 [01:56<02:11,  2.44s/it]

Epoch [46/100], Loss: 0.00080311


 47%|████▋     | 47/100 [01:58<02:09,  2.44s/it]

Epoch [47/100], Loss: 0.00243219


 48%|████▊     | 48/100 [02:01<02:06,  2.43s/it]

Epoch [48/100], Loss: 0.00326532


 49%|████▉     | 49/100 [02:03<02:05,  2.45s/it]

Epoch [49/100], Loss: 0.00062021


 50%|█████     | 50/100 [02:06<02:04,  2.50s/it]

Epoch [50/100], Loss: 0.00489188


 51%|█████     | 51/100 [02:08<02:01,  2.47s/it]

Epoch [51/100], Loss: 0.00116648


 52%|█████▏    | 52/100 [02:11<01:57,  2.46s/it]

Epoch [52/100], Loss: 0.00127771


 53%|█████▎    | 53/100 [02:13<01:54,  2.44s/it]

Epoch [53/100], Loss: 0.00254981


 54%|█████▍    | 54/100 [02:15<01:51,  2.43s/it]

Epoch [54/100], Loss: 0.00160120


 55%|█████▌    | 55/100 [02:18<01:49,  2.43s/it]

Epoch [55/100], Loss: 0.00492039


 56%|█████▌    | 56/100 [02:20<01:47,  2.43s/it]

Epoch [56/100], Loss: 0.00114653


 57%|█████▋    | 57/100 [02:23<01:44,  2.43s/it]

Epoch [57/100], Loss: 0.00107803


 58%|█████▊    | 58/100 [02:25<01:41,  2.42s/it]

Epoch [58/100], Loss: 0.00306937


 59%|█████▉    | 59/100 [02:28<01:39,  2.42s/it]

Epoch [59/100], Loss: 0.00292581


 60%|██████    | 60/100 [02:30<01:36,  2.42s/it]

Epoch [60/100], Loss: 0.00131139


 61%|██████    | 61/100 [02:32<01:34,  2.42s/it]

Epoch [61/100], Loss: 0.00123742


 62%|██████▏   | 62/100 [02:35<01:31,  2.42s/it]

Epoch [62/100], Loss: 0.00251680


 63%|██████▎   | 63/100 [02:37<01:29,  2.42s/it]

Epoch [63/100], Loss: 0.00215892


 64%|██████▍   | 64/100 [02:40<01:27,  2.42s/it]

Epoch [64/100], Loss: 0.00328598


 65%|██████▌   | 65/100 [02:42<01:24,  2.42s/it]

Epoch [65/100], Loss: 0.00130516


 66%|██████▌   | 66/100 [02:45<01:22,  2.42s/it]

Epoch [66/100], Loss: 0.00188830


 67%|██████▋   | 67/100 [02:47<01:19,  2.42s/it]

Epoch [67/100], Loss: 0.00294519


 68%|██████▊   | 68/100 [02:49<01:17,  2.43s/it]

Epoch [68/100], Loss: 0.00287161


 69%|██████▉   | 69/100 [02:52<01:15,  2.42s/it]

Epoch [69/100], Loss: 0.01492217


 70%|███████   | 70/100 [02:54<01:12,  2.42s/it]

Epoch [70/100], Loss: 0.00288531


 71%|███████   | 71/100 [02:57<01:10,  2.42s/it]

Epoch [71/100], Loss: 0.00231131


 72%|███████▏  | 72/100 [02:59<01:07,  2.42s/it]

Epoch [72/100], Loss: 0.00135111


 73%|███████▎  | 73/100 [03:01<01:05,  2.42s/it]

Epoch [73/100], Loss: 0.00060088


 74%|███████▍  | 74/100 [03:04<01:03,  2.43s/it]

Epoch [74/100], Loss: 0.00257963


 75%|███████▌  | 75/100 [03:06<01:00,  2.42s/it]

Epoch [75/100], Loss: 0.00156051


 76%|███████▌  | 76/100 [03:09<00:58,  2.42s/it]

Epoch [76/100], Loss: 0.00163358


 77%|███████▋  | 77/100 [03:11<00:55,  2.42s/it]

Epoch [77/100], Loss: 0.00127366


 78%|███████▊  | 78/100 [03:14<00:53,  2.42s/it]

Epoch [78/100], Loss: 0.00173595


 79%|███████▉  | 79/100 [03:16<00:50,  2.42s/it]

Epoch [79/100], Loss: 0.00070481


 80%|████████  | 80/100 [03:18<00:48,  2.42s/it]

Epoch [80/100], Loss: 0.00242017


 81%|████████  | 81/100 [03:21<00:45,  2.42s/it]

Epoch [81/100], Loss: 0.00174248


 82%|████████▏ | 82/100 [03:23<00:43,  2.42s/it]

Epoch [82/100], Loss: 0.00032553


 83%|████████▎ | 83/100 [03:26<00:41,  2.42s/it]

Epoch [83/100], Loss: 0.00286773


 84%|████████▍ | 84/100 [03:28<00:38,  2.42s/it]

Epoch [84/100], Loss: 0.00250523


 85%|████████▌ | 85/100 [03:31<00:36,  2.42s/it]

Epoch [85/100], Loss: 0.00180711


 86%|████████▌ | 86/100 [03:33<00:33,  2.42s/it]

Epoch [86/100], Loss: 0.00096489


 87%|████████▋ | 87/100 [03:35<00:31,  2.43s/it]

Epoch [87/100], Loss: 0.00083648


 88%|████████▊ | 88/100 [03:38<00:29,  2.42s/it]

Epoch [88/100], Loss: 0.00113140


 89%|████████▉ | 89/100 [03:40<00:26,  2.42s/it]

Epoch [89/100], Loss: 0.00145410


 90%|█████████ | 90/100 [03:43<00:24,  2.42s/it]

Epoch [90/100], Loss: 0.00158677


 91%|█████████ | 91/100 [03:45<00:21,  2.42s/it]

Epoch [91/100], Loss: 0.00447952


 92%|█████████▏| 92/100 [03:48<00:19,  2.46s/it]

Epoch [92/100], Loss: 0.00245576


 93%|█████████▎| 93/100 [03:50<00:17,  2.55s/it]

Epoch [93/100], Loss: 0.00102967


 94%|█████████▍| 94/100 [03:53<00:15,  2.51s/it]

Epoch [94/100], Loss: 0.00135104


 95%|█████████▌| 95/100 [03:55<00:12,  2.49s/it]

Epoch [95/100], Loss: 0.00047032


 96%|█████████▌| 96/100 [03:58<00:09,  2.48s/it]

Epoch [96/100], Loss: 0.00046218


 97%|█████████▋| 97/100 [04:00<00:07,  2.46s/it]

Epoch [97/100], Loss: 0.00150521


 98%|█████████▊| 98/100 [04:03<00:04,  2.45s/it]

Epoch [98/100], Loss: 0.00149321


 99%|█████████▉| 99/100 [04:05<00:02,  2.44s/it]

Epoch [99/100], Loss: 0.00047735


100%|██████████| 100/100 [04:07<00:00,  2.48s/it]

Epoch [100/100], Loss: 0.00079641





In [8]:
with torch.no_grad():
    count  = 0
    for ind,(inputs,labels) in enumerate(test_dataloader):
        inputs = inputs.type(torch.FloatTensor).to(device)
        labels = labels.type(torch.FloatTensor).to(device)

        outputs = lstm(inputs)
        y_pred = train_dataset.sc_close.inverse_transform(outputs.cpu())
        y_true = train_dataset.sc_close.inverse_transform(labels.cpu())
        mape_loss += mean_absolute_percentage_error(y_true, y_pred)
        rmse_loss += mean_squared_error(y_true, y_pred,squared=False)
        count += 1
    print(f"mape: {mape_loss / count}")
    print(f"rmse: {rmse_loss / count}")
    print(f"count: {count}")

mape: 0.0474502483153465
rmse: 18.155281804723497
count: 51
