In [1]:
# encoding: utf-8
"""
@author : zhirui zhou
@contact: evilpsycho42@gmail.com
@time   : 2019/11/6 13:49
"""
from torch import nn
import torch


class SimpleSeq2Seq(nn.Module):

    def __init__(self, target_dim, hidden_size, activation='Tanh', dropout=0.0):
        super(SimpleSeq2Seq, self).__init__()
        self.target_dim = target_dim
        self.hidden_size = hidden_size

        self.encoder = nn.LSTM(target_dim, hidden_size, num_layers=1, bias=True, batch_first=True)
        self.decoder = nn.LSTM(target_dim, hidden_size, num_layers=1, bias=True, batch_first=True)
        self.out = nn.Sequential(
            nn.Linear(hidden_size, target_dim),
            getattr(nn, activation)(),
            nn.Dropout(dropout))

    def forward(self, enc_inputs, dec_inputs):
        batch, dec_lens, _ = dec_inputs.shape
        enc_outputs, enc_hidden = self.encoder(enc_inputs)
        dec_hidden = enc_hidden
        dec_outputs, _ = self.decoder(dec_inputs, dec_hidden)
        preds = self.out(dec_outputs)
        return preds

    def predict(self, enc_inputs, predict_steps):
        with torch.no_grad():
            batch, _, _ = enc_inputs.shape
            enc_outputs, enc_hidden = self.encoder(enc_inputs)
            dec_inputs = enc_inputs[:, -1, :].unsqueeze(1)
            dec_hidden = enc_hidden
            preds = torch.zeros(batch, predict_steps, 1)
            for i in range(predict_steps):
                dec_outputs, dec_hidden = self.decoder(dec_inputs, dec_hidden)
                dec_inputs = self.out(dec_outputs)
                preds[:, i] = dec_inputs
            return preds
    
from torch.utils.data import Dataset, DataLoader
import numpy as np


class TorchSimpleSeriesDataSet(Dataset):

    def __init__(self, series, enc_lens, dec_lens):
        self.s = series.astype('float32')
        self.el = enc_lens
        self.dl = dec_lens

    def __len__(self):
        return len(self.s) - self.el - self.dl + 1

    def __getitem__(self, item):
        enc_inputs = self.s[item: item + self.el].reshape(-1, 1)
        dec_inputs = self.s[item + self.el - 1: item + self.el + self.dl - 1].reshape(-1, 1)
        dec_outputs = self.s[item + self.el: item + self.el + self.dl].reshape(-1, 1)
        return (enc_inputs, dec_inputs), dec_outputs


def walk_forward_split(series_index, n_test, enc_lens, dec_lens):
    train_index = series_index[: -n_test]
    valid_index = series_index[-(dec_lens + n_test - 1 + enc_lens):]
    return train_index, valid_index


def log_sin_curve(total_lens):
    source = np.sin(np.arange(total_lens)) + np.log(np.arange(1, total_lens + 1))
    noise = np.random.normal(0, 0.5, size=total_lens)
    x = source + noise
    return x, source


def create_dataset(x, enc_lens, dec_lens, n_valid, n_test, normalization=True):
    idxes = np.arange(len(x))
    train_idx, tmp_idx = walk_forward_split(idxes, n_test+n_valid, enc_lens, dec_lens)
    valid_idx, test_idx = walk_forward_split(tmp_idx, n_test, enc_lens, dec_lens)
    x_train, x_valid, x_test = x[train_idx], x[valid_idx], x[test_idx]

    if normalization:
        mu = x_train.mean()
        std = x_train.std()
        x_train = (x_train - mu) / std
        x_valid = (x_valid - mu) / std
        x_test = (x_test - mu) / std

    train = TorchSimpleSeriesDataSet(x_train, enc_lens, dec_lens)
    valid = TorchSimpleSeriesDataSet(x_valid, enc_lens, dec_lens)
    test = TorchSimpleSeriesDataSet(x_test, enc_lens, dec_lens)
    return train, valid, test


In [2]:
x, source = log_sin_curve(200)

In [6]:
enc_lens = 20
dec_lens = 10
n_valid =10
n_test = 10

trainset, validset, testset = create_dataset(x, enc_lens, dec_lens, n_valid, n_test)

batch_size = 12

traindl = DataLoader(trainset, shuffle=True, batch_size=batch_size)
validdl = DataLoader(validset, shuffle=False, batch_size=batch_size)
testdl = DataLoader(testset, shuffle=False, batch_size=batch_size)

model = SimpleSeq2Seq(1, 10)

In [22]:
from fastai.basic_train import Learner, AdamW
from fastai.basic_data import DataBunch


In [11]:
bunch = DataBunch(traindl, validdl, testdl)

In [12]:
learner = Learner(bunch, model, loss_func=nn.MSELoss())

In [17]:
learner.model.cpu()

SimpleSeq2Seq(
  (encoder): LSTM(1, 10, batch_first=True)
  (decoder): LSTM(1, 10, batch_first=True)
  (out): Sequential(
    (0): Linear(in_features=10, out_features=1, bias=True)
    (1): Tanh()
    (2): Dropout(p=0.0, inplace=False)
  )
)

In [20]:
learner.lr_find()

AttributeError: 'Learner' object has no attribute 'lr_find'

In [18]:
learner.validate(testdl)

[0.49292627]

In [13]:
learner.fin

epoch,train_loss,valid_loss,time
0,0.221285,0.363092,00:00
1,0.218922,0.369334,00:00
2,0.216972,0.371345,00:00
3,0.2169,0.368116,00:00
4,0.214972,0.376091,00:00
5,0.213489,0.378226,00:00
6,0.211887,0.354735,00:00
7,0.211516,0.386417,00:00
8,0.210286,0.407192,00:00
9,0.210018,0.352613,00:00
