In [17]:
import collections
import pathlib
import re
import string
import pickle
import tensorflow as tf
import pandas as pd
import numpy as np
import random

from tensorflow.keras import layers
from tensorflow.keras import losses
from tensorflow.keras import preprocessing
from tensorflow.keras import utils
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization

import tensorflow_datasets as tfds
import tensorflow_text as tf_text

In [18]:
from grabber import ChatBoxGrabber
dataset = []
for key, data in pickle.load(open('data.pickle', 'rb')).items():
    dataset.append(tuple([data.message, data.user]))
random.shuffle(dataset)

In [19]:
for user in {pair[1] for pair in dataset}:
    if sum([True for pair in dataset if pair[1] == user]) < 1000:
        dataset = [pair for pair in dataset if not user == pair[1]]

In [20]:
dataset = [pair for pair in dataset if len(pair[0]) < 200]

In [21]:
messages = [pair[0] for pair in dataset]
senders = [pair[1] for pair in dataset]

In [22]:
vocab = {'':0}
for message in messages:
    for char in message:
        if char not in vocab.keys():
            vocab[char] = len(vocab)

In [23]:
userVocab = {}
for user in senders:
    if user not in userVocab.keys():
            userVocab[user] = len(userVocab)
senders = [userVocab[user] for user in senders]

In [36]:
def vectorize(string):
    if not type(string) == str:
        return [vectorize(str) for str in string]
    ans=np.zeros(200, dtype = int)
    for i in range(len(string)):
        ans[i] = vocab[string[i]]
    return ans
#vectorize(messages[0])

In [25]:
vectors = np.array([vectorize(message) for message in messages])
#dataframe = pd.DataFrame({'x':messages, 'y':senders}).sample(frac=1).reset_index(drop=True)
vectors.shape

(99575, 200)

In [26]:
test = tf.data.Dataset.from_tensor_slices((vectors[:10000], senders[:10000]))
validate = tf.data.Dataset.from_tensor_slices((vectors[10000:30000], senders[10000:30000]))
train = tf.data.Dataset.from_tensor_slices((vectors[30000:], senders[30000:]))

In [27]:
dataset = tf.data.Dataset.from_tensor_slices((messages, np.array(senders)))

In [28]:
model = tf.keras.Sequential([
      layers.Embedding(len(vocab), 64, mask_zero=True),
      layers.Conv1D(64, 5, padding="valid", activation="relu", strides=2),
      layers.Conv1D(64, 5, padding="valid", activation="relu", strides=2),
      layers.Conv1D(64, 5, padding="valid", activation="relu", strides=2),
      layers.Conv1D(64, 5, padding="valid", activation="relu", strides=2),
      layers.GlobalMaxPooling1D(),
      layers.Dense(len({i for i in senders}))
  ])
model.compile(
    loss=losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy'])

In [29]:
train

<TensorSliceDataset shapes: ((200,), ()), types: (tf.int32, tf.int32)>

In [30]:
model.fit(vectors[30000:].tolist(), senders[30000:], epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x25efc2a95b0>

In [31]:
val = []
for user in userVocab.values():
    vects = []
    for i in range(0, len(senders)):
        if(senders[i] == user):
            vects.append(vectors[i])
    if len(vects) < 1000:
        print("None!")
        continue
    temp = sum(model.predict(np.array(random.sample(vects, 8))))
    print(max(temp) == temp[user])
    if not max(temp) == temp[user]:
        val.append(len(vects))
max(val)

True
True
True
True
False
True
True
True
True
False
False
False
True
True


4693

In [40]:
sum(model.predict(np.array(vectorize(["	@arsha :تیکه بود؟ :39:", "@Cripher :21::24:", "@Cripher تو حالت خوبه؟:21:	"]))))

array([  2.6109245 ,   3.1919065 ,  -0.14117494, -12.799496  ,
        -4.9312677 ,  -8.419474  ,  -0.5565512 ,   9.837082  ,
        11.354764  ,  -7.4957457 , -23.453403  , -15.742021  ,
         0.7645676 ,   1.3965693 ], dtype=float32)

In [33]:
userVocab

{635: 0,
 2903: 1,
 2274: 2,
 2454: 3,
 3074: 4,
 321: 5,
 2905: 6,
 2914: 7,
 2803: 8,
 1714: 9,
 1034: 10,
 2883: 11,
 986: 12,
 2210: 13}

In [90]:
counter = 0
for i in range(0, len(senders)):
        if(senders[i] == 3):
            counter +=1
counter

2777