Skip to content

Commit

Permalink
[#41] flesh out docstrings and typehints for utils method
Browse files Browse the repository at this point in the history
  • Loading branch information
raymondng76 committed Dec 14, 2021
1 parent ffc7d4d commit be7be52
Showing 1 changed file with 26 additions and 5 deletions.
31 changes: 26 additions & 5 deletions sgnlp/models/sentic_asgcn/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import argparse
import json
from logging import error
import math
from pickle import load
import pickle
import random
import pathlib
from typing import Dict
from typing import Dict, Iterable, List

import numpy as np
import torch
Expand Down Expand Up @@ -115,6 +113,10 @@ def build_embedding_matrix(


class BucketIterator(object):
"""
Bucket iterator class which provides sorting and padding for input dataset, iterate thru dataset batches
"""

def __init__(
self, data, batch_size, sort_key="text_indices", shuffle=True, sort=True
):
Expand All @@ -124,7 +126,17 @@ def __init__(
self.batches = self.sort_and_pad(data, batch_size)
self.batch_len = len(self.batches)

def sort_and_pad(self, data, batch_size):
def sort_and_pad(self, data, batch_size: int) -> List[Dict[str, torch.tensor]]:
"""
Class method to sort and pad data batches
Args:
data ([type]): input data
batch_size (int): batch size
Returns:
List[Dict[str, torch.tensor]]: return a list of dictionaries of tensors
"""
num_batch = int(math.ceil(len(data) / batch_size))
sorted_data = (
sorted(data, key=lambda x: len(x[self.sort_key])) if self.sort else data
Expand All @@ -135,7 +147,16 @@ def sort_and_pad(self, data, batch_size):
]
return batches

def pad_data(self, batch_data):
def pad_data(self, batch_data: Iterable) -> Dict[str, torch.tensor]:
"""
Class method to pad data batches
Args:
batch_data (Iterable): An iterable for looping thru input dataset
Returns:
Dict[str, torch.tensor]: return dictionary of tensors from data batches
"""
batch_text_indices = []
batch_context_indices = []
batch_aspect_indices = []
Expand Down

0 comments on commit be7be52

Please sign in to comment.