In [65]:
from datasets import load_dataset, load_from_disk, Dataset

In [None]:
datasets = load_dataset("madao33/new-title-chinese")
print(datasets)
print(len(datasets["train"]))
print(len(datasets["validation"]))

In [None]:
boolq_dataset = load_dataset("super_glue", "boolq")
boolq_dataset["train"].info

In [None]:
datasets = load_dataset("madao33/new-title-chinese", split=["train[:90%]", "validation[90%:]"])
print(datasets)
print(len(datasets))

In [None]:
datasets = load_dataset("madao33/new-title-chinese", split="train")

In [None]:
print(datasets.column_names)
print(datasets.features)

print(datasets[:5])
print(datasets["title"])
print(datasets["title"][:5])

In [None]:
datasets = boolq_dataset
datasets = datasets["train"]

In [None]:
len(datasets)

In [None]:
datasets.column_names

In [None]:
datasets.features

In [None]:
datasets = datasets.train_test_split(test_size=0.1, stratify_by_column="label")

In [None]:
datasets

In [None]:
type(datasets["train"])

In [None]:
datasets["train"].select([0, 1])

In [None]:
datasets["train"].select(range(5))

In [39]:
datasets = load_dataset("madao33/new-title-chinese")
datasets

DatasetDict({
    train: Dataset({
        features: ['title', 'content'],
        num_rows: 5850
    })
    validation: Dataset({
        features: ['title', 'content'],
        num_rows: 1679
    })
})

In [40]:
datasets["train"].column_names

['title', 'content']

In [41]:
datasets["train"].filter(lambda example: "中国" in example["title"])

Filter:   0%|          | 0/5850 [00:00<?, ? examples/s]

Dataset({
    features: ['title', 'content'],
    num_rows: 544
})

In [78]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")

In [59]:
def collate_fn(examples, tokenizer=tokenizer):
    outputs = tokenizer(examples["content"], add_special_tokens=True, return_tensors="pt", max_length=512, truncation=True, padding="max_length")
    outputs["labels"] = tokenizer(examples["title"])["input_ids"]
    return outputs


processed_datasets = datasets["train"].filter(lambda example: "中国" in example["title"]).map(collate_fn, batched=True, remove_columns=datasets["train"].column_names)

Filter:   0%|          | 0/5850 [00:00<?, ? examples/s]

Map:   0%|          | 0/544 [00:00<?, ? examples/s]

In [61]:
processed_datasets.save_to_disk("./processed_data")

Saving the dataset (0/1 shards):   0%|          | 0/544 [00:00<?, ? examples/s]

In [62]:
processed_datasets = load_from_disk("./processed_data")

In [63]:
processed_datasets

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
    num_rows: 544
})

In [69]:
datasets = Dataset.from_csv("ChnSentiCorp_htl_all.csv")
print(datasets)
print(len(datasets))

Dataset({
    features: ['label', 'review'],
    num_rows: 7766
})
7766


In [70]:
datasets = load_dataset("csv", data_dir="ChnSentiCorp_htl", split="train")
print(datasets)
print(len(datasets))

Dataset({
    features: ['label', 'review'],
    num_rows: 23298
})
23298


In [85]:
datasets = load_dataset("csv", data_files="ChnSentiCorp_htl_all.csv", split="train")
print(len(datasets))
datasets = datasets.filter(lambda example: example["review"] is not None)
print(len(datasets))
print(datasets[0])

7766
7765
{'label': 1, 'review': '距离川沙公路较近,但是公交指示不对,如果是"蔡陆线"的话,会非常麻烦.建议用别的路线.房间较为简单.'}


In [91]:
import torch
def process_function(examples, tokenizer=tokenizer):
    outputs = tokenizer(examples["review"], max_length=512, truncation=True)
    outputs["labels"] = torch.tensor(examples["label"])
    return outputs

datasets = datasets.map(process_function, batched=True, remove_columns=datasets.column_names)
datasets

Map:   0%|          | 0/7765 [00:00<?, ? examples/s]

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
    num_rows: 7765
})

In [93]:
from transformers import DataCollatorWithPadding
collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [94]:
from torch.utils.data import DataLoader
dataLoader = DataLoader(datasets, shuffle=True, batch_size=4, collate_fn=collator)

In [96]:
for index, data in enumerate(dataLoader):
    print(data["input_ids"].shape)
    if index == 3:
        break

torch.Size([4, 327])
torch.Size([4, 137])
torch.Size([4, 232])
torch.Size([4, 512])
