In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader


from estimator_model.deepiv import Net, NetWrapper, DeepIV, MixtureDensityNetwork, MDNWrapper
from estimator_model.utils import BatchData, DiscreteOBatchData

## We divide this notebook into 3 sections, where we
1. ### test Net and NetWrapper in the first section
2. ### test MDN and MDNWrapper in the second section
3. ### test deepiv in the final section.

### First section:
- continuous input, continuous output

In [2]:
net = Net(x_d=1, w_d=1, out_d=1)
net_wrapped = NetWrapper(net, is_y_net=False)
x = torch.normal(0, 1, size=(1000, 1))
w = torch.ones(1000, 1)
def f(x, w):
    return x * x + torch.exp(x) + 3 * w
target = f(x, w)
net_wrapped.fit(
    x=x,
    w=w,
    target=target,
    device='cpu',
    epoch=10,
)
((net_wrapped.predict(x, w) - f(x, w)) / f(x, w)).mean(dim=0)

End of epoch 0 | current loss 0.4365367889404297
End of epoch 1 | current loss 0.19809490442276
End of epoch 2 | current loss 0.2828143835067749
End of epoch 3 | current loss 0.29769447445869446
End of epoch 4 | current loss 0.27174270153045654
End of epoch 5 | current loss 0.21536581218242645
End of epoch 6 | current loss 0.14523985981941223
End of epoch 7 | current loss 0.09029783308506012
End of epoch 8 | current loss 0.061437517404556274
End of epoch 9 | current loss 0.040765754878520966


tensor([-0.0203], grad_fn=<MeanBackward1>)

- contiuous input, discrete output

In [3]:
out_d = 5
net = Net(x_d=1, w_d=1, out_d=out_d, is_discrete_output=True)
sm = nn.Softmax(dim=1)
loss = nn.CrossEntropyLoss()
result = sm(net(torch.ones(5, 1), torch.randn(5, 1)))
l = loss(result, torch.eye(5, 5))
l.backward()

In [4]:
x = torch.normal(0, 1, size=(1000, 1))
w = torch.normal(1, 2, size=(1000, 1))
def f(x, w):
    xw = torch.cat((x, w), dim=1)
    weight = torch.normal(0, 1, size=(xw.shape[1], 1))
    label_sign = torch.einsum('nd,dc->nc', [xw, weight])
    label = (label_sign > 0).to(int).squeeze()
    return F.one_hot(label)
target = f(x, w)
net = Net(1, 1, 2, is_discrete_output=True)
net_wrapped = NetWrapper(net=net, is_y_net=False)
net_wrapped.fit(x, w, target=target, device='cpu', epoch=20)
# training loss
nn.NLLLoss()(torch.log(net_wrapped.predict_proba(x, w)), torch.argmax(f(x, w), dim=1))

End of epoch 0 | current loss 0.6432521939277649
End of epoch 1 | current loss 0.6109849810600281
End of epoch 2 | current loss 0.5806307196617126
End of epoch 3 | current loss 0.5509892106056213
End of epoch 4 | current loss 0.5219535231590271
End of epoch 5 | current loss 0.49392032623291016
End of epoch 6 | current loss 0.46716728806495667
End of epoch 7 | current loss 0.441588819026947
End of epoch 8 | current loss 0.41623300313949585
End of epoch 9 | current loss 0.3911704421043396
End of epoch 10 | current loss 0.36681240797042847
End of epoch 11 | current loss 0.34367358684539795
End of epoch 12 | current loss 0.3219335079193115
End of epoch 13 | current loss 0.3016554117202759
End of epoch 14 | current loss 0.28303760290145874
End of epoch 15 | current loss 0.26598307490348816
End of epoch 16 | current loss 0.25048691034317017
End of epoch 17 | current loss 0.23644299805164337
End of epoch 18 | current loss 0.22363027930259705
End of epoch 19 | current loss 0.21195419132709503


tensor(3.6213, grad_fn=<NllLossBackward0>)

- discrete input, discret output

In [5]:
net = Net(x_d=3, w_d=1, out_d=5, is_discrete_input=True, is_discrete_output=True)
sm = nn.Softmax(dim=1)
loss = nn.CrossEntropyLoss()
x_input = torch.randint(0, 3, size=(5,)).squeeze()
w_input = torch.ones(5, 1)
result = sm(net(x_input, w_input))
l = loss(result, torch.eye(5, 5))
l.backward()

In [6]:
x = torch.eye(1000, 5).index_select(dim=0, index=torch.randint(0, 5, size=(1000,)))
w = torch.normal(0, 1, size=(1000, 1))
def f(x, w):
    xw = torch.cat((x, w), dim=1)
    weight = torch.normal(0, 1, size=(xw.shape[1], 1))
    label_sign = torch.einsum('nd,dc->nc', [xw, weight])
    label = (label_sign > 0).to(int).squeeze()
    return F.one_hot(label)
target = f(x, w)
net = Net(x_d=5, w_d=1, out_d=2, is_discrete_output=True, is_discrete_input=True)
net_wrapped = NetWrapper(net)
net_wrapped.fit(x, w, target=target, device='cpu', epoch=10)
nn.NLLLoss()(torch.log(net_wrapped.predict_proba(x, w)), torch.argmax(f(x, w), dim=1))

End of epoch 0 | current loss 0.6982303857803345
End of epoch 1 | current loss 0.6928979158401489
End of epoch 2 | current loss 0.6881588697433472
End of epoch 3 | current loss 0.6837738156318665
End of epoch 4 | current loss 0.6794893741607666
End of epoch 5 | current loss 0.6751973628997803
End of epoch 6 | current loss 0.6706061959266663
End of epoch 7 | current loss 0.6655802726745605
End of epoch 8 | current loss 0.6601215600967407
End of epoch 9 | current loss 0.6541720032691956


tensor(0.6549, grad_fn=<NllLossBackward0>)

- discrete input, continuous output

In [7]:
net = Net(x_d=3, w_d=1, out_d=1, is_discrete_input=True, is_discrete_output=False)
sm = nn.Softmax(dim=1)
loss = nn.MSELoss()
x_input = torch.randint(0, 3, size=(5,)).squeeze()
w_input = torch.ones(5, 1)
result = sm(net(x_input, w_input))
l = loss(result, torch.ones(5, 1))
l.backward()

In [9]:
x = torch.eye(1000, 5).index_select(dim=0, index=torch.randint(0, 5, size=(1000,)))
w = torch.normal(0, 2, size=(1000, 1))
def f(x, w):
    weight = torch.normal(0, 1, size=(6, 1))
    xw = torch.cat((x, w), dim=1)
    target = torch.einsum('nd,dc->nc', [xw, weight])
    return target
target = f(x, w)
net = Net(x_d=5, w_d=1, out_d=1, is_discrete_input=True)
net_wrapped = NetWrapper(net)
net_wrapped.fit(x=x, w=w, target=target, device='cpu', epoch=10)
nn.MSELoss()(net_wrapped.predict(x=x, w=w), target)

End of epoch 0 | current loss 3.8664863109588623
End of epoch 1 | current loss 1.245396375656128
End of epoch 2 | current loss 0.17864452302455902
End of epoch 3 | current loss 0.10468478500843048
End of epoch 4 | current loss 0.08381427824497223
End of epoch 5 | current loss 0.07277137041091919
End of epoch 6 | current loss 0.06443522870540619
End of epoch 7 | current loss 0.05730566382408142
End of epoch 8 | current loss 0.05132754519581795
End of epoch 9 | current loss 0.04609334468841553


tensor(0.0462, grad_fn=<MseLossBackward0>)