In [2]:
import torch
from torch import nn
from torch.utils.data import Dataset
import numpy as np
import pandas as pd

class Data(Dataset):

    def __init__(self, dt: pd.DataFrame, window_size: int = 20):
        """
        feature:
        "收盘价", "开盘价", "最高价", "最低价", "成交量"
        """
        super().__init__()
        self.values = torch.from_numpy(dt.iloc[:, 2:].values).type(dtype=torch.float32)
        self.window_size = window_size
        self.num_feature = self.values.shape[1]

    def __getitem__(self, idx: int):
        # (N_batch, seq_length, num_feature)
        x = self.values[idx:(idx+self.window_size)].reshape((self.window_size, self.num_feature))
        y = self.values[idx+1:(idx+self.window_size+1), 0].reshape((self.window_size, 1))[-1]
        return x, y

    def __len__(self):
        return self.values.shape[0] - self.window_size

In [3]:
class AccRate:
    """
    https://flyai.com/d/StockPredict
    挑战赛上的反向误差率：
    """
    def __init__(self):
        pass

    def __call__(self, output: torch.Tensor, x: torch.Tensor):
        high = x[:, 2]
        low = x[:, 3]
        gt = x[:, 0]
        return 100 * (1 - torch.sum((output - gt) / (high - low))/x.shape[0])
    
acc_rate = AccRate()

In [4]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import torch

def plot(ground_truth: pd.Series, pred: np.ndarray, tag: int, save_file: str, error: float):
    if isinstance(pred, list):
        pred = np.asarray(pred)
    if isinstance(error, torch.Tensor):
        error = error.item()
    if not os.path.exists("output"):
        os.mkdir("output") 
    plt.figure()
    # plt.title(f"result: {error}")
    pre = np.ones_like(pred)
    post = np.ones_like(pred)
    pre[tag:] = pred[tag:] * np.inf
    post[:tag] = post[:tag] * np.inf

    plt.plot(ground_truth.values, label="ground truth", color="green")
    plt.plot(pred * pre, label="predication(train stage)", linestyle="--", color="purple")
    plt.plot(pred * post, label="predication(test stage)", linestyle="--", color="red")
    plt.legend()
    plt.axvline(x=tag, linestyle=":", color="cyan")
    plt.savefig(f"output/{save_file}.png", bbox_inches="tight")
    plt.close()

In [21]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
from glob import glob
from typing import cast
import pandas as pd

EPOCHS = 100
WINDOW_SIZE = 40

class GRU(torch.nn.Module):
    def __init__(self, feature_dim, hidden_dim, layer_dim, output_dim):
        super().__init__()
        # GRU + 全连接
        self.gru = nn.GRU(feature_dim, hidden_dim, layer_dim,
                         batch_first=True)
        self.fc1 = nn.Sequential(
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, inputs):
        x, _ = self.gru(inputs, None)# [batch, time_step, hidden_dim]
        x = self.fc1(x[:, -1, :]) #[batch, time_step, output_dim]
        return x

split_data_root = os.path.join(os.path.dirname(os.path.abspath('__file__')), "split")
# files = os.listdir(split_data_root)
# files = ['000622.SZ.csv']:
files = np.loadtxt("./comp_idx.txt", dtype=str).tolist()
if isinstance(files, str):
    files = [files] 
for c_idx in files:
        # 模型
        model = GRU(feature_dim=5, hidden_dim=20, layer_dim=1, output_dim=1)
        opt = torch.optim.Adam(model.parameters(), lr=1e-4)
        criterion = nn.SmoothL1Loss()
        if torch.cuda.is_available():
            model.cuda()

        # 数据
        csv = glob(os.path.join(split_data_root, f'{c_idx}.csv'))[0]
        dt = pd.read_csv(csv)
        dt[["收盘价","开盘价","最高价","最低价","成交量"]] = (dt[["收盘价","开盘价","最高价","最低价","成交量"]] - dt[["收盘价","开盘价","最高价","最低价","成交量"]].mean()) / dt[["收盘价","开盘价","最高价","最低价","成交量"]].std()
        length = len(dt)
        train_len = round(length * 0.8) 
        train_dt = dt[:train_len]
        test_dt = dt[(train_len-WINDOW_SIZE+1):]

        train_data = Data(train_dt, window_size=WINDOW_SIZE)
        train_data = DataLoader(train_data, batch_size=1, pin_memory=True, num_workers=4)

        test_data = Data(test_dt, window_size=WINDOW_SIZE)
        test_data = DataLoader(test_data, batch_size=1, pin_memory=True, num_workers=4)

        ground_truth = dt[WINDOW_SIZE:len(dt)-1].iloc[:, 2]

        outputs_train = []
        outputs_test = []
        
        # 训练
        import datetime
        starttime = datetime.datetime.now()

        tq = tqdm(range(EPOCHS), total=EPOCHS, desc="Training..")
        for ep in tq:
            for x, y in train_data:
                x, y = cast(torch.Tensor, x), cast(torch.Tensor, y)
                if torch.cuda.is_available():
                    x = x.cuda()
                    y = y.cuda()
                out= model(x)
                if ep == EPOCHS - 1:
                    with torch.no_grad():
                        outputs_train.extend(out.cpu().detach().squeeze(dim=1).numpy().tolist())
                loss = criterion(out, y)
                opt.zero_grad()
                loss.backward()
                opt.step()
                tq.set_postfix({
                    "loss": loss.item()
                })
        endtime = datetime.datetime.now()
        print(f'训练时间：{(endtime - starttime).seconds}')
        # 测试
        test_x = []
        with torch.no_grad():
            for x, y in tqdm(test_data, total=len(test_data), desc="Testing..."):
                x, y = cast(torch.Tensor, x), cast(torch.Tensor, y)
                if torch.cuda.is_available():
                    x = x.cuda()
                    y = y.cuda()
                out = model(x)
                outputs_test.extend(out.cpu().detach().squeeze(dim=1).numpy().tolist())
                test_x.extend(x[:, -1].cpu().detach().numpy().tolist())

    #     er = acc_rate(torch.as_tensor(outputs_test), torch.as_tensor(test_x))
    #     print(f'acc_rate:{er}')
        plot(ground_truth, outputs_train+outputs_test, train_len - WINDOW_SIZE, f"GRU_{c_idx}", error=0.1)

Training..: 100%|██████████| 100/100 [01:38<00:00,  1.01it/s, loss=4.03e-5]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：98


Testing...: 100%|██████████| 47/47 [00:00<00:00, 164.76it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.03it/s, loss=7.08e-5]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 170.67it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.04it/s, loss=0.00158]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 167.95it/s]
Training..: 100%|██████████| 100/100 [01:34<00:00,  1.06it/s, loss=0.004]  
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：94


Testing...: 100%|██████████| 47/47 [00:00<00:00, 163.57it/s]
Training..: 100%|██████████| 100/100 [01:37<00:00,  1.03it/s, loss=0.000923]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：97


Testing...: 100%|██████████| 47/47 [00:00<00:00, 174.42it/s]
Training..: 100%|██████████| 100/100 [01:37<00:00,  1.02it/s, loss=0.0084] 
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：97


Testing...: 100%|██████████| 47/47 [00:00<00:00, 162.63it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.04it/s, loss=0.0225] 
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 159.61it/s]
Training..: 100%|██████████| 100/100 [01:37<00:00,  1.02it/s, loss=0.000422]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：97


Testing...: 100%|██████████| 47/47 [00:00<00:00, 150.43it/s]
Training..: 100%|██████████| 100/100 [01:37<00:00,  1.02it/s, loss=0.000134]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：97


Testing...: 100%|██████████| 47/47 [00:00<00:00, 173.56it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.03it/s, loss=0.000672]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 169.38it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.03it/s, loss=0.00308]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 169.86it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.03it/s, loss=0.0347] 
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 160.41it/s]
Training..: 100%|██████████| 100/100 [01:37<00:00,  1.03it/s, loss=0.0279] 
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：97


Testing...: 100%|██████████| 47/47 [00:00<00:00, 163.52it/s]
Training..: 100%|██████████| 100/100 [01:37<00:00,  1.03it/s, loss=0.00158]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：97


Testing...: 100%|██████████| 47/47 [00:00<00:00, 158.79it/s]
Training..: 100%|██████████| 100/100 [01:34<00:00,  1.06it/s, loss=0.00445]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：94


Testing...: 100%|██████████| 47/47 [00:00<00:00, 171.25it/s]
Training..: 100%|██████████| 100/100 [01:37<00:00,  1.03it/s, loss=0.000199]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：97


Testing...: 100%|██████████| 47/47 [00:00<00:00, 167.60it/s]
Training..: 100%|██████████| 100/100 [01:37<00:00,  1.02it/s, loss=0.000306]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：97


Testing...: 100%|██████████| 47/47 [00:00<00:00, 157.65it/s]
Training..:   1%|          | 1/100 [00:01<01:34,  1.05it/s, loss=0.257]   Traceback (most recent call last):
  File "/home/fanrui/anaconda3/envs/tf22/lib/python3.6/multiprocessing/queues.py", line 230, in _feed
    close()
  File "/home/fanrui/anaconda3/envs/tf22/lib/python3.6/multiprocessing/connection.py", line 177, in close
    self._close()
  File "/home/fanrui/anaconda3/envs/tf22/lib/python3.6/multiprocessing/connection.py", line 361, in _close
    _close(self._handle)
OSError: [Errno 9] Bad file descriptor
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=0.00813]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 177.44it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=0.0132] 
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 175.53it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.03it/s, loss=0.0293] 
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 177.72it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=0.0124] 
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 177.54it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=0.038]  
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 174.48it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=0.0195] 
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 178.25it/s]
Training..: 100%|██████████| 100/100 [01:34<00:00,  1.05it/s, loss=0.012]  
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：94


Testing...: 100%|██████████| 47/47 [00:00<00:00, 167.84it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.03it/s, loss=0.00192]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 171.77it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.04it/s, loss=0.002]  
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 168.44it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.04it/s, loss=0.00362]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 176.90it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=0.000284]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 175.27it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=0.00657]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 174.18it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.04it/s, loss=0.0144] 
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 175.16it/s]
Training..: 100%|██████████| 100/100 [01:34<00:00,  1.06it/s, loss=0.000479]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：94


Testing...: 100%|██████████| 47/47 [00:00<00:00, 174.62it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.04it/s, loss=1.05e-5]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 173.33it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=0.000923]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 173.26it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.04it/s, loss=0.00109]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 175.16it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.04it/s, loss=0.00279]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 174.86it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=0.000823]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 173.82it/s]
Training..: 100%|██████████| 100/100 [01:37<00:00,  1.03it/s, loss=0.008]  
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：97


Testing...: 100%|██████████| 47/47 [00:00<00:00, 157.39it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.04it/s, loss=0.00103]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 174.16it/s]
Training..: 100%|██████████| 100/100 [01:37<00:00,  1.03it/s, loss=0.00521]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：97


Testing...: 100%|██████████| 47/47 [00:00<00:00, 151.11it/s]
Training..: 100%|██████████| 100/100 [01:37<00:00,  1.03it/s, loss=0.0168] 
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：97


Testing...: 100%|██████████| 47/47 [00:00<00:00, 174.41it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=0.00595]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 177.13it/s]
Training..: 100%|██████████| 100/100 [01:34<00:00,  1.06it/s, loss=0.00197]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：94


Testing...: 100%|██████████| 47/47 [00:00<00:00, 165.79it/s]
Training..:  17%|█▋        | 17/100 [00:16<01:20,  1.03it/s, loss=0.00334] IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=0.0168] 
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 177.25it/s]
Training..:   5%|▌         | 5/100 [00:05<01:32,  1.02it/s, loss=0.00938] IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Training..:  94%|█████████▍| 94/100 [01:29<00:05,  1.03it/s, loss=0.0733]  IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Training..:  78%|███████▊  | 78/100 [01:15<00:21,  1.04it/s, loss=0.00628] IOPub message rate exceeded.
The notebook server will temporarily

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 174.39it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.04it/s, loss=0.0216] 
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 175.06it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=0.0244] 
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 171.97it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=0.00891]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 172.97it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=0.000599]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 148.01it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.04it/s, loss=0.00683]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 158.51it/s]
Training..: 100%|██████████| 100/100 [01:34<00:00,  1.06it/s, loss=0.0103] 
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：94


Testing...: 100%|██████████| 47/47 [00:00<00:00, 173.93it/s]
Training..: 100%|██████████| 100/100 [01:34<00:00,  1.06it/s, loss=0.0203] 
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：94


Testing...: 100%|██████████| 47/47 [00:00<00:00, 177.42it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=0.00298]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 173.20it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.04it/s, loss=0.0844] 
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 175.42it/s]
Training..: 100%|██████████| 100/100 [01:34<00:00,  1.06it/s, loss=0.0148] 
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：94


Testing...: 100%|██████████| 47/47 [00:00<00:00, 173.56it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.04it/s, loss=0.000564]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 174.64it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=0.00656]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 175.25it/s]
Training..: 100%|██████████| 100/100 [01:33<00:00,  1.07it/s, loss=0.00158]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：93


Testing...: 100%|██████████| 47/47 [00:00<00:00, 177.20it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=5.29e-5]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 173.12it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.04it/s, loss=0.014]  
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 174.53it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=0.0321] 
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 173.15it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.04it/s, loss=0.000409]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 172.73it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.04it/s, loss=0.0059] 
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 162.48it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.04it/s, loss=0.00101]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 174.03it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.04it/s, loss=0.000741]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 172.69it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.04it/s, loss=0.000126]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 175.65it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=5.6e-6] 
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 175.58it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.04it/s, loss=0.00013]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 163.67it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.04it/s, loss=0.000244]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 173.86it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=6.05e-5]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 170.96it/s]
Training..: 100%|██████████| 100/100 [01:34<00:00,  1.05it/s, loss=0.000134]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：94


Testing...: 100%|██████████| 47/47 [00:00<00:00, 174.38it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.04it/s, loss=1.62e-6]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 175.04it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.04it/s, loss=1.21e-8]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 174.61it/s]
Training..: 100%|██████████| 100/100 [01:34<00:00,  1.05it/s, loss=1.27e-5]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：94


Testing...: 100%|██████████| 47/47 [00:00<00:00, 174.45it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.04it/s, loss=0.000189]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 176.14it/s]
Training..: 100%|██████████| 100/100 [01:37<00:00,  1.03it/s, loss=8.66e-5]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：97


Testing...: 100%|██████████| 47/47 [00:00<00:00, 173.68it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.04it/s, loss=4.19e-5]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 161.41it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=1.93e-6]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 174.28it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.04it/s, loss=5.55e-5]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 169.80it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.04it/s, loss=6.49e-5]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 174.25it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=0.000392]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 175.12it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.04it/s, loss=0.000136]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 173.60it/s]
Training..: 100%|██████████| 100/100 [01:37<00:00,  1.03it/s, loss=0.00131]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：97


Testing...: 100%|██████████| 47/47 [00:00<00:00, 173.90it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.04it/s, loss=1.69e-5]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 128.47it/s]
Training..: 100%|██████████| 100/100 [01:37<00:00,  1.02it/s, loss=3.5e-5] 
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：97


Testing...: 100%|██████████| 47/47 [00:00<00:00, 171.31it/s]
Training..: 100%|██████████| 100/100 [01:37<00:00,  1.03it/s, loss=0.000148]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：97


Testing...: 100%|██████████| 47/47 [00:00<00:00, 175.17it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.03it/s, loss=0.000243]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 171.62it/s]
Training..: 100%|██████████| 100/100 [01:38<00:00,  1.01it/s, loss=0.000154]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：98


Testing...: 100%|██████████| 47/47 [00:00<00:00, 171.62it/s]
Training..: 100%|██████████| 100/100 [01:32<00:00,  1.08it/s, loss=0.000221]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：92


Testing...: 100%|██████████| 47/47 [00:00<00:00, 172.99it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.04it/s, loss=0.00202]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 163.97it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.04it/s, loss=2.63e-5]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 173.07it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.04it/s, loss=2.55e-5]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 172.35it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=4.69e-6]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 169.00it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=8.84e-6]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 172.50it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.04it/s, loss=5.04e-5]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 172.61it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.04it/s, loss=2.95e-7]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 171.57it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.04it/s, loss=0.000162]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 169.41it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=6.4e-5] 
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 173.07it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=0.000191]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 173.29it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.03it/s, loss=0.000171]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 169.48it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.04it/s, loss=6.12e-6]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 169.17it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.03it/s, loss=5.63e-5]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 174.11it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.04it/s, loss=0.000558]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 172.20it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.03it/s, loss=0.000135]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 154.82it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.03it/s, loss=2.42e-5]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 172.50it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s, loss=1.53e-7]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 174.83it/s]
Training..: 100%|██████████| 100/100 [01:34<00:00,  1.05it/s, loss=4.21e-6]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：94


Testing...: 100%|██████████| 47/47 [00:00<00:00, 166.41it/s]
Training..: 100%|██████████| 100/100 [01:35<00:00,  1.04it/s, loss=3.83e-6]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：95


Testing...: 100%|██████████| 47/47 [00:00<00:00, 173.50it/s]
Training..: 100%|██████████| 100/100 [01:36<00:00,  1.03it/s, loss=2.63e-5]
Testing...:   0%|          | 0/47 [00:00<?, ?it/s]

训练时间：96


Testing...: 100%|██████████| 47/47 [00:00<00:00, 175.37it/s]


In [11]:
from thop import profile

model = GRU(feature_dim=5, hidden_dim=20, layer_dim=1, output_dim=1)
input = torch.randn(1, 40, 5)
flops, params = profile(model, inputs=(input, ))
print('flops:{}'.format(flops))
print('params:{}'.format(params))

[INFO] Register count_gru() for <class 'torch.nn.modules.rnn.GRU'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[91m[WARN] Cannot find rule for <class 'torch.nn.modules.container.Sequential'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class '__main__.GRU'>. Treat it as zero Macs and zero Params.[00m
flops:70420.0
params:1641.0
