Skip to content

Commit

Permalink
[#41] add initial data_class
Browse files Browse the repository at this point in the history
  • Loading branch information
raymondng76 committed Dec 10, 2021
1 parent 5292df7 commit 1600754
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions sgnlp/models/asgcn/data_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from dataclasses import dataclass, field


@dataclass
class SenticGCNTrainArgs:
initializer: str = field(
default="xavier_uniform", metadata={"help": "Type of initalizer to use."}
)
optimizer: str = field(
default="adam", metadata={"help": "Type of optimizer to use."}
)
learning_rate: float = field(
default=0.001, metadata={"help": "Default learning rate for training."}
)
l2reg: float = field(default=0.00001, metadata={"help": "Default l2reg value."})
epochs: int = field(default=100, metadata={"help": "Number of epochs to train."})
batch_size: int = field(default=32, metadata={"help": "Training batch size."})
log_step: int = field(default=5, metadata={"help": "Default log step."})
embed_dim: int = field(
default=300, metadata={"help": "Number of neurons for embed layer."}
)
hidden_dim: int = field(
default=300, metadata={"help": "Number of neurons for hidden layer."}
)
polarities_dim: int = field(
default=3, metadata={"help": "Default dimension for polarities."}
)
save: bool = field(
default=False, metadata={"help": "Flag to indicate if results should be saved."}
)
seed: int = field(
default=776, metadata={"help": "Default random seed for training."}
)
device: str = field(
default="cuda", metadata={"help": "Type of compute device to use for training."}
)

def __post_init__(self):
assert self.initializer in [
"xavier_uniform",
"xavier_uniform",
"orthogonal",
], "Invalid initializer type!"
assert self.optimizer in [
"adadelta",
"adagrad",
"adam",
"adamax",
"asgd",
"rmsprop",
"sgd",
], "Invalid optimizer"
assert self.device in ["cuda", "cpu"], "Invalid device type."

0 comments on commit 1600754

Please sign in to comment.