-
Notifications
You must be signed in to change notification settings - Fork 0
/
corpusmaker.py
73 lines (55 loc) · 1.99 KB
/
corpusmaker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torchtext
import random
import json
class TextCorpus:
"""
A class to store text corpus.
Attributes:
- `raw`: list[str], raw text data
- `name`: str, name of the corpus
Methods:
- `random_split`: randomly split the raw data into train and test set (split ratio is 8:2 by default)
- `raw_number`: return the number of raw data
"""
def __init__(self):
self.raw: list[str] = []
self.name: str = None
def random_split(self, test_rate: float=0.2) -> tuple[list[str], list[str]]:
# randomly shuffle self.raw and split it into train and test
shuffled = self.raw.copy()
random.shuffle(shuffled)
split_idx = int(len(shuffled) * (1 - test_rate))
train = shuffled[:split_idx]
test = shuffled[split_idx:]
return train, test
def raw_number(self):
return len(self.raw)
def IMDB(num_total: int):
"""
Randomly select `num_total` samples from IMDB dataset, and return it as TextCorpus object.
Args:
- num_total: int, number of samples to select
Returns:
- text_corpus: TextCorpus
"""
train_iter, _ = torchtext.datasets.IMDB()
train_iter = list(train_iter)
if num_total > len(train_iter):
raise ValueError(f'num_total should be less than {len(train_iter)}')
random.shuffle(train_iter)
text_corpus = TextCorpus()
text_corpus.raw = [item[1] for item in train_iter[:num_total]]
text_corpus.name = 'IMDB'
return text_corpus
def Pile(pile_dataset_path: str, num_total: int):
train_iter = []
with open(pile_dataset_path, 'r', encoding='utf-8') as file:
datas = json.load(file)
train_iter = [data['text'] for data in datas]
if num_total > len(train_iter):
raise ValueError(f'num_total should be less than {len(train_iter)}')
random.shuffle(train_iter)
text_corpus = TextCorpus()
text_corpus.raw = [item for item in train_iter[:num_total]]
text_corpus.name = 'Pile'
return text_corpus