In [2]:
import numpy as np
import os

import tensorflow as tf
assert tf.__version__.startswith('2')

from tflite_model_maker import configs
from tflite_model_maker import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import text_classifier
from tflite_model_maker import TextClassifierDataLoader

In [3]:
import numpy as np, pandas as pd
from tqdm import tqdm_notebook
import json

from pathlib import Path
import os
import fasttext
import csv

## Prepare the data

In [4]:
def load_data(file):

    with open(f'{file}', 'r', encoding='utf-8') as f:
        input_lines = f.read().split('\n')

    data = {
        "title": [],
        "description": [],
        "recent_posts": [],
    }

    for line in input_lines:
        try:
            line = json.loads(line)
        except:
            pass

        if line != '':
            try:
                data['title'].append(line['title'])
                data['description'].append(line['description'])
                data['recent_posts'].append('\n'.join(line['recent_posts']))
            except:
                print('Parse error')

    data = pd.DataFrame(data)
    data['label'] = file.split('/')[-1].split('.')[0]
    
    return data

In [5]:
data = pd.DataFrame(columns = ['title', 'description', 'recent_posts', 'label'])
PATH = Path('tgparser/RU_TGSTAT_DATA/')
for file in os.listdir(PATH):
    data = pd.concat([data, load_data(str(PATH / file))]).reset_index(drop=True)
    
data['recent_posts'] = data['recent_posts'].apply(lambda x: x.replace('\u200b', ''))
data['label'] = data['label'].apply(lambda x: x.split('_')[0])

Parse error
Parse error
Parse error


In [6]:
mapper = {'Art & Design': '__label__0',
 'Bets & Gambling': '__label__1',
 'Books': '__label__2',
 'Business & Entrepreneurship': '__label__3',
 'Cars & Other Vehicles': '__label__4',
 'Celebrities & Lifestyle': '__label__5',
 'Cryptocurrencies': '__label__6',
 'Culture & Events': '__label__7',
 'Curious Facts': '__label__8',
 'Directories of Channels & Bots': '__label__9',
 'Economy & Finance': '__label__10',
 'Education': '__label__11',
 'Erotic Content': '__label__12',
 'Fashion & Beauty': '__label__13',
 'Fitness': '__label__14',
 'Food & Cooking': '__label__15',
 'Foreign Languages': '__label__16',
 'Health & Medicine': '__label__17',
 'History': '__label__18',
 'Hobbies & Activities': '__label__19',
 'Home & Architecture': '__label__20',
 'Humor & Memes': '__label__21',
 'Investments': '__label__22',
 'Job Listings': '__label__23',
 'Kids & Parenting': '__label__24',
 'Marketing & PR': '__label__25',
 'Motivation & Self-Development': '__label__26',
 'Movies': '__label__27',
 'Music': '__label__28',
 'Offers & Promotions': '__label__29',
 'Pets': '__label__30',
 'Politics & Incidents': '__label__31',
 'Psychology & Relationships': '__label__32',
 'Real Estate': '__label__33',
 'Recreation & Entertainment': '__label__34',
 'Religion & Spirituality': '__label__35',
 'Science': '__label__36',
 'Sports': '__label__37',
 'Technology & Internet': '__label__38',
 'Travel & Tourism': '__label__39',
 'Video Games': '__label__40',
 'Other': '__label__41'}

reverse_mapper = {v: k for k, v in mapper.items()}

In [7]:
# Remove emojis
import re


def deEmojify(text):    
    regex_pattern = re.compile("["
        u"\U0001F600-\U0001F64F"  # emoticons
        u"\U0001F300-\U0001F5FF"  # symbols & pictographs
        u"\U0001F680-\U0001F6FF"  # transport & map symbols
        u"\U0001F1E0-\U0001F1FF"  # flags (iOS)
        u"\U00002500-\U00002BEF"  # chinese char
        u"\U00002702-\U000027B0"
        u"\U00002702-\U000027B0"
        u"\U000024C2-\U0001F251"
        u"\U0001f926-\U0001f937"
        u"\U00010000-\U0010ffff"
        u"\u2640-\u2642" 
        u"\u2600-\u2B55"
        u"\u200d"
        u"\u23cf"
        u"\u23e9"
        u"\u231a"
        u"\ufe0f"  # dingbats
        u"\u3030"                    
                      "]+", re.UNICODE)
    return regex_pattern.sub(r'',text)

data['recent_posts'] = data['recent_posts'].apply(deEmojify)
data['title'] = data['title'].apply(deEmojify)
data['description'] = data['description'].apply(deEmojify)

In [8]:
data['recent_posts'] = data['recent_posts'].apply(lambda x: x.lower())
data['title'] = data['title'].apply(lambda x: x.lower())
data['description'] = data['description'].apply(lambda x: x.lower())

In [9]:
# Remove adds from all posts

ALL_POSTS = []
for v in tqdm_notebook(data['recent_posts'].apply(lambda x: x.split('\n'))):
    ALL_POSTS.extend(v)
post_counts = pd.Series(ALL_POSTS).value_counts().sort_values(ascending = False)


def filter_posts(posts, threshold = 5):
    posts = posts.split('\n')
    filtered_posts = []
    for post in posts:
        if post_counts[post] < threshold:
            filtered_posts.append(post)
    return '\n'.join(filtered_posts)

for i in range(data.shape[0]):
    data.iloc[i, 2] = filter_posts(data.iloc[i, 2])

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  after removing the cwd from sys.path.


  0%|          | 0/7070 [00:00<?, ?it/s]

In [10]:
def removeEmail(text):
    pattern = re.compile("((\w+)(\.|_)?(\w*)@(\w+)(\.(\w+))+)")
    return pattern.sub(r'', text)

data['recent_posts'] = data['recent_posts'].apply(removeEmail)
data['title'] = data['title'].apply(removeEmail)
data['description'] = data['description'].apply(removeEmail)

In [11]:
def removeUsername(text):
    pattern = re.compile("(@(\w+))")
    return pattern.sub(r'', text)

data['recent_posts'] = data['recent_posts'].apply(removeUsername)
data['title'] = data['title'].apply(removeUsername)
data['description'] = data['description'].apply(removeUsername)

In [12]:
def removeLinks(text):
    pattern = re.compile("(https?://[^ ]+)")
    return pattern.sub(r'', text)

data['recent_posts'] = data['recent_posts'].apply(removeLinks)
data['title'] = data['title'].apply(removeLinks)
data['description'] = data['description'].apply(removeLinks)

In [13]:
from sklearn.model_selection import train_test_split

train, test = train_test_split(data, shuffle = True, train_size = 0.7)

train = train.reset_index(drop = True)
test = test.reset_index(drop = True)

In [14]:
def sample_data(data):
    output_data = []
    N_SAMPLES = 10
    for i in range(data.shape[0]):
        for j in range(N_SAMPLES):
            titleText = data.iloc[i, 0] + data.iloc[i, 1]
            posts = '\n'.join(pd.Series(data.iloc[i, 2].split('\n')).sample(n = 5, replace = True).values)
            output_data.append([mapper[data.iloc[i, 3]], titleText + posts])
    return pd.DataFrame(output_data)

In [15]:
print(train.shape, test.shape)

(4949, 4) (2121, 4)


In [16]:
train = sample_data(train)
test = sample_data(test)

In [17]:
print(train.shape, test.shape)

(49490, 2) (21210, 2)


In [18]:
train.columns = ['label', 'text']
test.columns = ['label', 'text']

In [19]:
train.to_csv('data/train.csv', index = None)
test.to_csv('data/test.csv', index = None)

## Train model

In [20]:
#spec = model_spec.get('mobilebert_classifier')
spec=  model_spec.get('average_word_vec')

In [21]:
%%time

train_data = TextClassifierDataLoader.from_csv(
      filename='data/train.csv',
      text_column='text',
      label_column='label',
      model_spec=spec,
      is_training=True)

CPU times: user 45.1 s, sys: 626 ms, total: 45.7 s
Wall time: 45.9 s


In [22]:
%%time

test_data = TextClassifierDataLoader.from_csv(
      filename='data/test.csv',
      text_column='text',
      label_column='label',
      model_spec=spec,
      is_training=False)

CPU times: user 12.3 s, sys: 116 ms, total: 12.4 s
Wall time: 12.4 s


In [25]:
model = text_classifier.create(train_data, model_spec=spec, epochs = 20)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


In [26]:
loss, acc = model.evaluate(test_data)



## Export tflite model

In [37]:
model.export(export_dir='models/')

Finished populating metadata and associated file to the model:
models/model.tflite
The metadata json file has been saved to:
models/model.json
The associated file that has been been packed to the model is:
['vocab.txt', 'labels.txt']


  "tflite model is still allowed.".format(f))
