In [None]:
# dibujo de lo que voy a implementar

In [11]:
#!pip install torchview

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchview import draw_graph

class WideAndDeepModel(nn.Module):
    def __init__(self, num_wide_features, num_deep_features, units=30):
        super(WideAndDeepModel, self).__init__()
        # Wide branch
        self.wide = nn.Linear(num_wide_features, 1)
        
        # Deep branch
        self.hidden1 = nn.Linear(num_deep_features, units)
        self.hidden2 = nn.Linear(units, units)
        
        # Outputs
        self.main_output = nn.Linear(num_wide_features + units, 1)
        self.aux_output = nn.Linear(units, 1)

    def forward(self, inputs):
        input_wide, input_deep = inputs
        hidden1 = F.relu(self.hidden1(input_deep))
        hidden2 = F.relu(self.hidden2(hidden1))
        concat = torch.cat([input_wide, hidden2], dim=1)
        main_output = torch.sigmoid(self.main_output(concat))
        aux_output = torch.sigmoid(self.aux_output(hidden2))
        return main_output, aux_output

# Example usage
input_wide = torch.randn(5, 5)  # Wide features
input_deep = torch.randn(5, 10)  # Deep features
model = WideAndDeepModel(num_wide_features=5, num_deep_features=10)
main_output, aux_output = model((input_wide, input_deep))
print(f"Main Output: {main_output}")
print(f"Aux Output: {aux_output}")

# Inline plot of the model
model_graph = draw_graph(model, input_data=((input_wide, input_deep),), expand_nested=True)
model_graph.visual_graph.render("wide_and_deep_model", format="png")

# Unit tests
def test_model_output_shape():
    model = WideAndDeepModel(num_wide_features=5, num_deep_features=10)
    main_output, aux_output = model((input_wide, input_deep))
    assert main_output.shape == (5, 1), f"Expected shape (5, 1), got {main_output.shape}"
    assert aux_output.shape == (5, 1), f"Expected shape (5, 1), got {aux_output.shape}"

def test_forward_pass():
    model = WideAndDeepModel(num_wide_features=5, num_deep_features=10)
    main_output, aux_output = model((input_wide, input_deep))
    assert torch.all((main_output >= 0) & (main_output <= 1)), "Main output values should be in range [0, 1]"
    assert torch.all((aux_output >= 0) & (aux_output <= 1)), "Aux output values should be in range [0, 1]"

test_model_output_shape()
test_forward_pass()
print("All tests passed!")

Main Output: tensor([[0.5746],
        [0.4835],
        [0.6326],
        [0.5687],
        [0.6422]], grad_fn=<SigmoidBackward0>)
Aux Output: tensor([[0.4528],
        [0.4316],
        [0.4601],
        [0.4760],
        [0.4489]], grad_fn=<SigmoidBackward0>)
All tests passed!


<img src="wide_and_deep_model.png"/>