In [None]:
!pip install datasets

In [3]:
from datasets import load_dataset

ds = load_dataset("Kwaai/IMDB_Sentiment")

In [4]:
print(ds)

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})


In [11]:
from datasets import concatenate_datasets, Dataset
import pandas as pd

# combine train + test a full dataset
full_dataset = concatenate_datasets([ds['train'], ds['test']])  # 50,000

# count all
total = len(full_dataset)
total_train = len(ds['train'])
total_test = len(ds['test'])

# count unique text
unique = len(set(full_dataset['text']))
unique_train = len(set(ds['train']['text']))
unique_test = len(set(ds['test']['text']))

# duplicate count
du_total = total - unique
du_train = total_train - unique_train
du_test = total_test - unique_test

print("Original train set:")
print(f"Number of total text: {total_train}")
print(f"Number of unique text: {unique_train}")
print(f"Is there any duplicate text? : {total_train != unique_train} and Duplicate numbers: {du_train}")

print("")
print("Original test set:")
print(f"Number of total text: {total_test}")
print(f"Number of unique text: {unique_test}")
print(f"Is there any duplicate text? : {total_test != unique_test} and Duplicate numbers: {du_test}")
print("")

print("Total dataset:")
print(f"Number of total text: {total}")
print(f"Number of unique text: {unique}")
print(f"Is there any duplicate text? : {total != unique} and Duplicate numbers: {du_total}")


Original train set:
Number of total text: 25000
Number of unique text: 24904
Is there any duplicate text? : True and Duplicate numbers: 96

Original test set:
Number of total text: 25000
Number of unique text: 24801
Is there any duplicate text? : True and Duplicate numbers: 199

Total dataset:
Number of total text: 50000
Number of unique text: 49582
Is there any duplicate text? : True and Duplicate numbers: 418


In [15]:
df = full_dataset.to_pandas()

# unique
df_unique = df.drop_duplicates(subset="text")
df_unique = df.drop(columns=["__index_level_0__"])

# transfer to dataset
full_dataset = Dataset.from_pandas(df_unique)

# split dataset
split1 = full_dataset.train_test_split(test_size=0.3, seed=42)
train = split1['train']
temp = split1['test']

split2 = temp.train_test_split(test_size=1/3, seed=42)
validation = split2['train']
test = split2['test']


dataset = {
    "train": train,
    "validation": validation,
    "test": test
}

print("Train size:", len(dataset["train"]))
print("Validation size:", len(dataset["validation"]))
print("Test size:", len(dataset["test"]))
print("Total size:", len(dataset["train"]) + len(dataset["test"]) + len(dataset["validation"]))

Train size: 34707
Validation size: 9916
Test size: 4959
Total size: 49582


In [16]:
import pandas as pd

df = pd.DataFrame(dataset["train"][:10])
print(df)

                                                text  label
0  Anarchy and lawlessness reign supreme in the p...      1
1  Before I begin, a "little" correction: IMDb st...      0
2  You know Jason, you know Freddy, and you know ...      0
3  Creative use of modern and mystical elements: ...      1
4  In the trivia section for Pet Sematary, it men...      1
5  Despite a totally misleading advertising campa...      0
6  Well, were to start? This is by far one of the...      0
7  What's written on the poster is: "At birth he ...      0
8  Many of the earlier comments are right on the ...      1
9  i love this show. i hate when it goes to seaso...      1


In [17]:
dataset['train'].to_csv("imdb_train.csv")
dataset['validation'].to_csv("imdb_validation.csv")
dataset['test'].to_csv("imdb_test.csv")

Creating CSV from Arrow format:   0%|          | 0/35 [00:00<?, ?ba/s]

Creating CSV from Arrow format:   0%|          | 0/10 [00:00<?, ?ba/s]

Creating CSV from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

6641057

In [18]:
import pandas as pd

# check duplicate text
def check_data_leakage(train_path, val_path, test_path, text_column="text"):
    df_train = pd.read_csv(train_path)
    df_val = pd.read_csv(val_path)
    df_test = pd.read_csv(test_path)

    train_texts = set(df_train[text_column].dropna().unique())
    val_texts = set(df_val[text_column].dropna().unique())
    test_texts = set(df_test[text_column].dropna().unique())

    # union check
    train_val_overlap = train_texts & val_texts
    train_test_overlap = train_texts & test_texts
    val_test_overlap = val_texts & test_texts

    # result
    print(f"Train ∩ Validation: {len(train_val_overlap)} overlapping samples")
    print(f"Train ∩ Test:       {len(train_test_overlap)} overlapping samples")
    print(f"Validation ∩ Test:  {len(val_test_overlap)} overlapping samples")

    if train_test_overlap:
        print("\nSample duplicate text (Train ∩ Test):")
        for i, text in enumerate(list(train_test_overlap)[:5]):
            print(f"- {text[:100]}...")  # 只显示前 80 字符

    return {
        "train_val_overlap": train_val_overlap,
        "train_test_overlap": train_test_overlap,
        "val_test_overlap": val_test_overlap
    }

check_data_leakage(
    train_path="imdb_train.csv",
    val_path="imdb_validation.csv",
    test_path="imdb_test.csv"
)


Train ∩ Validation: 0 overlapping samples
Train ∩ Test:       0 overlapping samples
Validation ∩ Test:  0 overlapping samples


{'train_val_overlap': set(),
 'train_test_overlap': set(),
 'val_test_overlap': set()}