-
Notifications
You must be signed in to change notification settings - Fork 3
/
data.py
171 lines (146 loc) · 6.83 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# Code adapted from https://github.com/IST-DASLab/sparsegpt/blob/master/datautils.py
import numpy as np
import random
import torch
from datasets import load_dataset
# Set random seed for reproducibility
def set_seed(seed):
"""
Set the random seed for NumPy and PyTorch for reproducibility.
Args:
seed (int): The random seed.
"""
np.random.seed(seed)
torch.random.manual_seed(seed)
# Wrapper class for tokenized input IDs
class TokenizerWrapper:
"""
Wrapper class for tokenized input IDs.
Args:
input_ids (tensor): The tokenized input IDs from the tokenizer.
"""
def __init__(self, input_ids):
self.input_ids = input_ids
# Load and process PTB (Penn Treebank) dataset
def get_ptb(nsamples, seed, seqlen, tokenizer):
"""
Load and process PTB (Penn Treebank) dataset.
Args:
nsamples (int): Number of samples to extract.
seed (int): Random seed for reproducibility.
seqlen (int): Sequence length for each sample.
tokenizer (Tokenizer): Tokenizer to use for text encoding.
Returns:
tuple: A tuple containing trainloader (list of input and target pairs) and encoded test set.
"""
# Load train and test datasets
traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
testdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation')
# Encode datasets
trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt')
testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')
# Generate samples from training set using random seed and specified sequence length
random.seed(seed)
trainloader = []
for _ in range(nsamples):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
return trainloader, testenc
# Load and process wikitext2 dataset
def get_wikitext2(nsamples, seed, seqlen, tokenizer):
"""
Load and process the Wikitext-2 dataset.
Args:
nsamples (int): Number of samples to generate from the training set.
seed (int): Random seed for reproducibility.
seqlen (int): Sequence length for generated samples.
tokenizer (Tokenizer): Tokenizer instance for encoding texts.
Returns:
tuple: A tuple containing trainloader (list of input and target pairs) and encoded test dataset.
"""
# Load train and test datasets
# traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
# testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
traindata = load_dataset('text', data_files='datasets/wikitext/wiki.train.raw', split="train")
testdata = load_dataset('text', data_files='datasets/wikitext/wiki.test.raw', split="train")
# Encode datasets
trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt')
testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')
# Generate samples from training set using random seed and specified sequence length
random.seed(seed)
trainloader = []
for _ in range(nsamples):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
return trainloader, testenc
# Load and process C4 (Common Crawl) dataset
def get_c4(nsamples, seed, seqlen, tokenizer):
"""
Load and process the C4 (Common Crawl) dataset.
Args:
nsamples (int): Number of samples to generate from the training set.
seed (int): Random seed for reproducibility.
seqlen (int): Sequence length for generated samples.
tokenizer (Tokenizer): Tokenizer instance for encoding texts.
Returns:
tuple: A tuple containing trainloader (list of input and target pairs) and encoded validation dataset.
"""
# Load train and validation datasets
traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train')
valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')
# traindata = load_dataset('json', data_files={'train': 'datasets/c4/c4-train.00000-of-01024.json.gz'}, split='train')
# valdata = load_dataset('json', data_files={'validation': 'datasets/c4/c4-validation.00000-of-00008.json.gz'}, split='validation')
# Generate samples from training set using random seed and specified sequence length
random.seed(seed)
trainloader = []
for _ in range(nsamples):
while True:
i = random.randint(0, len(traindata) - 1)
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
if trainenc.input_ids.shape[1] > seqlen:
break
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
# Prepare validation dataset
valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
valenc = valenc.input_ids[:, :(256 * seqlen)]
valenc = TokenizerWrapper(valenc)
return trainloader, valenc
# Function to select the appropriate loader based on dataset name
def get_loaders(name='wikitext2', nsamples=128, seed=0, seqlen=2048, tokenizer=None):
"""
Select the appropriate loader based on dataset name.
Args:
name (str): The name of the dataset ('wikitext2', 'c4', or 'ptb').
nsamples (int): Number of samples to generate from the training set.
seed (int): Random seed for reproducibility.
seqlen (int): Sequence length for generated samples.
tokenizer (Tokenizer): Tokenizer instance for encoding texts.
Returns:
tuple: A tuple containing trainloader (list of input and target pairs) and encoded validation/test set.
"""
# Determine which dataset to use based on 'name' parameter and return corresponding loader
if 'wikitext2' in name:
return get_wikitext2(nsamples, seed, seqlen, tokenizer)
elif "c4" in name:
return get_c4(nsamples, seed, seqlen, tokenizer)
elif "ptb" in name:
return get_ptb(nsamples, seed, seqlen, tokenizer)
if __name__ == "__main__":
get_loaders('wikitext2', seed=0, seqlen=2048, tokenizer=None)
# Note:
# This script is designed to load and process various text datasets for training language models.
# It includes functions to load PTB (Penn Treebank), Wikitext-2, and C4 (Common Crawl) datasets.
# Each loading function returns a trainloader (list of input and target pairs) and encoded validation/test set.