Skip to content

Commit

Permalink
[#41] add draft ABSADatasetReader class
Browse files Browse the repository at this point in the history
  • Loading branch information
raymondng76 committed Dec 14, 2021
1 parent be7be52 commit ed89c1d
Showing 1 changed file with 70 additions and 0 deletions.
70 changes: 70 additions & 0 deletions sgnlp/models/sentic_asgcn/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import json
from logging import error
import math
import pickle
import random
Expand All @@ -8,6 +9,7 @@

import numpy as np
import torch
from transformers import PreTrainedTokenizer

from data_class import SenticASGCNTrainArgs

Expand Down Expand Up @@ -220,3 +222,71 @@ def __iter__(self):
random.shuffle(self.batches)
for idx in range(self.batch_len):
yield self.batches[idx]


class ABSADataset(object):
"""
Data class to hold dataset for training.
"""

def __init__(self, data):
self.data = data

def __getitem__(self, index):
return self.data[index]

def __len__(self):
return len(self.data)


class ABSADatasetReader:
def __init__(
self,
dataset_file_names: List[str],
tokenizer: PreTrainedTokenizer,
embed_dim: int = 300,
):
self.embed_dim = embed_dim
self.tokenizer = tokenizer
# TODO: Figure out how to include the embedding matrix here
# self.embedding_matrix = build_embedding_matrix()

@staticmethod
def __read_data__(file_name: str, tokenizer: PreTrainedTokenizer):
# Read raw data, graph data and tree data
with open(
file_name, "r", encoding="utf-8", newline="\n", errors="ignore"
) as fin:
lines = fin.readlines()
with open(f"{file_name}.graph", "rb") as fin_graph:
idx2graph = pickle.load(fin_graph)
with open(f"{file_name}.tree", "rb") as fin_tree:
idx2tree = pickle.load(fin_tree)

# Prep all data
all_data = []
for i in range(0, len(lines), 3):
text_left, _, text_right = [
s.lower().strip() for s in lines[i].partition("$T$")
]
aspect = lines[i + 1].lower().strip()
polarity = lines[i + 2].lower().strip()
text_indices = tokenizer(f"{text_left} {aspect} {text_right}")
context_indices = tokenizer(f"{text_left} {text_right}")
aspect_indices = tokenizer(aspect)
left_indices = tokenizer(text_left)
polarity = int(polarity) + 1
dependency_graph = idx2graph[i]
dependency_tree = idx2tree[i]

data = {
"text_indices": text_indices,
"context_indices": context_indices,
"aspect_indices": aspect_indices,
"left_indices": left_indices,
"polarity": polarity,
"dependency_graph": dependency_graph,
"dependency_tree": dependency_tree,
}
all_data.append(data)
return all_data

0 comments on commit ed89c1d

Please sign in to comment.