In [1]:
import logging
import random
import pandas as pd
import numpy as np
import torch

logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(levelname)s: %(message)s')

In [35]:
# set seed
seed = 789
random.seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)

# split data to 10 fold
fold_num = 10
data_file = './data/train_set.csv'

def all_data2fold(fold_num, data_file, num=10000):
    fold_data = []
    df = pd.read_csv(data_file, sep='\t', encoding='UTF-8')
    texts = df['text'].tolist()
    labels = df['label'].tolist()

    df = df.sample(n=num, random_state=seed)
    # Group by label and shuffle within groups
    label2data = df.groupby('label')

    # Distribute samples evenly across folds while maintaining class balance
    fold_texts = [[] for _ in range(fold_num)]
    fold_labels = [[] for _ in range(fold_num)]
    # print("---------flod_data---------",fold_texts)
    for label, group in label2data:
        indices = np.array_split(group.index.values, fold_num)
        for i, idx in enumerate(indices):
            # print("---------i, idx---------",i, idx)
            fold_texts[i].extend(df.loc[idx, 'text'].tolist())
            fold_labels[i].extend(df.loc[idx, 'label'].tolist())

    # Shuffle each fold to ensure randomness
    for i in range(fold_num):
        combined = list(zip(fold_texts[i], fold_labels[i]))
        np.random.shuffle(combined)
        fold_texts[i], fold_labels[i] = zip(*combined)

        fold_data.append({'label': list(fold_labels[i]), 'text': list(fold_texts[i])})

    logging.info("Fold lens %s", str([len(data['label']) for data in fold_data]))

    return fold_data

In [52]:
fold_data = all_data2fold(10,data_file,200000)

2025-01-21 16:23:10,909 INFO: Fold lens [20007, 20004, 20003, 20002, 20002, 19999, 19998, 19997, 19994, 19994]


In [54]:
# build train data for word2vec
fold_id = 9

train_texts = []
for i in range(0, fold_id):
    data = fold_data[i]
    train_texts.extend(data['text'])

logging.info('Total %d docs.' % len(train_texts))

2025-01-21 16:23:13,137 INFO: Total 180006 docs.


In [55]:
logging.info('Start training...')
from gensim.models.word2vec import Word2Vec

num_features = 128     # Word vector dimensionality
num_workers = 8       # Number of threads to run in parallel

train_texts = list(map(lambda x: list(x.split()), train_texts))
model = Word2Vec(train_texts, workers=num_workers, vector_size=num_features)
model.init_sims(replace=True)

# save model
model.save("./word2vec.bin")

2025-01-21 16:23:16,148 INFO: Start training...
2025-01-21 16:24:12,511 INFO: collecting all words and their counts
2025-01-21 16:24:12,536 INFO: PROGRESS: at sentence #0, processed 0 words, keeping 0 word types
2025-01-21 16:24:16,563 INFO: PROGRESS: at sentence #10000, processed 9174040 words, keeping 5326 word types
2025-01-21 16:24:19,521 INFO: PROGRESS: at sentence #20000, processed 18142528 words, keeping 5671 word types
2025-01-21 16:24:23,043 INFO: PROGRESS: at sentence #30000, processed 27427548 words, keeping 5887 word types
2025-01-21 16:24:25,974 INFO: PROGRESS: at sentence #40000, processed 36464668 words, keeping 6024 word types
2025-01-21 16:24:28,950 INFO: PROGRESS: at sentence #50000, processed 45535452 words, keeping 6161 word types
2025-01-21 16:24:31,827 INFO: PROGRESS: at sentence #60000, processed 54537037 words, keeping 6282 word types
2025-01-21 16:24:35,018 INFO: PROGRESS: at sentence #70000, processed 63666411 words, keeping 6343 word types
2025-01-21 16:24:37

In [56]:
# load model
model = Word2Vec.load("./word2vec.bin")

# convert format
model.wv.save_word2vec_format('./word2vec.txt', binary=False)

2025-01-21 16:32:22,823 INFO: loading Word2Vec object from ./word2vec.bin
2025-01-21 16:32:22,850 INFO: loading wv recursively from ./word2vec.bin.wv.* with mmap=None
2025-01-21 16:32:22,853 INFO: setting ignored attribute cum_table to None
2025-01-21 16:32:22,900 INFO: Word2Vec lifecycle event {'fname': './word2vec.bin', 'datetime': '2025-01-21T16:32:22.900913', 'gensim': '4.3.3', 'python': '3.12.8 | packaged by Anaconda, Inc. | (main, Dec 11 2024, 16:48:34) [MSC v.1929 64 bit (AMD64)]', 'platform': 'Windows-10-10.0.19045-SP0', 'event': 'loaded'}
2025-01-21 16:32:22,907 INFO: storing 5971x128 projection weights into ./word2vec.txt
