Skip to content

Commit

Permalink
[#41] add model init
Browse files Browse the repository at this point in the history
  • Loading branch information
raymondng76 committed Dec 10, 2021
1 parent 0333bab commit 436240e
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 4 deletions.
3 changes: 3 additions & 0 deletions sgnlp/models/sentic_asgcn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class SenticASGCNConfig(PreTrainedConfig):
Args:
embed_dim (:obj:`int`, defaults to 300): Embedding dimension size.
hidden_dim (:obj:`int`, defaults to 300): Size of hidden dimension.
dropout (:obj:`float`, defaults to 0.3): Droput percentage.
polarities_dim (:obj:`int`, defaults to 3): Size of output dimension represeting available polarities (e.g. Positive, Negative, Neutral).
device (:obj:`torch.device`, defaults to torch.device('cuda`)): Type of torch device.
Expand All @@ -27,11 +28,13 @@ def __init__(
embed_dim=300,
hidden_dim=300,
polarities_dim=3,
dropout=0.3,
device=torch.device("cuda"),
**kwargs
):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.hidden_dim = hidden_dim
self.dropout = dropout
self.polarities_dim = polarities_dim
self.device = device
1 change: 1 addition & 0 deletions sgnlp/models/sentic_asgcn/config/sentic_asgcn_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"embed_dim": 300,
"hidden_dim": 300,
"polarities_dim": 3,
"dropout": 0.3,
"save": true,
"seed": 776,
"device": "cuda"
Expand Down
3 changes: 3 additions & 0 deletions sgnlp/models/sentic_asgcn/data_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class SenticASGCNTrainArgs:
hidden_dim: int = field(
default=300, metadata={"help": "Number of neurons for hidden layer."}
)
dropout: float = field(
default=0.3, metadata={"help": "Default value for dropout percentages."}
)
polarities_dim: int = field(
default=3, metadata={"help": "Default dimension for polarities."}
)
Expand Down
25 changes: 22 additions & 3 deletions sgnlp/models/sentic_asgcn/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,39 @@
from transformers import PreTrainedModel
from transformers.file_utils import ModelOutput

from .modules.dynamic_rnn import DynamicLSTM
from .modules.gcn import GraphConvolution
from .config import SenticASGCNConfig


@dataclass
class SenticASGCNModelOutput(ModelOutput):
pass


class SenticASGCNPreTrainedModel(PreTrainedModel):
# config_class =
"""
An abstract class to handle weights initialization and a simple interface for download and loading pretrained models.
"""

config_class = SenticASGCNConfig
base_model_prefix = "sentic_asgcn"

def _init_weights(self, module):
pass


class SenticASGCNModel(SenticASGCNPreTrainedModel):
def __init__(self, config):
pass
def __init__(self, config: SenticASGCNConfig):
super().__init__(config)
self.text_lstm = DynamicLSTM(
config.embed_dim,
config.hidden_dim,
num_layers=1,
batch_first=True,
bidirectional=True,
)
self.gc1 = GraphConvolution(2 * config.hidden_dim, 2 * config.hidden_dim)
self.gc2 = GraphConvolution(2 * config.hidden_dim, 2 * config.hidden_dim)
self.fc = nn.Linear(2 * config.hidden_dim, config.polarities_dim)
self.text_embed_dropout = nn.Dropout(config.dropout)
4 changes: 3 additions & 1 deletion sgnlp/models/sentic_asgcn/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .data_class import SenticASGCNTrainArgs
from .utils import parse_args_and_load_config
from .utils import parse_args_and_load_config, set_random_seed


def train_model(cfg: SenticASGCNTrainArgs):
Expand All @@ -8,4 +8,6 @@ def train_model(cfg: SenticASGCNTrainArgs):

if __name__ == "__main__":
cfg = parse_args_and_load_config()
if cfg.seed is not None:
set_random_seed(cfg.seed)
train_model(cfg)

0 comments on commit 436240e

Please sign in to comment.