Skip to content

Commit

Permalink
add compact bin and bird animal example dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
dm-thu committed Aug 10, 2021
1 parent 13485ea commit 7b90b24
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 11 deletions.
4 changes: 3 additions & 1 deletion arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,9 @@ def add_data_args(parser):
group.add_argument('--dataset-type', type=str,
default='TokenizedDataset',
choices=['TokenizedDataset',
'TextCodeDataset'],
'TextCodeDataset',
'CompactBinaryDataset'
],
help='what type of dataset to use')

group.add_argument('--max-memory-length', type=int, default=2048,
Expand Down
4 changes: 2 additions & 2 deletions data_utils/configure_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def __init__(self, ds, **kwargs):
self.wrapped_data = ds

def __len__(self):
return len(self.wrapped_data) * 60
return len(self.wrapped_data) * 200

def __getitem__(self, index):
rng = random.Random(index)
Expand All @@ -301,7 +301,7 @@ def detect_new_datasets(args):
found = []
for _p in os.listdir(args.new_dataset_path):
p = os.path.join(args.new_dataset_path, _p)
if str(p).endswith('lmdb') and not str(os.path.abspath(p)) in current_datasets:
if (str(p).endswith('lmdb') or str(p).endswith('bin')) and not str(os.path.abspath(p)) in current_datasets:
found.append(p)
if len(found) == 0:
return None
Expand Down
33 changes: 32 additions & 1 deletion data_utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,26 @@ def __getitem__(self, idx):
row = pickle.loads(txn.get(key))

return self.process_fn(row)


class BinaryDataset(Dataset):
def __init__(self, path, process_fn, length_per_sample=64+1024, dtype='int32', preload=False, **kwargs):
assert length_per_sample is not None
self.length_per_sample = length_per_sample
self.dtype = np.dtype(dtype)
self.process_fn = process_fn
if preload:
self.bin = np.fromfile(path, dtype=self.dtype).reshape(-1, length_per_sample)
else:
with open(path, 'r') as fid:
nbytes = fid.seek(0, 2)
flen = fid.tell() // self.dtype.itemsize
self.bin = np.memmap(path, dtype=self.dtype, shape=(flen // length_per_sample, length_per_sample))

def __len__(self):
return self.bin.shape[0]

def __getitem__(self, index):
return self.process_fn(self.bin[index])

def get_dataset_by_type(dataset_type, path: str, args, DS_CLASS=LMDBDataset):

Expand Down Expand Up @@ -96,5 +115,17 @@ def process_fn(row):
return {'text': ret,
'loss_mask': np.array([1] * attention_mask_sep + [0] * (len(ret) - attention_mask_sep))
}

elif dataset_type == 'CompactBinaryDataset':
DS_CLASS = BinaryDataset
def process_fn(row):
text, code = row[:64].astype(np.int64), row[64:].astype(np.int64) # must 64 + 1024
text = text[text>-1]
ret = TextCodeTemplate(text, code)
ret, attention_mask_sep = pad_to_len(ret)
return {'text': ret,
'loss_mask': np.array([1] * attention_mask_sep + [0] * (len(ret) - attention_mask_sep))
}

return DS_CLASS(path, process_fn)

11 changes: 10 additions & 1 deletion data_utils/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,16 @@ def concat_codes(*codes):

def TextCodeTemplate(text, code):
tokenizer = get_tokenizer()
text_ids = [tokenizer['[ROI1]']] + tokenizer(text)
if isinstance(text, str):
text_ids = [tokenizer['[ROI1]']] + tokenizer(text)
else:
text_ids = np.concatenate(
(
np.array([tokenizer['[ROI1]']]),
text,
),
axis=0
)
code = tokenizer.wrap_code(code)
return concat_codes(text_ids, code)

Expand Down
7 changes: 5 additions & 2 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ wget https://cloud.tsinghua.edu.cn/f/71607a5dca69417baa8c/?dl=1 -O pretrained/vq
```
tar -xvf cogview-{base, sr, caption}.tar -C pretrained/cogview/
```
2. (Only for training tutorial, skip it for inference.) Download the Alibaba item-title image tokens dataset from our link at [Tianchi]()(*TODO*). Place the lmdb folder under `./data`.
2. (Only for training tutorial, skip it for inference.) Download a small "bird-and-animal" example dataset from our link at Tsinghua Cloud.
```
wget https://cloud.tsinghua.edu.cn/f/1e4963ec8ac84941ba68/?dl=1 -O data/bird_animal.bin
```

### Run CogView! (Model Inference)
We encapsulate the generation functions into scripts. See `generate_samples.py` and `arguments.py` for details.
Expand Down Expand Up @@ -95,7 +98,7 @@ The output is `{output_path}/scores.txt`, a line of a list of scores, following
Note: *In the released codes, for simplicity, we did not expose the raw API , which supports some advanced generation modes, e.g. text and part of image.*

## Training
Here we use a subset of our dataset from Alibaba item-title for tutorial.
Here we use a subset of our dataset from bird-and-animal for tutorial.
### Single Node
After downloading the dataset, directly run
```
Expand Down
6 changes: 3 additions & 3 deletions scripts/pretrain_single_node.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@ HOST_FILE_PATH="hostfile_single"

config_json="$script_dir/ds_config.json"
gpt_options=" \
--experiment-name cogview-ali_fashion_tutorial-12-1024-16 \
--experiment-name cogview-bird_animal_tutorial-12-1024-16 \
--img-tokenizer-num-tokens 8192 \
--dataset-type TokenizedDataset \
--dataset-type CompactBinaryDataset \
--model-parallel-size ${MP_SIZE} \
--num-layers 12 \
--hidden-size 1024 \
--num-attention-heads 16 \
--save $main_dir/data/checkpoints \
--train-iters 20000 \
--resume-dataloader \
--train-data ./data/ali_vqvae_hard_biggerset_011.lmdb \
--train-data ./data/bird_animal.bin \
--split 949,50,1 \
--distributed-backend nccl \
--lr-decay-style cosine \
Expand Down
2 changes: 1 addition & 1 deletion scripts/text2image.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

# ==== tutorial settings: =====
# CHECKPOINT_PATH=data/checkpoints/cogview-ali_fashion_tutorial-12-1024-1606-14-06-09
# CHECKPOINT_PATH=data/checkpoints/cogview-bird_animal_tutorial-12-1024-1608-10-09-38
# NLAYERS=12
# NHIDDEN=1024
# NATT=16
Expand Down

0 comments on commit 7b90b24

Please sign in to comment.