From 975b08807c40579708093a03918b4679a91a07e6 Mon Sep 17 00:00:00 2001 From: Prem Date: Tue, 19 Oct 2021 21:06:19 +0530 Subject: [PATCH] update docs --- docs/api.rst | 6 + docs/instructions/inference.rst | 2 +- docs/instructions/self_supervised.rst | 2 +- docs/instructions/training.rst | 2 +- docs/requirements.txt | 1 + openhands/apis/dpc.py | 2 +- openhands/datasets/isolated/autsl.py | 4 +- openhands/datasets/isolated/csl.py | 4 +- openhands/datasets/isolated/devisign.py | 4 +- openhands/datasets/isolated/gsl.py | 4 +- openhands/datasets/isolated/include.py | 4 +- openhands/datasets/isolated/lsa64.py | 4 +- openhands/datasets/isolated/ms_asl.py | 4 +- openhands/datasets/isolated/wlasl.py | 4 +- openhands/datasets/pose_transforms.py | 16 +-- openhands/datasets/ssl/dpc_dataset.py | 27 ++++- .../{encoder => common}/transformer_layers.py | 0 openhands/models/decoder/bert_hf.py | 15 ++- openhands/models/decoder/fc.py | 21 +++- openhands/models/decoder/rnn.py | 19 ++- openhands/models/encoder/__init__.py | 6 +- openhands/models/encoder/cnn2d.py | 3 + openhands/models/encoder/cnn3d.py | 3 + .../models/encoder/graph/decoupled_gcn.py | 25 ++-- .../models/encoder/graph/pose_flattener.py | 21 ++-- openhands/models/encoder/graph/sgn.py | 33 +++++- openhands/models/encoder/graph/st_gcn.py | 24 ++-- openhands/models/ssl/__init__.py | 3 + openhands/models/ssl/dpc_rnn.py | 111 +++++++++++++----- 29 files changed, 281 insertions(+), 93 deletions(-) rename openhands/models/{encoder => common}/transformer_layers.py (100%) create mode 100644 openhands/models/ssl/__init__.py diff --git a/docs/api.rst b/docs/api.rst index 576faca..9482786 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -22,6 +22,12 @@ Decoders .. automodule:: openhands.models.decoder :members: +SSL-Models +^^^^^^^^ + +.. automodule:: openhands.models.ssl + :members: + Datasets -------- diff --git a/docs/instructions/inference.rst b/docs/instructions/inference.rst index 24ba0ea..3b30f14 100644 --- a/docs/instructions/inference.rst +++ b/docs/instructions/inference.rst @@ -15,7 +15,7 @@ Computing accuacy using test set .. code:: python import omegaconf - from openhands.core.inference import InferenceModel + from openhands.apis.inference import InferenceModel cfg = omegaconf.OmegaConf.load("path/to/config.yaml") model = InferenceModel(cfg=cfg) diff --git a/docs/instructions/self_supervised.rst b/docs/instructions/self_supervised.rst index 9f3ebca..7616eb7 100644 --- a/docs/instructions/self_supervised.rst +++ b/docs/instructions/self_supervised.rst @@ -43,7 +43,7 @@ Finally, run the following snippet to perform the pretraining: .. code:: python import omegaconf - from openhands.core.dpc import PretrainingModelDPC + from openhands.apis.dpc import PretrainingModelDPC cfg = omegaconf.OmegaConf.load("path/to/config.yaml") trainer = PretrainingModelDPC(cfg=cfg) diff --git a/docs/instructions/training.rst b/docs/instructions/training.rst index 4dedd9a..b22f8ec 100644 --- a/docs/instructions/training.rst +++ b/docs/instructions/training.rst @@ -13,7 +13,7 @@ After you have a config ready, run the following python snippet: .. code:: python import omegaconf - from openhands.core.classification_model import ClassificationModel + from openhands.apis.classification_model import ClassificationModel from openhands.core.exp_utils import get_trainer cfg = omegaconf.OmegaConf.load("path/to/config.yaml") diff --git a/docs/requirements.txt b/docs/requirements.txt index 05bc3eb..92ac715 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -22,3 +22,4 @@ scikit_learn==1.0 sphinx myst-parser sphinx_rtd_theme +sphinx_copybutton diff --git a/openhands/apis/dpc.py b/openhands/apis/dpc.py index 968e8d7..ea4a04a 100644 --- a/openhands/apis/dpc.py +++ b/openhands/apis/dpc.py @@ -33,7 +33,7 @@ def calc_topk_accuracy(output, target, topk=(1,)): def process_output(mask): """task mask as input, compute the target for contrastive loss""" - (B, NP, B2, NS) = mask.size() # [B, P, SQ, B, N, SQ] + (B, NP, B2, NS) = mask.size() # [B, P, B, N,] target = (mask == 1).float() target.requires_grad = False return target, (B, B2, NS, NP) diff --git a/openhands/datasets/isolated/autsl.py b/openhands/datasets/isolated/autsl.py index 14dd17d..4c6da33 100644 --- a/openhands/datasets/isolated/autsl.py +++ b/openhands/datasets/isolated/autsl.py @@ -5,9 +5,9 @@ class AUTSLDataset(BaseIsolatedDataset): """ - Turkish Isolated Sign language dataset from the paper + Turkish Isolated Sign language dataset from the paper: - > [AUTSL: A Large Scale Multi-modal Turkish Sign Language Dataset and Baseline Methods](https://arxiv.org/abs/2008.00932)
+ `AUTSL: A Large Scale Multi-modal Turkish Sign Language Dataset and Baseline Methods `_ """ def read_glosses(self): class_mappings_df = pd.read_csv(self.class_mappings_file_path) diff --git a/openhands/datasets/isolated/csl.py b/openhands/datasets/isolated/csl.py index 9793474..e1c449b 100644 --- a/openhands/datasets/isolated/csl.py +++ b/openhands/datasets/isolated/csl.py @@ -5,9 +5,9 @@ class CSLDataset(BaseIsolatedDataset): """ - Chinese Isolated Sign language dataset from the paper + Chinese Isolated Sign language dataset from the paper: - > [Attention-Based 3D-CNNs for Large-Vocabulary Sign Language Recognition](https://ieeexplore.ieee.org/document/8466903)
+ `Attention-Based 3D-CNNs for Large-Vocabulary Sign Language Recognition `_ """ def read_glosses(self): self.glosses = [] diff --git a/openhands/datasets/isolated/devisign.py b/openhands/datasets/isolated/devisign.py index b96a18a..c34c991 100644 --- a/openhands/datasets/isolated/devisign.py +++ b/openhands/datasets/isolated/devisign.py @@ -7,9 +7,9 @@ class DeviSignDataset(BaseIsolatedDataset): """ - Chinese Isolated Sign language dataset from the paper + Chinese Isolated Sign language dataset from the paper: - > [The devisign large vocabulary of chinese sign language database and baseline evaluations]
+ `The devisign large vocabulary of chinese sign language database and baseline evaluations` """ def read_glosses(self): self.glosses = [] diff --git a/openhands/datasets/isolated/gsl.py b/openhands/datasets/isolated/gsl.py index 323bb3c..143ecbb 100644 --- a/openhands/datasets/isolated/gsl.py +++ b/openhands/datasets/isolated/gsl.py @@ -5,9 +5,9 @@ class GSLDataset(BaseIsolatedDataset): """ - Greek Isolated Sign language dataset from the paper + Greek Isolated Sign language dataset from the paper: - > [A Comprehensive Study on Deep Learning-based Methods for Sign Language Recognition](https://ieeexplore.ieee.org/document/8466903)
+ `A Comprehensive Study on Deep Learning-based Methods for Sign Language Recognition `_ """ def read_glosses(self): self.glosses = [ diff --git a/openhands/datasets/isolated/include.py b/openhands/datasets/isolated/include.py index 333c247..ba8d006 100644 --- a/openhands/datasets/isolated/include.py +++ b/openhands/datasets/isolated/include.py @@ -5,9 +5,9 @@ class INCLUDEDataset(BaseIsolatedDataset): """ - Indian Isolated Sign language dataset from the paper + Indian Isolated Sign language dataset from the paper: - > [INCLUDE: A Large Scale Dataset for Indian Sign Language Recognition](https://dl.acm.org/doi/10.1145/3394171.3413528)
+ `INCLUDE: A Large Scale Dataset for Indian Sign Language Recognition `_ """ def read_glosses(self): # TODO: Separate the classes into a separate file? diff --git a/openhands/datasets/isolated/lsa64.py b/openhands/datasets/isolated/lsa64.py index ae6282e..60634f2 100644 --- a/openhands/datasets/isolated/lsa64.py +++ b/openhands/datasets/isolated/lsa64.py @@ -6,9 +6,9 @@ class LSA64Dataset(BaseIsolatedDataset): """ - Argentinian Isolated Sign language dataset from the paper + Argentinian Isolated Sign language dataset from the paper: - > [LSA64: An Argentinian Sign Language Dataset](http://sedici.unlp.edu.ar/bitstream/handle/10915/56764/Documento_completo.pdf-PDFA.pdf)
+ `LSA64: An Argentinian Sign Language Dataset `_ """ def read_glosses(self): df = pd.read_csv(self.class_mappings_file_path, delimiter="|", header=None) diff --git a/openhands/datasets/isolated/ms_asl.py b/openhands/datasets/isolated/ms_asl.py index 4f85a80..8ad0063 100644 --- a/openhands/datasets/isolated/ms_asl.py +++ b/openhands/datasets/isolated/ms_asl.py @@ -5,9 +5,9 @@ class MSASLDataset(BaseIsolatedDataset): """ - American Isolated Sign language dataset from the paper + American Isolated Sign language dataset from the paper: - > [MS-ASL: A Large-Scale Data Set and Benchmark for Understanding American Sign Language](https://arxiv.org/abs/1812.01053)
+ `MS-ASL: A Large-Scale Data Set and Benchmark for Understanding American Sign Language `_ """ def read_glosses(self): # TODO: Separate the classes into a separate file? diff --git a/openhands/datasets/isolated/wlasl.py b/openhands/datasets/isolated/wlasl.py index d429edf..bf6c4f6 100644 --- a/openhands/datasets/isolated/wlasl.py +++ b/openhands/datasets/isolated/wlasl.py @@ -5,9 +5,9 @@ class WLASLDataset(BaseIsolatedDataset): """ - American Isolated Sign language dataset from the paper + American Isolated Sign language dataset from the paper: - > [Word-level Deep Sign Language Recognition from Video: A New Large-scale Dataset and Methods Comparison](https://arxiv.org/abs/1910.11006)
+ `Word-level Deep Sign Language Recognition from Video: A New Large-scale Dataset and Methods Comparison `_ """ def read_glosses(self): with open(self.split_file, "r") as f: diff --git a/openhands/datasets/pose_transforms.py b/openhands/datasets/pose_transforms.py index 65b14b3..f4056e1 100644 --- a/openhands/datasets/pose_transforms.py +++ b/openhands/datasets/pose_transforms.py @@ -95,6 +95,7 @@ def __call__(self, data:dict): class PoseSelect: """ Select the given index keypoints from all keypoints. + Args: preset (str | None, optional): can be used to specify existing presets - `mediapipe_holistic_minimal_27` or `mediapipe_holistic_top_body_59` If None, then the `pose_indexes` argument indexes will be used to select. Default: ``None`` @@ -135,7 +136,7 @@ def __call__(self, data:dict): # Adopted from: https://github.com/AmitMY/pose-format/ class ShearTransform: """ - Applies [2D shear transform](https://en.wikipedia.org/wiki/Shear_matrix) + Applies `2D shear `_ transformation Args: shear_std (float): std to use for shear transformation. Default: 0.2 @@ -168,7 +169,7 @@ def __call__(self, data:dict): class RotatationTransform: """ - Applies [2D rotation transformation](https://en.wikipedia.org/wiki/Rotation_matrix). + Applies `2D rotation `_ transformation. Args: rotation_std (float): std to use for rotation transformation. Default: 0.2 @@ -205,7 +206,8 @@ def __call__(self, data): class ScaleTransform: """ - Applies [Scaling](https://en.wikipedia.org/wiki/Scaling_(geometry)) transformation + Applies `Scaling `_ transformation + Args: scale_std (float): std to use for Scaling transformation. Default: 0.2 """ @@ -452,10 +454,10 @@ def __call__(self, data): class TemporalSample: """ Randomly choose Uniform and Temporal subsample - If subsample_mode==2, randomly sub-sampling or uniform-sampling is done - If subsample_mode==0, only uniform-sampling (for test sets) - If subsample_mode==1, only sub-sampling (to reproduce results of some papers that use only subsampling) - + - If subsample_mode==2, randomly sub-sampling or uniform-sampling is done + - If subsample_mode==0, only uniform-sampling (for test sets) + - If subsample_mode==1, only sub-sampling (to reproduce results of some papers that use only subsampling) + Args: num_frames (int): Number of frames to subsample. subsample_mode (int): Mode to choose. diff --git a/openhands/datasets/ssl/dpc_dataset.py b/openhands/datasets/ssl/dpc_dataset.py index 2da798e..44a54ad 100644 --- a/openhands/datasets/ssl/dpc_dataset.py +++ b/openhands/datasets/ssl/dpc_dataset.py @@ -11,6 +11,18 @@ from ...core.data import create_pose_transforms class WindowedDatasetHDF5(torch.utils.data.DataLoader): + """ + Windowed dataset loader from HDF5 for SL-DPC model. + + Args: + root_dir (str): Directory which contains the data. + file_format (str): File type. Default: ``h5``. + transforms (obj | None): Compose object with transforms or None. Default: ``None``. + seq_len (int): Sequence length for each window. Default: 10. + num_seq (int): Total number of windows. Default: 7. + downsample (int): Number of frames to skip per timestep when sampling. Default: 3. + num_channels (int): Number of input channels. Default: 2. + """ def __init__( self, root_dir, @@ -113,13 +125,26 @@ def get_weights_for_balanced_sampling(self): class WindowedDatasetPickle(torch.utils.data.DataLoader): + """ + Windowed dataset loader from HDF5 for SL-DPC model. + This module is for loading finetuning datasets. + + Args: + root_dir (str): Directory which contains the data. + file_format (str): File type. Default: ``pkl``. + transforms (obj | None): Compose object with transforms or None. Default: ``None``. + seq_len (int): Sequence length for each window. Default: 10. + num_seq (int): Total number of windows. Default: 10. + downsample (int): Number of frames to skip per timestep when sampling. Default: 1. + num_channels (int): Number of input channels. Default: 2. + """ def __init__( self, root_dir, file_format='pkl', transforms=None, seq_len=10, - num_seq=6, + num_seq=10, downsample=1, num_channels=2, ): diff --git a/openhands/models/encoder/transformer_layers.py b/openhands/models/common/transformer_layers.py similarity index 100% rename from openhands/models/encoder/transformer_layers.py rename to openhands/models/common/transformer_layers.py diff --git a/openhands/models/decoder/bert_hf.py b/openhands/models/decoder/bert_hf.py index 76ea540..f7a1006 100644 --- a/openhands/models/decoder/bert_hf.py +++ b/openhands/models/decoder/bert_hf.py @@ -35,6 +35,15 @@ def forward(self, x): class BERT(nn.Module): + """ + BERT decoder module. + + Args: + n_features (int): Number of features in the input. + num_class (int): Number of class for classification. + config (dict): Configuration set for BERT layer. + + """ def __init__(self, n_features, num_class, config): """ pooling_type -> ["max","avg","att","cls"] @@ -71,7 +80,11 @@ def __init__(self, n_features, num_class, config): def forward(self, x): """ - x.shape: (batch_size, T, n_features) + Args: + x (torch.Tensor): Input tensor of shape: (batch_size, T, n_features) + + returns: + torch.Tensor: logits for classification. """ x = self.l1(x) if self.cls_token: diff --git a/openhands/models/decoder/fc.py b/openhands/models/decoder/fc.py index f7ff409..e2d6b3c 100644 --- a/openhands/models/decoder/fc.py +++ b/openhands/models/decoder/fc.py @@ -3,7 +3,15 @@ class FC(nn.Module): - def __init__(self, n_features, num_class, dropout_ratio=0.2, batch_norm=False, **kwargs): + """ + Fully connected layer head + Args: + n_features (int): Number of features in the input. + num_class (int): Number of class for classification. + dropout_ratio (float): Dropout ratio to use. Default: 0.2. + batch_norm (bool): Whether to use batch norm or not. Default: ``False``. + """ + def __init__(self, n_features, num_class, dropout_ratio=0.2, batch_norm=False): super().__init__() self.dropout = nn.Dropout(p=dropout_ratio) self.bn = batch_norm @@ -15,9 +23,14 @@ def __init__(self, n_features, num_class, dropout_ratio=0.2, batch_norm=False, * nn.init.normal_(self.classifier.weight, 0, math.sqrt(2.0 / num_class)) def forward(self, x): - ''' - x.shape: (batch_size, n_features) - ''' + """ + Args: + x (torch.Tensor): Input tensor of shape: (batch_size, n_features) + + returns: + torch.Tensor: logits for classification. + """ + x = self.dropout(x) if self.bn: x = self.bn(x) diff --git a/openhands/models/decoder/rnn.py b/openhands/models/decoder/rnn.py index 2892269..1ae3cf7 100644 --- a/openhands/models/decoder/rnn.py +++ b/openhands/models/decoder/rnn.py @@ -5,6 +5,19 @@ class RNNClassifier(nn.Module): + """ + RNN head for classification. + + Args: + n_features (int): Number of features in the input. + num_class (int): Number of class for classification. + rnn_type (str): GRU or LSTM. Default: ``GRU``. + hidden_size (str): Hidden dim to use for RNN. Default: 512. + num_layers (int): Number of layers of RNN to use. Default: 1. + bidirectional (bool): Whether to use bidirectional RNN or not. Default: ``True``. + use_attention (bool): Whether to use attenion for pooling or not. Default: ``False``. + + """ def __init__( self, n_features, @@ -32,7 +45,11 @@ def __init__( def forward(self, x): """ - x.shape: (batch_size, T, n_features) + Args: + x (torch.Tensor): Input tensor of shape: (batch_size, T, n_features) + + returns: + torch.Tensor: logits for classification. """ self.rnn.flatten_parameters() diff --git a/openhands/models/encoder/__init__.py b/openhands/models/encoder/__init__.py index 2930508..b65d041 100644 --- a/openhands/models/encoder/__init__.py +++ b/openhands/models/encoder/__init__.py @@ -1,5 +1,9 @@ from .graph.pose_flattener import PoseFlattener from .graph.decoupled_gcn import DecoupledGCN from .graph.st_gcn import STGCN +from .graph.sgn import SGN -__all__ = ["PoseFlattener", "DecoupledGCN", "STGCN"] +from .cnn2d import CNN2D +from .cnn3d import CNN3D + +__all__ = ["PoseFlattener", "DecoupledGCN", "STGCN", "SGN", "CNN2D", "CNN3D"] diff --git a/openhands/models/encoder/cnn2d.py b/openhands/models/encoder/cnn2d.py index 9f99e36..2c32259 100644 --- a/openhands/models/encoder/cnn2d.py +++ b/openhands/models/encoder/cnn2d.py @@ -20,6 +20,9 @@ def __init__(self, in_channels=3, backbone="resnet18", pretrained=True): self.backbone.fc = nn.Identity() def forward(self, x): + """ + forward step + """ b, c, t, h, w = x.shape cnn_embeds = [] for i in range(t): diff --git a/openhands/models/encoder/cnn3d.py b/openhands/models/encoder/cnn3d.py index 93143cd..e6cac7d 100644 --- a/openhands/models/encoder/cnn3d.py +++ b/openhands/models/encoder/cnn3d.py @@ -52,6 +52,9 @@ def __init__(self, in_channels, backbone, pretrained=True, **kwargs): self.n_out_features = 400 # list(self.backbone.modules())[-2].out_features def forward(self, x): + """ + forward step + """ x = self.backbone(x) return x.transpose(0, 1) # Batch-first diff --git a/openhands/models/encoder/graph/decoupled_gcn.py b/openhands/models/encoder/graph/decoupled_gcn.py index 9a3fcfb..e413e3e 100644 --- a/openhands/models/encoder/graph/decoupled_gcn.py +++ b/openhands/models/encoder/graph/decoupled_gcn.py @@ -326,8 +326,8 @@ def forward(self, x, keep_prob): class DecoupledGCN(nn.Module): """ ST-GCN backbone with Decoupled GCN layers, Self Attention and DropGraph proposed in the paper: - - > [Skeleton Aware Multi-modal Sign Language Recognition](https://arxiv.org/pdf/2103.08833.pdf)
+ `Skeleton Aware Multi-modal Sign Language Recognition + `_ Args: in_channels (int): Number of channels in the input data. @@ -335,13 +335,6 @@ class DecoupledGCN(nn.Module): groups (int): Number of Decouple groups to use. Default: 8. block_size (int): Block size used for Temporal masking in Dropgraph. Default: 41. n_out_features (int): Output Embedding dimension. Default: 256. - Shape: - - Input: :math:`(N, in_channels, T_{in}, V_{in})` - - Output: :math:`(N, n_out_features)` where - :math:`N` is a batch size, - :math:`T_{in}` is a length of input sequence, - :math:`V_{in}` is the number of graph nodes, - :math:`n_out_features` is the `n_out_features' value, """ def __init__( @@ -411,6 +404,20 @@ def __init__( bn_init(self.data_bn, 1) def forward(self, x, keep_prob=0.9): + """ + Args: + x (torch.Tensor): Input graph sequence of shape :math:`(N, in\_channels, T_{in}, V_{in})` + keep_prob (float): The probability to keep the node. Default: 0.9. + + Returns: + torch.Tensor: Output embedding of shape :math:`(N, n\_out\_features)` + + where: + - :math:`N` is a batch size, + - :math:`T_{in}` is a length of input sequence, + - :math:`V_{in}` is the number of graph nodes, + - :math:`n\_out\_features` is the `n\_out\_features' value. + """ N, C, T, V = x.size() x = x.permute(0, 3, 1, 2).contiguous().view(N, V * C, T) x = self.data_bn(x) diff --git a/openhands/models/encoder/graph/pose_flattener.py b/openhands/models/encoder/graph/pose_flattener.py index b77ee97..4000838 100644 --- a/openhands/models/encoder/graph/pose_flattener.py +++ b/openhands/models/encoder/graph/pose_flattener.py @@ -9,13 +9,7 @@ class PoseFlattener(nn.Module): Args: in_channels (int): Number of channels in the input data. num_points (int): Number of spatial joints - - Shape: - - Input: :math:`(N, in_channels, T_{in}, V_{in})` - - Output: :math:`(N, T_{in}, in_channels * V_{in})` where - :math:`N` is a batch size, - :math:`T_{in}` is a length of input sequence, - :math:`V_{in}` is the number of graph nodes, + """ def __init__(self, in_channels=3, num_points=27): super().__init__() @@ -23,10 +17,17 @@ def __init__(self, in_channels=3, num_points=27): def forward(self, x): """ - x.shape: (B, C, T, V) - + Args: + x (torch.Tensor): Input tensor of shape :math:`(N, in_channels, T_{in}, V_{in})` + Returns: - out.shape: (B, T, C*V) + torch.Tensor: Tensor with channel dimension flattened of shape :math:`(N, T_{in}, in\_channels * V_{in})` + + where + - :math:`N` is a batch size, + - :math:`T_{in}` is a length of input sequence, + - :math:`V_{in}` is the number of graph nodes, + """ x = x.permute(0, 2, 1, 3) return torch.flatten(x, start_dim=2) diff --git a/openhands/models/encoder/graph/sgn.py b/openhands/models/encoder/graph/sgn.py index 4e7b9c7..67cdc34 100644 --- a/openhands/models/encoder/graph/sgn.py +++ b/openhands/models/encoder/graph/sgn.py @@ -106,10 +106,21 @@ def forward(self, x1): class SGN(nn.Module): """ - https://arxiv.org/pdf/1904.01189.pdf + SGN model proposed in + `Semantics-Guided Neural Networks for Efficient Skeleton-Based Human Action Recognition + `_ + + Note: + The model supports inputs only with fixed number of frames. + + Args: + n_frames (int): Number of frames in the input sequence. + num_points (int): Number of spatial points in a graph. + in_channels (int): Number of channels in the input data. Default: 2. + bias (bool): Whether to use bias or not. Default: ``True``. """ - def __init__(self, n_frames, num_points, in_channels=3, bias=True): + def __init__(self, n_frames, num_points, in_channels=2, bias=True): super(SGN, self).__init__() self.dim1 = 256 @@ -150,6 +161,21 @@ def __init__(self, n_frames, num_points, in_channels=3, bias=True): nn.init.constant_(self.gcn3.w.cnn.weight, 0) def forward(self, input): + """ + Args: + input (torch.Tensor): Input tensor of shape :math:`(N, in\_channels, T_{in}, V_{in})` + + Returns: + torch.Tensor: Output embedding of shape :math:`(N, n\_out\_features)` + + where + - :math:`N` is a batch size, + - :math:`T_{in}` is a length of input sequence, + - :math:`V_{in}` is the number of graph nodes, + - :math:`n\_out\_features` is the output embedding dimension. + + """ + # B, C, T, V input = input.permute(0, 2, 3, 1) @@ -182,6 +208,9 @@ def forward(self, input): return output def one_hot(self, bs, spa, tem): + """ + get one-hot encodings + """ y = torch.arange(spa).unsqueeze(-1) y_onehot = torch.FloatTensor(spa, spa) y_onehot.zero_() diff --git a/openhands/models/encoder/graph/st_gcn.py b/openhands/models/encoder/graph/st_gcn.py index ba25601..0ecb001 100644 --- a/openhands/models/encoder/graph/st_gcn.py +++ b/openhands/models/encoder/graph/st_gcn.py @@ -136,7 +136,6 @@ def __init__( self.relu = nn.ReLU(inplace=True) def forward(self, x, A): - res = self.residual(x) x, A = self.gcn(x, A) x = self.tcn(x) + res @@ -157,15 +156,6 @@ class STGCN(nn.Module): edge_importance_weighting (bool): If ``True``, adds a learnable importance weighting to the edges of the graph. Default: True. n_out_features (int): Output Embedding dimension. Default: 256. kwargs (dict): Other parameters for graph convolution units. - - Shape: - - Input: :math:`(N, in\_channels, T_{in}, V_{in}, M_{in})` - - - Output: :math:`(N, n\_out\_features)` where - :math:`N` is a batch size, - :math:`T_{in}` is a length of input sequence, - :math:`V_{in}` is the number of graph nodes, - :math:`n\_out\_features` is the output embedding dimension, """ def __init__(self, in_channels, graph_args, edge_importance_weighting, n_out_features = 256, **kwargs): super().__init__() @@ -204,6 +194,20 @@ def __init__(self, in_channels, graph_args, edge_importance_weighting, n_out_fea self.edge_importance = [1] * len(self.st_gcn_networks) def forward(self, x): + """ + Args: + x (torch.Tensor): Input tensor of shape :math:`(N, in\_channels, T_{in}, V_{in})` + + Returns: + torch.Tensor: Output embedding of shape :math:`(N, n\_out\_features)` + + where + - :math:`N` is a batch size, + - :math:`T_{in}` is a length of input sequence, + - :math:`V_{in}` is the number of graph nodes, + - :math:`n\_out\_features` is the output embedding dimension. + + """ N, C, T, V = x.size() x = x.permute(0, 3, 1, 2).contiguous() # NCTV -> NVCT x = x.view(N, V * C, T) diff --git a/openhands/models/ssl/__init__.py b/openhands/models/ssl/__init__.py new file mode 100644 index 0000000..b70d171 --- /dev/null +++ b/openhands/models/ssl/__init__.py @@ -0,0 +1,3 @@ +from .dpc_rnn import DPC_RNN_Pretrainer, DPC_RNN_Finetuner + +__all__ = ["DPC_RNN_Pretrainer", "DPC_RNN_Finetuner"] diff --git a/openhands/models/ssl/dpc_rnn.py b/openhands/models/ssl/dpc_rnn.py index ce457c9..a858f47 100644 --- a/openhands/models/ssl/dpc_rnn.py +++ b/openhands/models/ssl/dpc_rnn.py @@ -5,8 +5,49 @@ # Adopted from: https://github.com/TengdaHan/DPC +def load_weights_from_pretrained(model, pretrained_model_path): + ckpt = torch.load(pretrained_model_path) + ckpt_dict = ckpt["state_dict"].items() + pretrained_dict = {k.replace("model.", ""): v for k, v in ckpt_dict} + + model_dict = model.state_dict() + tmp = {} + print("\n=======Check Weights Loading======") + print("Weights not used from pretrained file:") + for k, v in pretrained_dict.items(): + if k in model_dict: + tmp[k] = v + else: + print(k) + print("---------------------------") + print("Weights not loaded into new model:") + for k, v in model_dict.items(): + if k not in pretrained_dict: + print(k) + print("===================================\n") + del pretrained_dict + model_dict.update(tmp) + del tmp + model.load_state_dict(model_dict) + model.to(dtype=torch.float) + return model + class DPC_RNN_Pretrainer(nn.Module): + """ + ST-DPC model pretrain module. + + Args: + pred_steps (int): Number of future prediction steps. Default: 3. + in_channels (int): Number of channels in the input data. Default: 2. + hidden_channels (int): Hidden channels for ST-GCN backbone. Default: 64. + hidden_dim (int): Output dimension from ST-GCN backbone. Default: 256. + dropout (float): Dropout ratio for ST-GCN backbone. Default: 256. + graph_args (dict): Parameters for Spatio-temporal graph construction. + edge_importance_weighting (bool): If ``True``, adds a learnable importance weighting to the edges of the graph. Default: True. + kwargs (dict): Other parameters for graph convolution units. + + """ def __init__( self, pred_steps=3, @@ -46,6 +87,17 @@ def __init__( self._initialize_weights(self.network_pred) def forward(self, block): + """ + Args: + block (torch.Tensor): Input data of shape :math:`(N, W, T, V, in_channels)`. + where: + - :math:`N` is a batch size, + - :math:`W` is the number of windows, + - :math:`T` is a length of input sequence, + - :math:`V` is the number of graph nodes, + - :math:`in\_channels` is the number of channels. + + """ block = block.permute(0, 1, 4, 2, 3) # B, N, T, V, C -> B, N, C, T, V B, N, C, T, V = block.shape block = block.view(B * N, C, T, V) @@ -113,34 +165,26 @@ def _initialize_weights(self, module): nn.init.orthogonal_(param, 1) -def load_weights_from_pretrained(model, pretrained_model_path): - ckpt = torch.load(pretrained_model_path) - ckpt_dict = ckpt["state_dict"].items() - pretrained_dict = {k.replace("model.", ""): v for k, v in ckpt_dict} - - model_dict = model.state_dict() - tmp = {} - print("\n=======Check Weights Loading======") - print("Weights not used from pretrained file:") - for k, v in pretrained_dict.items(): - if k in model_dict: - tmp[k] = v - else: - print(k) - print("---------------------------") - print("Weights not loaded into new model:") - for k, v in model_dict.items(): - if k not in pretrained_dict: - print(k) - print("===================================\n") - del pretrained_dict - model_dict.update(tmp) - del tmp - model.load_state_dict(model_dict) - model.to(dtype=torch.float) - return model - class DPC_RNN_Finetuner(nn.Module): + """ + SL-DPC Finetune module. + + This module is proposed in + `OpenHands: Making Sign Language Recognition Accessible with Pose-based Pretrained Models across Languages + `_ + + Args: + num_class (int): Number of classes to classify. + pred_steps (int): Number of future prediction steps. Default: 3. + in_channels (int): Number of channels in the input data. Default: 2. + hidden_channels (int): Hidden channels for ST-GCN backbone. Default: 64. + hidden_dim (int): Output dimension from ST-GCN backbone. Default: 256. + dropout (float): Dropout ratio for ST-GCN backbone. Default: 256. + graph_args (dict): Parameters for Spatio-temporal graph construction. + edge_importance_weighting (bool): If ``True``, adds a learnable importance weighting to the edges of the graph. Default: True. + kwargs (dict): Other parameters for graph convolution units. + + """ def __init__( self, num_class=60, @@ -181,6 +225,19 @@ def __init__( self._initialize_weights(self.final_fc) def forward(self, block): + """ + Args: + block (torch.Tensor): Input data of shape :math:`(N, W, T, V, in_channels)`. + where: + - :math:`N` is a batch size, + - :math:`W` is the number of windows, + - :math:`T` is a length of input sequence, + - :math:`V` is the number of graph nodes, + - :math:`in\_channels` is the number of channels. + + returns: + torch.Tensor: logits for classification. + """ B, N, C, T, V = block.shape block = block.view(B * N, C, T, V)