-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5292df7
commit 1600754
Showing
1 changed file
with
53 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from dataclasses import dataclass, field | ||
|
||
|
||
@dataclass | ||
class SenticGCNTrainArgs: | ||
initializer: str = field( | ||
default="xavier_uniform", metadata={"help": "Type of initalizer to use."} | ||
) | ||
optimizer: str = field( | ||
default="adam", metadata={"help": "Type of optimizer to use."} | ||
) | ||
learning_rate: float = field( | ||
default=0.001, metadata={"help": "Default learning rate for training."} | ||
) | ||
l2reg: float = field(default=0.00001, metadata={"help": "Default l2reg value."}) | ||
epochs: int = field(default=100, metadata={"help": "Number of epochs to train."}) | ||
batch_size: int = field(default=32, metadata={"help": "Training batch size."}) | ||
log_step: int = field(default=5, metadata={"help": "Default log step."}) | ||
embed_dim: int = field( | ||
default=300, metadata={"help": "Number of neurons for embed layer."} | ||
) | ||
hidden_dim: int = field( | ||
default=300, metadata={"help": "Number of neurons for hidden layer."} | ||
) | ||
polarities_dim: int = field( | ||
default=3, metadata={"help": "Default dimension for polarities."} | ||
) | ||
save: bool = field( | ||
default=False, metadata={"help": "Flag to indicate if results should be saved."} | ||
) | ||
seed: int = field( | ||
default=776, metadata={"help": "Default random seed for training."} | ||
) | ||
device: str = field( | ||
default="cuda", metadata={"help": "Type of compute device to use for training."} | ||
) | ||
|
||
def __post_init__(self): | ||
assert self.initializer in [ | ||
"xavier_uniform", | ||
"xavier_uniform", | ||
"orthogonal", | ||
], "Invalid initializer type!" | ||
assert self.optimizer in [ | ||
"adadelta", | ||
"adagrad", | ||
"adam", | ||
"adamax", | ||
"asgd", | ||
"rmsprop", | ||
"sgd", | ||
], "Invalid optimizer" | ||
assert self.device in ["cuda", "cpu"], "Invalid device type." |