# Imports

In [1]:
import numpy as np
import torch
import os
import torch.nn as nn

In [5]:
class PositiveNN3(nn.Module):
  def __init__(self, input_dim, hidden1_dim, hidden2_dim, hidden3_dim, output_dim, activation_func, dropout_1, dropout_2, dropout_3):
    super(PositiveNN3, self).__init__()

    activation_functions = {
        'Tanh': nn.Tanh(),
        'ReLU': nn.ReLU(),
        'Sigmoid': nn.Sigmoid(),
        'ELU': nn.ELU(),
    }
    activation = activation_functions.get(activation_func)

    if activation is None:
        raise ValueError(f"Invalid activation function: {activation_func}")

    self.layers = nn.Sequential(
        nn.Linear(input_dim, hidden1_dim),
        nn.BatchNorm1d(hidden1_dim),  # Added Batch Norm
        activation,
        nn.Dropout(dropout_1),
        nn.Linear(hidden1_dim, hidden2_dim),
        nn.BatchNorm1d(hidden2_dim),  # Added Batch Norm
        activation,
        nn.Dropout(dropout_2),
        nn.Linear(hidden2_dim, hidden3_dim),
        nn.BatchNorm1d(hidden3_dim),  # Added Batch Norm
        activation,
        nn.Dropout(dropout_3),
        nn.Linear(hidden3_dim, output_dim),
        nn.ReLU()
    )

  def forward(self, x):
    return self.layers(x)
  

## 3 layers

In [6]:
num_l = 3
activation_function = 'Tanh'

In [7]:
model = PositiveNN3(12, 10, 26, 30, 2, 'Tanh', dropout_1=0.05, dropout_2=0, dropout_3=0.1)

model.load_state_dict(torch.load(f"best_val_model_params_subject{num_l}_{activation_function}_mohseni.pt"))
model.eval()

  model.load_state_dict(torch.load(f"best_val_model_params_subject{num_l}_{activation_function}_mohseni.pt"))


PositiveNN3(
  (layers): Sequential(
    (0): Linear(in_features=12, out_features=10, bias=True)
    (1): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Tanh()
    (3): Dropout(p=0.05, inplace=False)
    (4): Linear(in_features=10, out_features=26, bias=True)
    (5): BatchNorm1d(26, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Tanh()
    (7): Dropout(p=0, inplace=False)
    (8): Linear(in_features=26, out_features=30, bias=True)
    (9): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): Tanh()
    (11): Dropout(p=0.1, inplace=False)
    (12): Linear(in_features=30, out_features=2, bias=True)
    (13): ReLU()
  )
)