In [None]:
from collections import Counter
import pandas as pd
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score
from sklearn.pipeline import Pipeline
from sklearn.model_selection import cross_validate

#### Model for Misformation Spreader classification

In [None]:
full_df_pan = pd.read_pickle("./data/preprocessed/full_df.pkl")
full_df_pan.tweets = full_df_pan.tweets.apply(lambda x: " ".join(x))
full_df_pan.TopN_Tweets = full_df_pan.TopN_Tweets.apply(lambda x: x[:30])
full_df_pan.TopN_Tweets = full_df_pan.TopN_Tweets.apply(lambda x: " ".join(x))

In [None]:
X = full_df_pan.TopN_Tweets
y = full_df_pan.target

In [None]:
pipe = Pipeline([('cv', CountVectorizer()),
                 ('clf', RandomForestClassifier(max_depth=4, random_state=0))])
pipe.fit(X, y)
# cross_validate(pipe, X, y, cv=3)

In [None]:
from joblib import dump, load
dump(pipe, './model/pan.joblib') 

#### Model for Disaster Classifier.

In [None]:
full_df_disaster = pd.read_csv("./data/preprocessed/disaster/nlp_disaster.csv")
full_df_disaster = full_df_disaster.rename({"text":"tweets"}, axis = 1)

In [None]:
X = full_df_disaster.tweets
y = full_df_disaster.target

pipe = Pipeline([('cv', CountVectorizer()),
                 ('clf', SVC(random_state=0))])
pipe.fit(X, y)
# cross_validate(pipe, X, y, cv=3)

In [None]:
dump(pipe, './model/disaster.joblib') 

In [None]:
def predict_tweet(tweets):
    pan_clf = load('./model/pan.joblib')
    pan_pred = pan_clf.predict([" ".join(tweets)])[0]
    disaster_clf = load('./model/disaster.joblib')
    predictions = np.array([disaster_clf.predict([tweet])[0] for tweet in tweets])
    disaster_tweets_indexes = np.where(predictions == 1)[0]
    disaster_counts = Counter(predictions)[1]
    disaster_pred  = 1 if disaster_counts / len(tweets) > 0.02 else 0
    print(disaster_counts / len(tweets))
    if pan_pred and disaster_pred:
        print("potential misinformation spreader tweeting about disaster.")
    elif pan_pred or disaster_pred:
        if pan_pred:
            print("potential misinformation spreader")
        else:
            print("a regular user tweeting about disaster")
    else:
        print("a regular user with routine tweet ")
    
    return pan_pred, disaster_pred#, disaster_tweets_indexes

In [None]:
tweets = ["tweet1", "tweet2"]

In [None]:
predict_tweet(tweets)