# Lab 02
## Introduction
This project's goal is to code a sentiment classifier on the IMDB sentiment dataset. The IMDB sentiment [dataset](https://huggingface.co/datasets/imdb) is a collection of 50K movie reviews, annotated as positive or negative, and split in two sets of equal size: a training and a test set. Both set have an equal number of positive and negative review.

## The dataset

In [21]:
from datasets import load_dataset


dataset = load_dataset("imdb")

ModuleNotFoundError: No module named 'datasets'

1. How many splits does the dataset has ?
2. How big are these splits ?

In [2]:
dataset.num_rows

{'train': 25000, 'test': 25000, 'unsupervised': 50000}

The dataset has 3 splits : train, test and unsupervised. They represent respectively 25000, 25000 and 50000 examples.

3. What is the proportion of each class on the supervised splits ?

In [3]:
positiveTrain = dataset['train'].filter(lambda example: example['label'] == 1).num_rows
positiveTest = dataset['test'].filter(lambda example: example['label'] == 1).num_rows
print("In the train and test splits, there are respectively "+ str(positiveTrain) + "/25000 and " + str(positiveTest) + "/25000 positive ratings")


Loading cached processed dataset at C:\Users\antho\.cache\huggingface\datasets\imdb\plain_text\1.0.0\d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0\cache-50ca35f45cb9d15a.arrow
Loading cached processed dataset at C:\Users\antho\.cache\huggingface\datasets\imdb\plain_text\1.0.0\d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0\cache-3cd99f281cbb7e3f.arrow


In the train and test splits, there are respectively 12500/25000 and 12500/25000 positive ratings


Both supervised  have an equal number of positive and negative review

## Naive Bayes classifier
### 1. Processing function

In [25]:
import string
def process(txt):
    """
    Converts all uppercase letters to lowercase, replaces all punctuation marks with spaces, and returns the processed string.
    """
    
    lowercase_txt = txt.lower()
    
    # create a translation table using maketrans method
    replace_punctuation = str.maketrans(string.punctuation, ' '*len(string.punctuation))
    
    # use the translate method to replace the punctuation
    processed_txt = lowercase_txt.translate(replace_punctuation)
    return processed_txt
process("What's your name? I'm Ba-yes")

'what s your name  i m ba yes'

### 2. Our Naive Bayes

In [49]:
from collections import Counter
from collections import defaultdict

import math

def train_naive_bayes(documents, classes):
    #def train_naive_bayes(documents: list[str], classes: list[int]) -> tuple[set[int, float], set[tuple[str, int], float], list[str]]:
    """
    Trains a Naive Bayes classifier on a list of labeled documents.

    Args:
    - documents (list): A list of dictionaries containing the text and label for each document.
    - classes (list): A list of the possible class labels (0 or/and 1).

    Returns:
    - log_prior (dict): A dictionary containing the log prior probabilities for each class.
    - log_likelihood (defaultdict): A defaultdict containing the log likelihood probabilities
    for each word in the vocabulary given each class.
    - vocabulary (list): A list of words in the vocabulary.
    """
    n_doc = len(documents)
    log_prior = {}
    whole_vocabulary = set([word for d in documents for word in process(d['text']).split()])
    vocabulary = sorted({s for s in whole_vocabulary if s.isalpha()})
    log_likelihood = defaultdict(lambda: math.log(1/len(vocabulary)))
    bigdoc = {}
    word_counts = {}
    # Calculate P(c) terms
    for c in classes:
        n_c = len([d for d in documents if d['label'] == c])
        log_prior[c] = math.log(n_c / n_doc)

    # Calculate P(w|c) terms
        bigdoc = [process(d['text']).split() for d in documents if d['label'] == c]
        word_counts[c] = Counter([word for doc in bigdoc for word in doc])
        total_count = sum(word_counts[c].values())
        for word in vocabulary:
            count_w_c = word_counts[c][word]
            log_likelihood[(word, c)] = math.log((count_w_c + 1) / (total_count + len(vocabulary)))

    return log_prior, log_likelihood, vocabulary


def test_naive_bayes(test_txt, log_prior, log_likelihood, classes, vocabulary):
    """
    Predicts the label of a given test document using a trained Naive Bayes classifier.
    """
    # Calculate sum[c] for each class c
    sum_c = {}
    test_words = set(test_txt.split())
    for c in classes:
        sum_c[c] = log_prior[c]
    for word, c in log_likelihood.keys():
        if word in test_words:
            sum_c[c] += log_likelihood[(word, c)]
    
    # Return the class with highest sum[c]
    return max(sum_c, key=sum_c.get)

def test_accuracy(test_set, log_prior, log_likelihood, classes, vocabulary):
    """
    Tests the accuracy of a trained Naive Bayes classifier on a given test set.

    Returns:
    - accuracy: The accuracy of the classifier on the test set as a fraction between 0 and 1.
    """
    true = 0
    total = 0
    for test_doc in test_set:
        test_class = test_naive_bayes(process(test_doc["text"]), log_prior, log_likelihood, [0, 1], vocabulary)
        if test_doc["label"] == test_class:
            true += 1
        total += 1
    return true/total

### 3. With Sickit-learn

In [37]:
from sklearn.naive_bayes import MultinomialNB
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.pipeline import Pipeline

text_clf = Pipeline([
    ('vect', CountVectorizer()),  # Vectorize the data
    ('clf', MultinomialNB()),  # Train  the classifier
])

### 4. Compare

#### Training

In [38]:
little_train = dataset['train'].shard(num_shards=100, index=0)
little_test = dataset['test'].shard(num_shards=100, index=0)
log_prior, log_likelihood, vocabulary = train_naive_bayes(dataset['train'], [0, 1])
text_clf.fit(dataset['train']['text'], dataset['train']['label'])

NameError: name 'dataset' is not defined

In [78]:
print(test_accuracy(little_test, log_prior, log_likelihood, [0, 1], vocabulary))
print(text_clf.score(little_test['text'], little_test['label']))

0.82
0.8
