In [1]:
from sklearn.mixture import GaussianMixture
from gensim.models.doc2vec import TaggedDocument
from gensim.models.doc2vec import Doc2Vec
from tqdm import tqdm
import numpy as np
import torch
from transformers import CamembertForSequenceClassification, CamembertTokenizer
from utils import *

In [2]:
df_articles = load_newspaper()

In [3]:
train_dataset, test_dataset = extract_train_test_dataset(df_articles)

In [5]:
tagged_docs = []
for index, row in tqdm(train_dataset.iterrows()):
    text = row['body'].lower().split()
    tag = row['label']
    tagged_docs.append(TaggedDocument(text, [tag]))

17000it [00:01, 13471.97it/s]


In [18]:
d2v = Doc2Vec(documents=tagged_docs, epochs=40)
d2v.build_vocab(tagged_docs)
d2v.train(tagged_docs, total_examples=d2v.corpus_count, epochs=d2v.epochs)

In [4]:
model = CamembertForSequenceClassification.from_pretrained("camembert-base", num_labels=6)
tokenizer = CamembertTokenizer.from_pretrained('camembert-base',do_lower_case=True)

Some weights of the model checkpoint at camembert-base were not used when initializing CamembertForSequenceClassification: ['lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'roberta.pooler.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.layer_norm.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing CamembertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing CamembertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of CamembertForSequenceClassification were not initialized from the model checkpoint at camembert-base and are newly initialized: ['classifier.out_proj.bias

In [19]:
gmm = GaussianMixture(n_components=11, max_iter=200)

In [16]:
train_vectors = []
y_train = []
for index, row in tqdm(train_dataset.iterrows()):
    text = row['body'].lower()
    parts = split_text_in_parts(text)
    n = 0
    rep = 0
    for part in parts:
        tok_text = torch.tensor([tokenizer(' '.join(part))['input_ids']])
        res = model.roberta.embeddings(tok_text)
        n += len(part)
        rep += len(part)*torch.mean(res, dim=1)[0].detach().numpy()
    tag = str(index)
    if n > 0:
        y_train.append(row['label'])
        train_vectors.append(rep / n)

17000it [05:23, 52.63it/s]


In [17]:
test_vectors = []
y_test = []
for index, row in tqdm(test_dataset.iterrows()):
    text = row['body'].lower()
    parts = split_text_in_parts(text)
    n = 0
    rep = 0
    for part in parts:
        tok_text = torch.tensor([tokenizer(' '.join(part))['input_ids']])
        res = model.roberta.embeddings(tok_text)
        n += len(part)
        rep += len(part)*torch.mean(res, dim=1)[0].detach().numpy()
    tag = str(index)
    if n > 0:
        y_test.append(row['label'])
        test_vectors.append(rep / n)

5500it [01:35, 57.36it/s]


In [None]:
train_vectors = []
y_train = []
for index, row in tqdm(test_dataset.iterrows()):
    text = row['body'].lower().split()
    tag = str(index)
    y_train.append(row['label'])
    train_vectors.append(d2v.infer_vector(text))

In [21]:
vectors = []
y_true = []
for index, row in tqdm(test_dataset.iterrows()):
    text = row['body'].lower().split()
    tag = str(index)
    y_true.append(row['label'])
    vectors.append(d2v.infer_vector(text))

5500it [00:57, 95.78it/s] 


In [20]:
gmm.fit(train_vectors, y_train)

In [22]:
y_pred = gmm.predict(test_vectors)

In [50]:
acc = np.sum(np.array(y_pred) == np.array(y_true)) / len(y_pred)
acc

0.49672727272727274

In [51]:
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(np.array(y_true), np.array(y_pred))
cm

array([[570,  17, 287,   0,  24, 102],
       [  9, 878,  19,   0,  63,  31],
       [ 32,  45, 571,   0, 152, 200],
       [418,   5,   9,   0,  26,  42],
       [ 55,  27,  64,   0, 669, 185],
       [  9, 256, 669,   0,  22,  44]], dtype=int64)