In [1]:
from array import array
from sail.models.torch.tcn import TCNModel
import numpy as np
import torch
from sklearn.datasets import make_regression
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def regression_data(n_samples, n_features):
    X, y = make_regression(n_samples, n_features, n_informative=10, bias=0, random_state=0)
    X, y = X.astype(np.float32), y.astype(np.float32).reshape(-1, 1)
    Xt = StandardScaler().fit_transform(X)
    yt = StandardScaler().fit_transform(y)
    return Xt, yt

In [3]:
class Dataset():
    
    def __init__(self, x, y):
        self.x, self.y = x, y
    
    def __len__(self):
        return self.x.shape[0]
    
    def __getitem__(self, i):
        return torch.from_numpy(self.x[i]), torch.from_numpy(self.y[i])

In [4]:
n_samples = 6000
n_features = 10
batch_size = 64

In [5]:
model = TCNModel(input_dim=n_features, output_dim=1, kernel_size=3, num_filters=3, num_layers=3, dilation_base=2, weight_norm=False, dropout=0.2)
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [6]:
X, y = regression_data(n_samples, n_features)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

In [7]:
train_ds = Dataset(X_train, y_train)
test_ds = Dataset(X_test, y_test)

train_dl = DataLoader(train_ds, batch_size, shuffle=True)
test_dl = DataLoader(test_ds, batch_size)

In [8]:
for epoch in range(10):
    model.train()
    for xb,yb in train_dl:
        yb_pred = model(xb)
        loss = loss_fn(yb_pred, yb)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
    model.eval()
    with torch.no_grad():
        loss, acc, n = 0., 0., 0.
        for xb, yb in test_dl:
            yb_pred = model(xb)
            loss += loss_fn(yb_pred, yb)
            n += len(xb)
        loss = loss/n * 100
    print(f"EPOCH#:{epoch} \t Loss: {loss:.4f}")

EPOCH#:0 	 Loss: 1.1281
EPOCH#:1 	 Loss: 0.4946
EPOCH#:2 	 Loss: 0.1329
EPOCH#:3 	 Loss: 0.0451
EPOCH#:4 	 Loss: 0.0259
EPOCH#:5 	 Loss: 0.0156
EPOCH#:6 	 Loss: 0.0095
EPOCH#:7 	 Loss: 0.0066
EPOCH#:8 	 Loss: 0.0047
EPOCH#:9 	 Loss: 0.0036
