In [19]:
import torch
from torch import nn, optim
import torch.nn.functional as F

from skorch import NeuralNetRegressor


class Regressor(nn.Module):
    def __init__(self, hidden_dim, feats1_dim, feats2_dim=None, drop_prob=0):
        super(Regressor, self).__init__()
        
        self.gru1 = nn.GRU(feats1_dim, hidden_dim, batch_first=True)
        self.gru2 = nn.GRU(feats2_dim, hidden_dim, batch_first=True) if feats2_dim else None
        self.dropout = nn.Dropout(drop_prob)
        self.linear = nn.Linear(hidden_dim*2 if feats2_dim else hidden_dim, 1)

    def forward(self, feats1, feats2=None):
        out, _ = self.gru1(feats1)
        
        if feats2 is not None:
            gru2_out, _ = self.gru2(feats2)
            out = torch.cat((out, gru2_out), dim=2)
        
        out = out[:, -1, :]
        
        out = self.dropout(out)
        out = self.linear(out)
        
        return out

N_SAMPLES = 1024
HIDDEN_DIM = 512
FEATS1_DIM = 128
FEATS2_DIM = 256

X = {
    'feats1': torch.randn(N_SAMPLES, 32, FEATS1_DIM),
    'feats2': torch.randn(N_SAMPLES, 32, FEATS2_DIM)
}

targets = torch.rand(N_SAMPLES, 1)

net = NeuralNetRegressor(module=Regressor,
                         module__hidden_dim=HIDDEN_DIM,
                         module__feats1_dim=FEATS1_DIM,
                         module__feats2_dim=FEATS2_DIM,                         
                         module__drop_prob=0.2,
                         batch_size=128,
                         max_epochs=20)

net.fit(X, targets)

  epoch    train_loss    valid_loss     dur
-------  ------------  ------------  ------
      1        0.2897        0.1827  4.1738
      2        0.2000        0.1339  4.1319
      3        0.1531        0.1085  4.0821
      4        0.1274        0.0957  4.1040
      5        0.1158        0.0892  4.0711
      6        0.1073        0.0862  4.1459
      7        0.1022        0.0850  4.1629
      8        0.0968        0.0845  4.1509
      9        0.0959        0.0844  4.3773
     10        0.0927        0.0844  4.5688
     11        0.0907        0.0846  4.2865
     12        0.0903        0.0848  4.0701
     13        0.0892        0.0850  4.2915
     14        0.0872        0.0851  4.0911
     15        0.0876        0.0852  4.0422
     16        0.0854        0.0854  4.0472
     17        0.0859        0.0856  4.0552
     18        0.0871        0.0857  4.1050
     19        0.0845        0.0860  4.0013
     20        0.0809        0.0862  4.1020


<class 'skorch.regressor.NeuralNetRegressor'>[initialized](
  module_=Regressor(
    (gru1): GRU(128, 512, batch_first=True)
    (gru2): GRU(256, 512, batch_first=True)
    (dropout): Dropout(p=0.2)
    (linear): Linear(in_features=1024, out_features=1, bias=True)
  ),
)