In [None]:
!pip install datasets

In [7]:
from datasets import load_dataset

ds = load_dataset("Kwaai/IMDB_Sentiment")

In [8]:
print(ds)

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})


In [14]:
from datasets import concatenate_datasets, Dataset
import pandas as pd

# combine train + test a full dataset
full_dataset = concatenate_datasets([ds['train'], ds['test']])  # 50,000

# count all
total = len(full_dataset)

# count unique text
unique = len(set(full_dataset['text']))

print(f"Number of total text: {total}")
print(f"Number of unique text: {unique}")
print(f"Is there any duplicate text? : {total != unique}")

Number of total text: 50000
Number of unique text: 49582
Is there any duplicate text? : True


In [15]:
df = full_dataset.to_pandas()

# unique
df_unique = df.drop_duplicates(subset="text")

# transfer to dataset
full_dataset = Dataset.from_pandas(df_unique)

# split dataset
split1 = full_dataset.train_test_split(test_size=0.3, seed=42)
train = split1['train']
temp = split1['test']

split2 = temp.train_test_split(test_size=1/3, seed=42)
validation = split2['train']
test = split2['test']


dataset = {
    "train": train,
    "validation": validation,
    "test": test
}

print("Train size:", len(dataset["train"]))
print("Validation size:", len(dataset["validation"]))
print("Test size:", len(dataset["test"]))

Train size: 34707
Validation size: 9916
Test size: 4959


In [17]:
import pandas as pd

df = pd.DataFrame(dataset["train"][:10])
print(df)

                                                text  label  __index_level_0__
0  Anarchy and lawlessness reign supreme in the p...      1              18135
1  Before I begin, a "little" correction: IMDb st...      0              28917
2  You know Jason, you know Freddy, and you know ...      0              10723
3  Creative use of modern and mystical elements: ...      1              15970
4  In the trivia section for Pet Sematary, it men...      1              23713
5  Despite a totally misleading advertising campa...      0              33491
6  Well, were to start? This is by far one of the...      0              12150
7  What's written on the poster is: "At birth he ...      0               8495
8  Many of the earlier comments are right on the ...      1              40256
9  i love this show. i hate when it goes to seaso...      1              38548


In [19]:
dataset['train'].to_csv("imdb_train.csv")
dataset['validation'].to_csv("imdb_validation.csv")
dataset['test'].to_csv("imdb_test.csv")

Creating CSV from Arrow format:   0%|          | 0/35 [00:00<?, ?ba/s]

Creating CSV from Arrow format:   0%|          | 0/10 [00:00<?, ?ba/s]

Creating CSV from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

6669770

In [None]:
import pandas as pd

# check duplicate text
def check_data_leakage(train_path, val_path, test_path, text_column="text"):
    df_train = pd.read_csv(train_path)
    df_val = pd.read_csv(val_path)
    df_test = pd.read_csv(test_path)

    train_texts = set(df_train[text_column].dropna().unique())
    val_texts = set(df_val[text_column].dropna().unique())
    test_texts = set(df_test[text_column].dropna().unique())

    # union check
    train_val_overlap = train_texts & val_texts
    train_test_overlap = train_texts & test_texts
    val_test_overlap = val_texts & test_texts

    # result
    print(f"Train ∩ Validation: {len(train_val_overlap)} overlapping samples")
    print(f"Train ∩ Test:       {len(train_test_overlap)} overlapping samples")
    print(f"Validation ∩ Test:  {len(val_test_overlap)} overlapping samples")

    if train_test_overlap:
        print("\n🚨 示例重复文本（Train ∩ Test）:")
        for i, text in enumerate(list(train_test_overlap)[:5]):
            print(f"- {text[:100]}...")  # 只显示前 80 字符

    return {
        "train_val_overlap": train_val_overlap,
        "train_test_overlap": train_test_overlap,
        "val_test_overlap": val_test_overlap
    }

check_data_leakage(
    train_path="imdb_train.csv",
    val_path="imdb_validation.csv",
    test_path="imdb_test.csv"
)


Train ∩ Validation: 0 overlapping samples
Train ∩ Test:       0 overlapping samples
Validation ∩ Test:  0 overlapping samples


{'train_val_overlap': set(),
 'train_test_overlap': set(),
 'val_test_overlap': set()}