In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import feather
import pickle

In [2]:
with open('topics.bin', 'rb') as f:
    all_topics, selected_topics = pickle.load(f)

In [3]:
df_train = feather.read_dataframe('df_train.feather')
df_val = feather.read_dataframe('df_val.feather')
df_test = feather.read_dataframe('df_test.feather')

In [4]:
def get_y(df, topics):
    topic_idx = {t: i for (i, t) in enumerate(topics)}
    y = np.zeros((len(df), len(topics)), dtype='uint8')

    for idx, topics in enumerate(df.topics):
        for t in topics.split(','):
            if t in topic_idx:
                y[idx, topic_idx[t]] = 1
    return y

In [6]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.svm import LinearSVC
from sklearn.metrics import log_loss, f1_score

In [7]:
with open('models/tfidf_vec.bin', 'rb') as f:
    vec = pickle.load(f)

In [87]:
df_all = pd.concat([df_train, df_val]).reset_index(drop=1)

In [88]:
vec = TfidfVectorizer(stop_words='english', ngram_range=(1, 3), min_df=10)

In [89]:
X_train = vec.fit_transform(df_all.body)
X_test = vec.transform(df_test.body)

In [90]:
y_train = get_y(df_all, selected_topics)

In [106]:
from sklearn.cross_validation import KFold
cv = KFold(n=len(df_all), n_folds=3, shuffle=True, random_state=1)

In [107]:
from time import time

In [108]:
models = {}

train_preds = {}
test_preds = {}

t = time()

svm_params = {
    'penalty': 'l1',
    'dual': False,
    'C': 1.0,
    'random_state': 1,
}

for i in range(y_train.shape[1]):
    t0 = time()
    topic = selected_topics[i]

    y = y_train[:, i]
    try:
        train_pred = np.zeros(len(y), dtype='float32')

        for train_idx, val_idx in cv:
            svm = LinearSVC(**svm_params)
            svm.fit(X_train[train_idx], y[train_idx])
            train_pred[val_idx] = svm.decision_function(X_train[val_idx])

        train_preds[topic] = train_pred

        svm = LinearSVC(**svm_params)
        models[topic] = svm.fit(X_train, y)
        test_preds[topic] = svm.decision_function(X_test)

        print('%s, took %.3fs' % (topic, time() - t0))
    except:
        print('got error for %s, skipping it' % (topic))

print('overall took %.3fs' % (time() - t))

afghanistan, took 80.721s
aid, took 101.524s
algerianhostagecrisis, took 25.385s
alqaida, took 98.818s
alshabaab, took 31.521s
antiwar, took 32.056s
arabandmiddleeastprotests, took 130.147s
armstrade, took 41.443s
australiansecurityandcounterterrorism, took 27.061s
belgium, took 28.165s
bigdata, took 26.812s
biometrics, took 27.567s
bokoharam, took 26.584s
bostonmarathonbombing, took 25.204s
britisharmy, took 41.036s
cameroon, took 24.286s
carers, took 26.461s
chemicalweapons, took 25.766s
clusterbombs, took 23.546s
cobra, took 24.212s
conflictanddevelopment, took 72.526s
controversy, took 27.393s
criminaljustice, took 163.303s
cybercrime, took 34.035s
cyberwar, took 29.692s
dataprotection, took 67.941s
defence, took 144.224s
deflation, took 24.497s
drones, took 32.665s
drugs, took 41.922s
drugspolicy, took 40.150s
drugstrade, took 50.151s
earthquakes, took 29.450s
ebola, took 27.104s
economy, took 162.188s
egypt, took 49.075s
encryption, took 25.852s
energy, took 47.797s
espionage, to

In [109]:
pred_total = [train_preds[t].astype('float32') for t in selected_topics]
pred_total = np.array(pred_total).T

In [110]:
pred_total

array([[-0.20290829, -1.42606413, -1.23267102, ..., -1.12577271,
        -1.06694591, -1.08752131],
       [ 1.15480232, -1.56295514, -1.2080282 , ..., -1.1272527 ,
        -1.07851553, -1.07916272],
       [-1.1033839 , -1.45609128, -1.23267102, ..., -1.13343871,
        -1.09774947, -1.18163681],
       ..., 
       [-1.2271198 , -1.97697723, -1.22208393, ..., -1.08837831,
        -1.108024  , -1.09676635],
       [-1.32761407, -1.29466021, -1.2080282 , ..., -1.15244091,
        -1.13327777, -1.10450971],
       [-1.21049392, -1.41006315, -1.2080282 , ..., -1.13947082,
        -1.13505507, -1.09904253]], dtype=float32)

In [111]:
f1_score(y_train, pred_total >= 0, average='micro')

0.74804079680565916

In [114]:
f1s = []

for t in np.linspace(-1, 0, 11):
    f1 = f1_score(y_train, pred_total >= t, average='micro')
    print('t=%0.2f, f1=%.4f' % (t, f1))
    f1s.append((f1, t))

t=-1.00, f1=0.3743
t=-0.90, f1=0.5336
t=-0.80, f1=0.6322
t=-0.70, f1=0.6960
t=-0.60, f1=0.7361
t=-0.50, f1=0.7593
t=-0.40, f1=0.7717
t=-0.30, f1=0.7752
t=-0.20, f1=0.7713
t=-0.10, f1=0.7621
t=0.00, f1=0.7480


In [115]:
max(f1s)

(0.77517980270698839, -0.29999999999999993)

In [122]:
all_zeros = np.zeros(X_test.shape[0], dtype='uint8')

df_final_pred = pd.DataFrame()
df_final_pred['id'] = df_test['key']

for t in all_topics:
    if t in test_preds:
        pred = test_preds[t]
        df_final_pred[t] = (pred >= -0.3).astype('uint8')
    else:
        df_final_pred[t] = all_zeros

In [123]:
df_final_pred.drop('id', axis=1).sum().sum()

12494

In [124]:
df_final_pred.drop('id', axis=1).sum()

activism                                    0
afghanistan                               106
aid                                        65
algerianhostagecrisis                       4
alqaida                                   110
alshabaab                                  28
antiwar                                     7
arabandmiddleeastprotests                 172
armstrade                                  80
australianguncontrol                        0
australiansecurityandcounterterrorism       6
bastilledaytruckattack                      0
belgium                                   124
berlinchristmasmarketattack                 0
bigdata                                     8
biometrics                                  2
bokoharam                                  38
bostonmarathonbombing                      60
britisharmy                                 2
brusselsattacks                             0
cameroon                                    3
carers                            

In [125]:
df_final_pred.to_csv('svm_sub3.csv', index=False)

In [126]:
with open('smv_models_pred.bin', 'wb') as f:
    pickle.dump((models, train_preds, test_preds), f)