In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from time import time
import os

In [2]:
from sklearn.linear_model import SGDClassifier as SC, LogisticRegression as LR
from sklearn.model_selection import train_test_split as tts
from sklearn.metrics import f1_score
from sklearn.utils import shuffle

import fasttext.FastText as ft
import joblib

## 1. 数据读取

In [3]:
train_file_path = '../input/news-classification/Data/train_set.csv'
train_df = pd.read_csv(train_file_path, sep='\t')
train_df, valid_df = tts(train_df, test_size=0.15)
print(train_df.shape, valid_df.shape)

(170000, 2) (30000, 2)


## 2. 格式转换

fasttext对模型的输入有格式要求，因此在模型训练和预测时需要将数据转换为符合要求的格式

In [4]:
def text_prepare(file, data):
    with open(file, 'w', encoding='utf-8') as f:
        for i in range(len(data)):
            review = data.iloc[i]['text']
            cat = '__label__' + str(data.iloc[i]['label']) + ' , '
            f.write(cat+review+'\n')

In [5]:
t0 = time()

train_file = './train_data.txt'
valid_file = './valid_data.txt'
text_prepare(train_file, train_df)
text_prepare(valid_file, valid_df)
del train_df
del valid_df

print('saving complete')
print('processing time(s):',time()-t0)

saving complete
processing time(s): 68.53736734390259


## 3. 模型训练

将训练集数据对模型进行训练，并在验证集上进行验证

In [6]:
t0 = time()
model = ft.train_supervised(train_file, loss='softmax')
print('Processing time(s):', time()-t0)
print(model.test(valid_file))

Processing time(s): 70.1198456287384
(30000, 0.9182333333333333, 0.9182333333333333)


## 4. 模型输出

对参赛的数据进行预测并输出保存

In [7]:
test_file_path = '../input/news-classification/Data/test_a.csv'
test_df = pd.read_csv(test_file_path, sep='\t')
print(test_df.shape)

(50000, 1)


In [8]:
t0 = time()
test_txt = list(test_df['text'])
result = model.predict(test_txt)[0]
print(len(result))
print('Processing time(s):', time()-t0)

50000
Processing time(s): 5.276782035827637


In [9]:
result = pd.DataFrame(columns=['label'], data=result)
result['label'] = result['label'].apply(lambda x: x[-1])
result.to_csv('predict_a.csv', sep='\t', index=False)
print('saving complete')

saving complete
