In [None]:
import numpy as np
import random

from datasets import load_dataset
from string import punctuation
from sklearn import preprocessing

from matplotlib import pyplot as plt

#Libs imports
import lib.dataLoader as dl
import lib.model as bayes

from IPython.display import Image

### Load dataset and pop the useless unsupervised part

The IMDB dataset contain 3 splits:<br>
    -The train dataset that contain 25000 elements<br>
    -The test dataset that contain 25000 elements<br>
    -The unsupervised dataset that contain 50000 element

In [None]:
dataset = load_dataset("imdb")
dataset.pop("unsupervised")
dataset

### Setup the transforms in order to remove the ponctuation and lower case the texts

In [None]:
dataset['train'].set_transform(dl.removePonctTransform)
dataset['test'].set_transform(dl.removePonctTransform)

### Split the datasets in order to train on it

In [None]:
(X_train_text, X_test_text, y_train, y_test) = dl.splitDataset(dataset)

### Gets the more revelant words in order to encode the texts

In [None]:
words = dl.listMostRevelantWord(X_train_text, 3400)

### Encode the texts

In [None]:
X_train = dl.textToRevelantMatrice(X_train_text, words)
X_test = dl.textToRevelantMatrice(X_test_text, words)

### Create bayes model

In [None]:
model = bayes.createModel()

### Train the bayes model with our train encoded dataset

In [None]:
model = bayes.trainModel(model, X_train, y_train)

### Draw results on the test dataset

In [None]:
bayes.testModel(model, X_test, y_test, True)

Here the accuracy is a good way to measure the efficiency of our model because our dataset is large and our learning is unsupervised

### As part of our analysis, we have executed the following line of code as a comment. This takes a lot of time, that's why we recorded our result that we display here:

In [None]:
"""
words_ = dl.listSortRevelantWord(X_train_text)

no_ponct_result_list = []

for i in range (100, 5000, 100):
    words = dl.getNFirstWords(words_, i)

    X_train = dl.textToRevelantMatrice(X_train_text, words)
    X_test = dl.textToRevelantMatrice(X_test_text, words)

    model = bayes.createModel()
    model = bayes.trainModel(model, X_train, y_train)

    no_ponct_result_list.append(bayes.countDif(model, X_test, y_test))
    
x = np.linspace(100, 5000, num=len(no_ponct_result_list))

plt.plot(x, no_ponct_result_list)
plt.title("Prediction curve according to the number of word used without ponctuation")
plt.xlabel("Number of most revelant words")
plt.ylabel("Number of fail prediction")

plt.savefig("figures/no_ponct_preds")
"""
Image("figures/no_ponct_preds.png")

### This is how we deduced that we had to take the first 3400 words

### We do the same process as before with a better preprocess, we remove ponctuation, upper cases and stop words

In [None]:
dataset = load_dataset("imdb")
dataset['train'].set_transform(dl.removeStopWordsTransform)
dataset['test'].set_transform(dl.removeStopWordsTransform)
(X_train_text, X_test_text, y_train, y_test) = dl.splitDataset(dataset)

words = dl.listMostRevelantWord(X_train_text, 2600)

X_train = dl.textToRevelantMatrice(X_train_text, words)
X_test = dl.textToRevelantMatrice(X_test_text, words)

model = bayes.createModel()
model = bayes.trainModel(model, X_train, y_train)

bayes.testModel(model, X_test, y_test, True)

We have here a better accuracy, the steamming removes stopwords and therefore words that do not add context to our text, false information in our case

### As before, we ran this code to deduce that we should take the first 2600 lines:

In [None]:
"""
words_ = dl.listSortRevelantWord(X_train_text)

no_stopwords_result_list = []

for i in range (100, 5000, 100):
    words = dl.getNFirstWords(words_, i)

    X_train = dl.textToRevelantMatrice(X_train_text, words)
    X_test = dl.textToRevelantMatrice(X_test_text, words)

    model = bayes.createModel()
    model = bayes.trainModel(model, X_train, y_train)

    no_stopwords_result_list.append(bayes.countDif(model, X_test, y_test))

x = np.linspace(100, 5000, num=len(no_stopwords_result_list))

plt.plot(x, no_stopwords_result_list)
plt.title("Prediction curve according to the number of word used without stop words")
plt.xlabel("Number of most revelant words")
plt.ylabel("Number of fail prediction")
plt.savefig("figures/no_stopwords_preds")
"""

Image("figures/no_stopwords_preds.png")