# 数据导入

### 删去中文列

In [1]:
import pandas as pd
import numpy as np

train = pd.read_csv('./data/train.csv')

del train['title1_zh']
del train['title2_zh']

train.head()

Unnamed: 0,id,tid1,tid2,title1_en,title2_en,label
0,0,0,1,There are two new old-age insurance benefits f...,"Police disprove ""bird's nest congress each per...",unrelated
1,3,2,3,"""If you do not come to Shenzhen, sooner or lat...",Shenzhen's GDP outstrips Hong Kong? Shenzhen S...,unrelated
2,1,2,4,"""If you do not come to Shenzhen, sooner or lat...",The GDP overtopped Hong Kong? Shenzhen clarifi...,unrelated
3,2,2,5,"""If you do not come to Shenzhen, sooner or lat...",Shenzhen's GDP topped Hong Kong last year? She...,unrelated
4,9,6,7,"""How to discriminate oil from gutter oil by me...",It took 30 years of cooking oil to know that o...,agreed


In [2]:
test = pd.read_csv('./data/test.csv')

del test['title1_zh']
del test['title2_zh']

test.head()

Unnamed: 0,id,tid1,tid2,title1_en,title2_en
0,321187,167562,59521,egypt 's presidential election failed to win m...,Lyon! Lyon officials have denied that Felipe F...
1,321190,167564,91315,A message from Saddam Hussein after he was cap...,The Top 10 Americans believe that the Lizard M...
2,321189,167563,167564,Will the United States wage war on Iraq withou...,A message from Saddam Hussein after he was cap...
3,321193,167564,160994,A message from Saddam Hussein after he was cap...,The hanging Saddam is a surrogate? This man's ...
4,321191,167564,15084,A message from Saddam Hussein after he was cap...,Chinese loquat loquat plaster in America? Pure...


In [3]:
solution = pd.read_csv('./data/solution.csv')

solution.head()

Unnamed: 0,Id,Expected,Weight,Usage
0,347448,unrelated,0.0625,Public
1,347449,unrelated,0.0625,Private
2,359100,unrelated,0.0625,Public
3,359101,unrelated,0.0625,Private
4,359102,unrelated,0.0625,Private


# 数据预处理

### 特殊值处理

* 特殊符号

In [4]:
# 若字符串中全是特殊符号，则删除该行

import re

def is_special(s, threshold=0.4):
    non_alnum_chars = re.findall(r'[^a-zA-Z0-9\s]', s)
    non_alnum_ratio = len(non_alnum_chars) / len(s)
    
    return non_alnum_ratio > threshold

special_1 = train['title1_en'].apply(is_special)
special_2 = train['title2_en'].apply(is_special)

train = train[~special_1 & ~special_2]

* 重复值

In [5]:
# 若字符串中有重复10次以上的单词，词组，或是长串连续字符则删除该行

def is_repeated(s, min_repeats=6):
    char_pattern = r'(.)\1{' + str(min_repeats - 1) + ',}'
    phase_pattern = r'\b(\w+\s?\w*)\b(?:\W+\1\b){' + str(min_repeats - 1) + ',}'
    word_pattern = r'\b(\w+)\b(?:.*?\b\1\b){' + str(min_repeats - 1) + ',}'
    return bool(re.search(r'\b(\w+\s?\w*)\b(?:\W+\1\b){' + str(min_repeats) + ',}', s)) or bool(re.search(char_pattern, s)) or bool(re.search(phase_pattern, s)) or bool(re.search(word_pattern, s))

repeated_1 = train['title1_en'].apply(is_repeated)
repeated_2 = train['title2_en'].apply(is_repeated)

train = train[~repeated_1 & ~repeated_2]

* UNK

In [6]:
# 发现数据中有个别行中存在大量“UNK”，大概是由于使用模型翻译时词典中找不到适配的词汇导致的
# 为了防止这些样本对模型训练产生干扰，我们可以将这些样本所在行删除
# 样本中出现多于8个“UNK”的行将被删除

mask_1 = train['title1_en'].apply(lambda x: x.lower().split().count('unk') > 8)
mask_2 = train['title2_en'].apply(lambda x: x.lower().split().count('unk') > 8)

# 删除满足条件的行
train = train[~mask_1 & ~mask_2]

In [7]:
lengths_1 = train['title1_en'].apply(len)
max_1 = lengths_1.max()
index1 = lengths_1.idxmax()
print(train.iloc[index1]['title1_en'])
print(len(train.iloc[index1]['title1_en']))
print(len(train[train['id']==index1]['title1_en']))
print(train[train['id']==index1]['title1_en'])
print(max_1)

and saw a UFO.
14
1
38559    The "new changes in college entrance examinati...
Name: title1_en, dtype: object
502


In [9]:
print(len(train['title1_en'][index1]), train['title1_en'][index1])
print(train[train['id']==115330]['title1_en'])

502 After the incident of Gao Yun-cheung's sexual assault scandal broke out, his co-starring with Bing Bing (formerly: Winning the World) was doomed. It had been reported that Tangde wanted to replace the remake, but later on the Tang side denied the cancellation of Gao Yun-xiang's relevant parts. “ A replica of Li Chen ” It's just a technical test. 
It has been reported that Gao Yun-cheung had previously signed a morality clause with the Ba-Qing Sect, citing the impact of the incident, Gao Yun-cheung
Series([], Name: title1_en, dtype: object)


### 文本清理

* 去除标点符号

In [10]:
def remove_punctuation(x):
    x = re.sub(r'[^\w\s]','',x)
    return x

* 转成小写

In [11]:
def to_lowercase(x):
    return x.lower()

* 去除停用词

In [12]:
# import nltk
# nltk.download('stopwords')

# from nltk.corpus import stopwords

# def remove_stopwords(text):
#     stop_words = set(stopwords.words('english'))
#     return ' '.join([word for word in text.split() if word not in stop_words])

* 去除多余空格

In [13]:
def remove_extra_spaces(text):
    return re.sub(r'\s+', ' ', text).strip()

* 整合步骤

In [14]:
def clean_text(text):
    text = remove_punctuation(text)
    text = to_lowercase(text)
    # text = remove_stopwords(text)
    text = remove_extra_spaces(text)
    return text

* 应用

In [15]:
train['title1_en'] = train['title1_en'].apply(clean_text)
train['title2_en'] = train['title2_en'].apply(clean_text)

train.head()

Unnamed: 0,id,tid1,tid2,title1_en,title2_en,label
0,0,0,1,there are two new oldage insurance benefits fo...,police disprove birds nest congress each perso...,unrelated
1,3,2,3,if you do not come to shenzhen sooner or later...,shenzhens gdp outstrips hong kong shenzhen sta...,unrelated
3,2,2,5,if you do not come to shenzhen sooner or later...,shenzhens gdp topped hong kong last year shenz...,unrelated
4,9,6,7,how to discriminate oil from gutter oil by mea...,it took 30 years of cooking oil to know that o...,agreed
5,4,2,8,if you do not come to shenzhen sooner or later...,shenzhens gdp overtakes hong kong bureau of st...,unrelated


In [16]:
test['title1_en'] = test['title1_en'].apply(clean_text)
test['title2_en'] = test['title2_en'].apply(clean_text)

test.head()

Unnamed: 0,id,tid1,tid2,title1_en,title2_en
0,321187,167562,59521,egypt s presidential election failed to win mi...,lyon lyon officials have denied that felipe fe...
1,321190,167564,91315,a message from saddam hussein after he was cap...,the top 10 americans believe that the lizard m...
2,321189,167563,167564,will the united states wage war on iraq withou...,a message from saddam hussein after he was cap...
3,321193,167564,160994,a message from saddam hussein after he was cap...,the hanging saddam is a surrogate this mans mo...
4,321191,167564,15084,a message from saddam hussein after he was cap...,chinese loquat loquat plaster in america pure ...


### 标签编码

In [17]:
from sklearn.preprocessing import LabelEncoder

'''
agreed = 0
disagreed = 1
unrelated = 2
'''

# 初始化 LabelEncoder
label_encoder = LabelEncoder()

train['label_encoded'] = label_encoder.fit_transform(train['label'])

train.head()

Unnamed: 0,id,tid1,tid2,title1_en,title2_en,label,label_encoded
0,0,0,1,there are two new oldage insurance benefits fo...,police disprove birds nest congress each perso...,unrelated,2
1,3,2,3,if you do not come to shenzhen sooner or later...,shenzhens gdp outstrips hong kong shenzhen sta...,unrelated,2
3,2,2,5,if you do not come to shenzhen sooner or later...,shenzhens gdp topped hong kong last year shenz...,unrelated,2
4,9,6,7,how to discriminate oil from gutter oil by mea...,it took 30 years of cooking oil to know that o...,agreed,0
5,4,2,8,if you do not come to shenzhen sooner or later...,shenzhens gdp overtakes hong kong bureau of st...,unrelated,2


In [18]:
solution['Expected'] = label_encoder.fit_transform(solution['Expected'])

solution.head()

Unnamed: 0,Id,Expected,Weight,Usage
0,347448,2,0.0625,Public
1,347449,2,0.0625,Private
2,359100,2,0.0625,Public
3,359101,2,0.0625,Private
4,359102,2,0.0625,Private


### 文本向量化

* 加载数据集（字典化）

In [30]:
from torch.utils.data import Dataset

class AFQMC(Dataset):
    def __init__(self, data_file):
        self.data = self.load_data(data_file)
    
    def load_data(self, data_file):
        data_file = data_file.reset_index(drop=True)  # 重置索引
        Data = data_file.to_dict(orient='index')
        return Data
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if idx not in self.data:
            raise KeyError(f"Key {idx} not found in dataset")
        return self.data[idx]

train_dict = AFQMC(train)
test_dict = AFQMC(test)

print(train_dict[0])
print(test_dict[0])

{'id': 0, 'tid1': 0, 'tid2': 1, 'title1_en': 'there are two new oldage insurance benefits for old people in rural areas have you got them', 'title2_en': 'police disprove birds nest congress each person gets 50000 yuan still old people insist on going to beijing', 'label': 'unrelated', 'label_encoded': 2}
{'id': 321187, 'tid1': 167562, 'tid2': 59521, 'title1_en': 'egypt s presidential election failed to win millions of votes in egypt s presidential election', 'title2_en': 'lyon lyon officials have denied that felipe federico has joined liverpool is it true that the price has not been agreed'}


In [20]:
# 如果数据集非常巨大，难以一次性加载到内存中，我们也可以继承 IterableDataset 类构建迭代型数据集

# from torch.utils.data import IterableDataset
# import json

# class IterableAFQMC(IterableDataset):
#     def __init__(self, data_file):
#         self.data_file = data_file

#     def __iter__(self):
#         df = self.data_file
#         for _, row in df.iterrows():
#             sample = row.to_dict()
#             yield sample


# try:
#     train_dict = IterableAFQMC(train)
#     print(next(iter(train_dict)))
# except Exception as e:
#     print(f"Error: {e}")

In [31]:
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")

def collote_fn(batch_samples):
    try:
        batch_sentence_1, batch_sentence_2 = [], []
        batch_label = []
        # for i in range(len(batch_samples)):
        #     batch_sentence_1.append(batch_samples.iloc[i]['title1_en'])
        #     batch_sentence_2.append(batch_samples.iloc[i]['title2_en'])
        #     batch_label.append(int(batch_samples.iloc[i]['label_encoded']))
        for sample in batch_samples:
            batch_sentence_1.append(sample['title1_en'])
            batch_sentence_2.append(sample['title2_en'])
            batch_label.append(int(sample['label_encoded']))
        X = tokenizer(
            batch_sentence_1, 
            batch_sentence_2, 
            padding=True, 
            truncation=True, 
            return_tensors="pt"
        )
        y = torch.tensor(batch_label)
        return X, y
    except Exception as e:
        print(f"Error in collote_fn: {e}")
        raise

train_dataloader = DataLoader(train_dict, batch_size=4, collate_fn=collote_fn)

batch_X, batch_y = next(iter(train_dataloader))
print('batch_X shape:', {k: v.shape for k, v in batch_X.items()})
print('batch_y shape:', batch_y.shape)
print(batch_X)
print(batch_y)

batch_X shape: {'input_ids': torch.Size([4, 66]), 'token_type_ids': torch.Size([4, 66]), 'attention_mask': torch.Size([4, 66])}
batch_y shape: torch.Size([4])
{'input_ids': tensor([[  101,  1175,  1132,  1160,  1207,  1385,  2553,  5986,  6245,  1111,
          1385,  1234,  1107,  3738,  1877,  1138,  1128,  1400,  1172,   102,
          2021,  4267, 20080, 24157,  4939, 10175, 16821,  1296,  1825,  3370,
         13837,  1568,   194,  8734,  1253,  1385,  1234, 19831,  1113,  1280,
          1106,  1129, 23784,  2118,   102,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0],
        [  101,  1191,  1128,  1202,  1136,  1435,  1106,  1131, 19411, 10436,
         10639,  1137,  1224,  1240,  1488,  1209,  1145,  1435,  1107,  1750,
          1190,  1275,  1201,  1131, 19411, 10436,  1679,  8008,   176,  1181,
          1643,  1209, 13908, 16358,  2118,   180,  4553,   102,

# 训练模型