Skip to content

Commit

Permalink
added deconfounder structure as argument
Browse files Browse the repository at this point in the history
  • Loading branch information
josegcpa committed Apr 4, 2024
1 parent 6be686a commit ddd7af6
Showing 1 changed file with 8 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
n_cat_deconfounder: int | List[int] = None,
n_cont_deconfounder: int = None,
exclude_surrogate_variables: bool = False,
deconfounder_structure: list[int] = None,
*args,
**kwargs,
):
Expand All @@ -37,11 +38,14 @@ def __init__(
confounders. Defaults to None (no continuous deconfounding).
exclude_surrogate_variables (bool, optional): whether to exclude
surrogate variables
deconfounder_structure (list[int], optional): structure of the
deconfounder structure. Defaults to None (linear classifier).
"""
self.n_features_deconfounder = n_features_deconfounder
self.n_cat_deconfounder = n_cat_deconfounder
self.n_cont_deconfounder = n_cont_deconfounder
self.exclude_surrogate_variables = exclude_surrogate_variables
self.deconfounder_structure = deconfounder_structure
if self.exclude_surrogate_variables:
kwargs["output_features"] = 512 - self.n_features_deconfounder
super().__init__(*args, **kwargs)
Expand All @@ -52,6 +56,8 @@ def __init__(
self.n_cat_deconfounder = []
if self.n_cont_deconfounder is None:
self.n_cont_deconfounder = 0
if self.deconfounder_structure is None:
self.deconfounder_structure = []

self.init_deconfounding_layers()
self.gp = GlobalPooling()
Expand All @@ -67,14 +73,14 @@ def init_deconfounding_layers(self):
MLP(
self.n_features_deconfounder,
n_class,
[self.n_features_deconfounder],
self.deconfounder_structure,
)
)
if self.n_cont_deconfounder > 0:
self.confound_regressions = MLP(
self.n_features_deconfounder,
self.n_cont_deconfounder,
[self.n_features_deconfounder],
self.deconfounder_structure,
)

def forward(
Expand Down

0 comments on commit ddd7af6

Please sign in to comment.