# Exercise 5. Text Classification

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In this exercise we will build a classification model for the newsgroup dataset.

We will apply the following steps:

* A.	Document representation with tf-idf, word2vec and BERT
* B.	Naïve Bayes classification model
* C.	Random Forest
* D.	Grid search


In [3]:
# Import packages
import pandas as pd
import pickle
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.naive_bayes import MultinomialNB
from sklearn.metrics import classification_report, accuracy_score
from sklearn.ensemble import RandomForestClassifier

# A.	Document representation with tf-idf, word2vec and BERT

In [4]:
# Load the stmnmed texts
data_stem=pd.DataFrame(pickle.load(open("/content/drive/MyDrive/TWSM_Data/Stemmed.pkl", "rb")))
# Load the word2vec embeddings and teh list of removed values
data_w2v=pickle.load(open("/content/drive/MyDrive/TWSM_Data/WordtoVecModel.pkl", "rb"))
positive=pickle.load(open('/content/drive/MyDrive/TWSM_Data/positive.pkl', 'rb'))
# Load the BERT mebeddings
data_BERT=pickle.load(open("/content/drive/MyDrive/TWSM_Data/BertModel.pkl", "rb"))

In [5]:
# Transform teh stememd data
data_stem=data_stem.set_axis(['preprocessed'], axis=1)
data_stem.head()

Unnamed: 0,preprocessed
0,car wonder enlighten car saw dai door sport ca...
1,clock poll final final clock report acceler cl...
2,question folk mac plu final gave ghost weekend...
3,weitek robert kyanko rob rjck uucp wrote abrax...
4,shuttl launch question articl cowcb world std ...


In [6]:
data_w2v.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99
0,0.351774,-0.270618,-0.099167,0.440975,0.147224,-0.04785,0.347204,-0.370328,-0.306052,-0.53304,0.784328,0.352205,-0.078587,-0.83025,0.189646,-0.017968,0.253096,0.125804,-0.177684,0.034206,0.132728,0.088732,-0.123832,-0.398519,0.014783,-0.270238,0.117907,0.391219,0.158757,0.173804,0.131754,-0.084592,-0.40822,-0.056334,-0.280491,-0.08024,0.259069,0.260064,-0.286499,-0.022011,...,0.080258,-0.143452,-0.554667,0.198921,-0.189796,0.216042,-0.360367,0.454792,-0.272879,0.330489,0.399188,0.305223,-0.085042,0.435105,0.282648,0.01771,0.235993,-0.084351,-0.208245,0.192481,-0.13251,0.138845,0.169404,0.136261,-0.032994,-0.172526,-0.098615,0.172062,-0.217515,-0.404443,0.4921,-0.176473,0.651371,0.46328,-0.441628,-0.175508,0.143901,0.502239,-0.043058,-0.491982
1,0.008163,0.162567,0.056294,0.110691,0.077435,0.173565,0.030174,0.191341,0.136779,-0.149317,0.083286,0.085803,0.400404,-0.347317,0.033176,-0.430936,0.143813,0.261639,-0.115579,-0.062114,-0.088979,-0.347597,-0.098744,-0.244849,-0.298315,0.557203,0.301427,0.090291,-0.113943,0.139821,0.007065,-0.216194,-0.249048,0.365799,0.004725,0.116574,-0.111931,0.211696,0.011096,-0.09235,...,-0.00641,0.043633,-0.382989,0.39463,-0.005884,-0.19418,-0.281401,0.188782,-0.145757,0.141406,-0.065778,0.286781,-0.13732,0.399824,0.202061,0.221534,0.074902,-0.190985,0.102994,0.424236,0.098024,0.057525,0.282348,0.437015,0.181157,0.176196,-0.049588,0.107599,0.250988,-0.072748,0.096912,0.354876,0.271327,0.113638,-0.151813,0.036743,0.010965,0.239501,0.214709,-0.38445
2,0.18277,-0.100424,0.019315,0.334073,0.107887,0.124762,-0.113445,0.359388,-0.018454,-0.086599,0.106137,-0.052467,0.183271,-0.273118,0.272357,0.148054,0.107701,0.01081,-0.053615,-0.050433,-0.063988,-0.272693,-0.25213,-0.030147,0.038941,0.213088,0.046555,0.011535,-0.005292,0.148623,0.0758,-0.08096,-0.206258,0.305082,-0.191143,0.226707,0.109818,-0.082238,0.061433,0.001232,...,0.198873,0.365456,-0.46816,0.261627,-0.177639,-0.010834,-0.067605,0.086463,-0.112576,0.312297,0.121522,0.208696,-0.096691,0.161278,0.142064,0.319702,0.152462,-0.155592,0.113821,0.149915,0.03235,-0.038154,-0.102814,0.364744,-0.000543,0.108923,0.129811,0.255257,0.161428,0.131891,0.146105,-0.147394,0.027552,-0.035525,-0.278403,0.124541,0.122495,0.117965,0.404937,-0.070498
3,-0.294891,0.21189,0.061987,-0.018384,-0.002906,-0.085502,-0.344018,0.004979,0.080335,-0.303516,0.162338,-0.112059,0.218604,-0.104201,0.311974,0.140519,0.176662,0.22641,-0.109118,0.018592,0.088117,-0.04605,0.132062,-0.438827,-0.211266,0.215469,0.046284,-0.195149,-0.235789,-0.169567,0.054132,-0.428622,-0.03286,0.047562,0.092482,0.606193,0.226337,-0.272084,-0.146256,0.248828,...,-0.134156,0.007131,-0.340334,0.18829,0.200545,-0.219851,0.00186,-0.140909,0.125717,0.002194,-0.06258,-0.022784,0.048569,0.217845,0.114779,0.001259,0.041265,-0.32713,-0.017834,0.16617,-0.104674,-0.128521,0.199192,0.021575,0.033676,-0.036859,0.058534,0.236719,0.294604,0.16678,-0.229979,-0.164436,0.149876,0.2124,-0.437927,0.156925,-0.166397,0.117495,0.07816,-0.37075
4,-0.217928,-0.257206,0.003999,-0.010809,0.019237,0.06663,-0.420439,0.188918,0.172918,-0.07668,0.010337,-0.287048,0.222951,-0.186517,0.071781,0.229308,-0.063299,-0.108122,0.03744,0.042467,-0.086294,0.102709,-0.073687,0.047708,-0.205369,0.355595,0.095391,-0.552123,-0.062507,-0.062732,-0.33627,-0.186323,-0.155956,0.449449,-0.010485,0.703118,0.079991,-0.158513,-0.078529,-0.17644,...,-0.187346,0.243681,-0.329644,0.08031,0.007752,-0.151513,0.08125,-0.358037,-0.188895,0.193158,-0.279578,-0.007327,0.048084,0.107231,0.098582,0.161298,-0.211673,-0.148438,0.013817,0.146621,-0.012112,-0.090765,-0.194526,0.116612,0.153502,0.167925,0.123219,0.106242,0.257059,0.200487,-0.486795,0.091707,0.08233,0.116555,-0.291279,0.382398,-0.324631,-0.012788,0.022908,-0.307021


In [7]:
data_BERT.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,728,729,730,731,732,733,734,735,736,737,738,739,740,741,742,743,744,745,746,747,748,749,750,751,752,753,754,755,756,757,758,759,760,761,762,763,764,765,766,767
0,-0.120891,-0.280994,-0.963245,0.419622,0.778366,-0.196418,-0.313293,0.23287,-0.862499,-0.955871,0.223536,0.886177,0.327486,0.866635,-0.249378,0.088374,0.439759,0.016313,0.064368,0.730549,0.462973,0.999983,-0.40831,0.307045,0.337647,0.908904,-0.068019,-0.055471,0.25283,0.328617,0.393255,0.119468,-0.622634,-0.270945,-0.978561,-0.187582,0.180457,0.12795,-0.101807,-0.25029,...,0.42801,-0.24705,-0.033118,-0.254504,-0.322066,-0.023203,-0.252481,-0.285563,0.193448,0.033795,0.999958,-0.736765,-0.848266,-0.186036,-0.336793,0.301686,-0.462011,-0.999997,0.255416,-0.859095,0.80337,-0.318131,0.902804,-0.754084,0.307099,-0.061251,0.62403,0.844899,-0.087035,-0.475938,0.161414,-0.928535,0.863065,-0.074633,-0.00226,-0.675741,0.185203,-0.855487,-0.17854,-0.240451
1,-0.183047,-0.320446,-0.967441,0.510457,0.806681,-0.212261,-0.239628,0.210489,-0.880438,-0.945462,0.122012,0.914262,0.214913,0.882294,-0.172117,0.000281,0.382229,0.015465,0.056818,0.700129,0.484552,0.999983,-0.460843,0.301718,0.323201,0.921945,-0.133789,-0.070382,0.238025,0.354671,0.378942,0.101619,-0.541697,-0.292166,-0.978894,-0.117593,0.185518,0.117982,-0.145663,-0.236753,...,0.481052,-0.270803,-0.04876,-0.217497,-0.346388,0.035731,-0.253299,-0.326463,0.175167,-0.009337,0.999946,-0.775669,-0.875688,-0.17833,-0.360491,0.283788,-0.45156,-0.999997,0.240596,-0.880769,0.822632,-0.391792,0.911289,-0.800638,0.229179,-0.054229,0.605605,0.866382,-0.10912,-0.466791,0.17515,-0.939977,0.888713,-0.089563,-0.091914,-0.731656,0.169998,-0.887561,-0.124455,-0.239807
2,-0.447093,-0.470325,-0.965326,0.46833,0.782669,-0.307996,-0.073427,0.399358,-0.898938,-0.996745,-0.013201,0.910904,0.592074,0.867216,0.129251,-0.309642,0.301718,-0.296469,0.227609,0.730816,0.562351,0.999993,-0.372397,0.492953,0.458219,0.939361,-0.354881,0.247452,0.618205,0.554835,0.12269,0.32039,-0.797047,-0.401551,-0.978959,-0.659457,0.303599,-0.156353,-0.21498,-0.144441,...,0.302321,-0.322595,-0.211455,-0.266762,-0.021651,-0.307902,-0.443819,-0.391257,0.400275,0.207513,0.999979,-0.793533,-0.903334,-0.252147,-0.451421,0.494141,-0.54382,-1.0,0.288275,-0.916234,0.872236,-0.40584,0.887067,-0.890248,-0.080239,-0.253946,0.607473,0.890332,-0.33116,-0.502541,0.636328,-0.899832,0.891958,0.028261,-0.23509,-0.60619,0.71639,-0.910605,-0.424803,-0.059388
3,-0.098515,-0.385608,-0.987904,0.494115,0.869421,-0.238608,-0.245373,0.274326,-0.947566,-0.950335,-0.018988,0.949433,-0.0147,0.939705,-0.249542,-0.130132,0.267692,-0.008533,0.047951,0.734438,0.511568,0.999997,-0.642022,0.352938,0.385997,0.963973,-0.198282,-0.205098,0.149731,0.367538,0.343082,0.175255,-0.408339,-0.325079,-0.989443,-0.010332,0.279057,0.168347,-0.1956,-0.278317,...,0.652145,-0.299753,-0.139638,-0.173683,-0.443671,-0.027985,-0.349703,-0.388918,0.165237,0.035473,0.999991,-0.882225,-0.953739,-0.24563,-0.413192,0.379641,-0.529157,-1.0,0.287768,-0.922833,0.899929,-0.592388,0.953749,-0.889039,0.282065,-0.129607,0.664973,0.92745,-0.143531,-0.5332,0.266515,-0.961368,0.948298,-0.195683,-0.299023,-0.849391,0.298872,-0.932983,-0.178504,-0.309659
4,-0.180124,-0.413501,-0.984492,0.474194,0.865759,-0.281325,-0.225921,0.305099,-0.935964,-0.974516,0.008402,0.932913,0.18234,0.928759,-0.195707,-0.15407,0.334482,-0.087378,0.10497,0.755661,0.508255,0.999997,-0.580859,0.43008,0.409025,0.948664,-0.238847,-0.117083,0.270822,0.436732,0.326114,0.217338,-0.518259,-0.354096,-0.987442,-0.17187,0.272495,0.096164,-0.216041,-0.242049,...,0.606793,-0.327609,-0.150611,-0.184259,-0.386002,-0.137363,-0.39325,-0.38306,0.206357,0.095422,0.999989,-0.869633,-0.944107,-0.264682,-0.423192,0.436482,-0.529781,-1.0,0.28089,-0.92591,0.897921,-0.559192,0.944515,-0.896119,0.235804,-0.174416,0.641793,0.917698,-0.214031,-0.523474,0.428578,-0.948966,0.931038,-0.180842,-0.408895,-0.821256,0.473733,-0.940421,-0.238981,-0.261727


In [8]:
# Load the dataset to get class information.
df = pd.read_json('https://raw.githubusercontent.com/selva86/datasets/master/newsgroups.json')
df.head()

Unnamed: 0,content,target,target_names
0,From: lerxst@wam.umd.edu (where's my thing)\nS...,7,rec.autos
1,From: guykuo@carson.u.washington.edu (Guy Kuo)...,4,comp.sys.mac.hardware
2,From: twillis@ec.ecn.purdue.edu (Thomas E Will...,4,comp.sys.mac.hardware
3,From: jgreen@amber (Joe Green)\nSubject: Re: W...,1,comp.graphics
4,From: jcm@head-cfa.harvard.edu (Jonathan McDow...,14,sci.space


In [9]:
data_stem['target']=df.target
data_stem['target_names']=df.target_names
data_stem.head()

Unnamed: 0,preprocessed,target,target_names
0,car wonder enlighten car saw dai door sport ca...,7,rec.autos
1,clock poll final final clock report acceler cl...,4,comp.sys.mac.hardware
2,question folk mac plu final gave ghost weekend...,4,comp.sys.mac.hardware
3,weitek robert kyanko rob rjck uucp wrote abrax...,1,comp.graphics
4,shuttl launch question articl cowcb world std ...,14,sci.space


In [14]:
data_w2v['target']=[df['target'].loc[i] for i in positive]
data_w2v['target_names']=[df['target_names'].loc[i] for i in positive]
data_w2v.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,target,target_names
0,0.351774,-0.270618,-0.099167,0.440975,0.147224,-0.04785,0.347204,-0.370328,-0.306052,-0.53304,0.784328,0.352205,-0.078587,-0.83025,0.189646,-0.017968,0.253096,0.125804,-0.177684,0.034206,0.132728,0.088732,-0.123832,-0.398519,0.014783,-0.270238,0.117907,0.391219,0.158757,0.173804,0.131754,-0.084592,-0.40822,-0.056334,-0.280491,-0.08024,0.259069,0.260064,-0.286499,-0.022011,...,-0.554667,0.198921,-0.189796,0.216042,-0.360367,0.454792,-0.272879,0.330489,0.399188,0.305223,-0.085042,0.435105,0.282648,0.01771,0.235993,-0.084351,-0.208245,0.192481,-0.13251,0.138845,0.169404,0.136261,-0.032994,-0.172526,-0.098615,0.172062,-0.217515,-0.404443,0.4921,-0.176473,0.651371,0.46328,-0.441628,-0.175508,0.143901,0.502239,-0.043058,-0.491982,7,rec.autos
1,0.008163,0.162567,0.056294,0.110691,0.077435,0.173565,0.030174,0.191341,0.136779,-0.149317,0.083286,0.085803,0.400404,-0.347317,0.033176,-0.430936,0.143813,0.261639,-0.115579,-0.062114,-0.088979,-0.347597,-0.098744,-0.244849,-0.298315,0.557203,0.301427,0.090291,-0.113943,0.139821,0.007065,-0.216194,-0.249048,0.365799,0.004725,0.116574,-0.111931,0.211696,0.011096,-0.09235,...,-0.382989,0.39463,-0.005884,-0.19418,-0.281401,0.188782,-0.145757,0.141406,-0.065778,0.286781,-0.13732,0.399824,0.202061,0.221534,0.074902,-0.190985,0.102994,0.424236,0.098024,0.057525,0.282348,0.437015,0.181157,0.176196,-0.049588,0.107599,0.250988,-0.072748,0.096912,0.354876,0.271327,0.113638,-0.151813,0.036743,0.010965,0.239501,0.214709,-0.38445,4,comp.sys.mac.hardware
2,0.18277,-0.100424,0.019315,0.334073,0.107887,0.124762,-0.113445,0.359388,-0.018454,-0.086599,0.106137,-0.052467,0.183271,-0.273118,0.272357,0.148054,0.107701,0.01081,-0.053615,-0.050433,-0.063988,-0.272693,-0.25213,-0.030147,0.038941,0.213088,0.046555,0.011535,-0.005292,0.148623,0.0758,-0.08096,-0.206258,0.305082,-0.191143,0.226707,0.109818,-0.082238,0.061433,0.001232,...,-0.46816,0.261627,-0.177639,-0.010834,-0.067605,0.086463,-0.112576,0.312297,0.121522,0.208696,-0.096691,0.161278,0.142064,0.319702,0.152462,-0.155592,0.113821,0.149915,0.03235,-0.038154,-0.102814,0.364744,-0.000543,0.108923,0.129811,0.255257,0.161428,0.131891,0.146105,-0.147394,0.027552,-0.035525,-0.278403,0.124541,0.122495,0.117965,0.404937,-0.070498,4,comp.sys.mac.hardware
3,-0.294891,0.21189,0.061987,-0.018384,-0.002906,-0.085502,-0.344018,0.004979,0.080335,-0.303516,0.162338,-0.112059,0.218604,-0.104201,0.311974,0.140519,0.176662,0.22641,-0.109118,0.018592,0.088117,-0.04605,0.132062,-0.438827,-0.211266,0.215469,0.046284,-0.195149,-0.235789,-0.169567,0.054132,-0.428622,-0.03286,0.047562,0.092482,0.606193,0.226337,-0.272084,-0.146256,0.248828,...,-0.340334,0.18829,0.200545,-0.219851,0.00186,-0.140909,0.125717,0.002194,-0.06258,-0.022784,0.048569,0.217845,0.114779,0.001259,0.041265,-0.32713,-0.017834,0.16617,-0.104674,-0.128521,0.199192,0.021575,0.033676,-0.036859,0.058534,0.236719,0.294604,0.16678,-0.229979,-0.164436,0.149876,0.2124,-0.437927,0.156925,-0.166397,0.117495,0.07816,-0.37075,1,comp.graphics
4,-0.217928,-0.257206,0.003999,-0.010809,0.019237,0.06663,-0.420439,0.188918,0.172918,-0.07668,0.010337,-0.287048,0.222951,-0.186517,0.071781,0.229308,-0.063299,-0.108122,0.03744,0.042467,-0.086294,0.102709,-0.073687,0.047708,-0.205369,0.355595,0.095391,-0.552123,-0.062507,-0.062732,-0.33627,-0.186323,-0.155956,0.449449,-0.010485,0.703118,0.079991,-0.158513,-0.078529,-0.17644,...,-0.329644,0.08031,0.007752,-0.151513,0.08125,-0.358037,-0.188895,0.193158,-0.279578,-0.007327,0.048084,0.107231,0.098582,0.161298,-0.211673,-0.148438,0.013817,0.146621,-0.012112,-0.090765,-0.194526,0.116612,0.153502,0.167925,0.123219,0.106242,0.257059,0.200487,-0.486795,0.091707,0.08233,0.116555,-0.291279,0.382398,-0.324631,-0.012788,0.022908,-0.307021,14,sci.space


In [10]:
data_BERT['target']=df.target
data_BERT['target_names']=df.target_names
data_BERT.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,730,731,732,733,734,735,736,737,738,739,740,741,742,743,744,745,746,747,748,749,750,751,752,753,754,755,756,757,758,759,760,761,762,763,764,765,766,767,target,target_names
0,-0.120891,-0.280994,-0.963245,0.419622,0.778366,-0.196418,-0.313293,0.23287,-0.862499,-0.955871,0.223536,0.886177,0.327486,0.866635,-0.249378,0.088374,0.439759,0.016313,0.064368,0.730549,0.462973,0.999983,-0.40831,0.307045,0.337647,0.908904,-0.068019,-0.055471,0.25283,0.328617,0.393255,0.119468,-0.622634,-0.270945,-0.978561,-0.187582,0.180457,0.12795,-0.101807,-0.25029,...,-0.033118,-0.254504,-0.322066,-0.023203,-0.252481,-0.285563,0.193448,0.033795,0.999958,-0.736765,-0.848266,-0.186036,-0.336793,0.301686,-0.462011,-0.999997,0.255416,-0.859095,0.80337,-0.318131,0.902804,-0.754084,0.307099,-0.061251,0.62403,0.844899,-0.087035,-0.475938,0.161414,-0.928535,0.863065,-0.074633,-0.00226,-0.675741,0.185203,-0.855487,-0.17854,-0.240451,7,rec.autos
1,-0.183047,-0.320446,-0.967441,0.510457,0.806681,-0.212261,-0.239628,0.210489,-0.880438,-0.945462,0.122012,0.914262,0.214913,0.882294,-0.172117,0.000281,0.382229,0.015465,0.056818,0.700129,0.484552,0.999983,-0.460843,0.301718,0.323201,0.921945,-0.133789,-0.070382,0.238025,0.354671,0.378942,0.101619,-0.541697,-0.292166,-0.978894,-0.117593,0.185518,0.117982,-0.145663,-0.236753,...,-0.04876,-0.217497,-0.346388,0.035731,-0.253299,-0.326463,0.175167,-0.009337,0.999946,-0.775669,-0.875688,-0.17833,-0.360491,0.283788,-0.45156,-0.999997,0.240596,-0.880769,0.822632,-0.391792,0.911289,-0.800638,0.229179,-0.054229,0.605605,0.866382,-0.10912,-0.466791,0.17515,-0.939977,0.888713,-0.089563,-0.091914,-0.731656,0.169998,-0.887561,-0.124455,-0.239807,4,comp.sys.mac.hardware
2,-0.447093,-0.470325,-0.965326,0.46833,0.782669,-0.307996,-0.073427,0.399358,-0.898938,-0.996745,-0.013201,0.910904,0.592074,0.867216,0.129251,-0.309642,0.301718,-0.296469,0.227609,0.730816,0.562351,0.999993,-0.372397,0.492953,0.458219,0.939361,-0.354881,0.247452,0.618205,0.554835,0.12269,0.32039,-0.797047,-0.401551,-0.978959,-0.659457,0.303599,-0.156353,-0.21498,-0.144441,...,-0.211455,-0.266762,-0.021651,-0.307902,-0.443819,-0.391257,0.400275,0.207513,0.999979,-0.793533,-0.903334,-0.252147,-0.451421,0.494141,-0.54382,-1.0,0.288275,-0.916234,0.872236,-0.40584,0.887067,-0.890248,-0.080239,-0.253946,0.607473,0.890332,-0.33116,-0.502541,0.636328,-0.899832,0.891958,0.028261,-0.23509,-0.60619,0.71639,-0.910605,-0.424803,-0.059388,4,comp.sys.mac.hardware
3,-0.098515,-0.385608,-0.987904,0.494115,0.869421,-0.238608,-0.245373,0.274326,-0.947566,-0.950335,-0.018988,0.949433,-0.0147,0.939705,-0.249542,-0.130132,0.267692,-0.008533,0.047951,0.734438,0.511568,0.999997,-0.642022,0.352938,0.385997,0.963973,-0.198282,-0.205098,0.149731,0.367538,0.343082,0.175255,-0.408339,-0.325079,-0.989443,-0.010332,0.279057,0.168347,-0.1956,-0.278317,...,-0.139638,-0.173683,-0.443671,-0.027985,-0.349703,-0.388918,0.165237,0.035473,0.999991,-0.882225,-0.953739,-0.24563,-0.413192,0.379641,-0.529157,-1.0,0.287768,-0.922833,0.899929,-0.592388,0.953749,-0.889039,0.282065,-0.129607,0.664973,0.92745,-0.143531,-0.5332,0.266515,-0.961368,0.948298,-0.195683,-0.299023,-0.849391,0.298872,-0.932983,-0.178504,-0.309659,1,comp.graphics
4,-0.180124,-0.413501,-0.984492,0.474194,0.865759,-0.281325,-0.225921,0.305099,-0.935964,-0.974516,0.008402,0.932913,0.18234,0.928759,-0.195707,-0.15407,0.334482,-0.087378,0.10497,0.755661,0.508255,0.999997,-0.580859,0.43008,0.409025,0.948664,-0.238847,-0.117083,0.270822,0.436732,0.326114,0.217338,-0.518259,-0.354096,-0.987442,-0.17187,0.272495,0.096164,-0.216041,-0.242049,...,-0.150611,-0.184259,-0.386002,-0.137363,-0.39325,-0.38306,0.206357,0.095422,0.999989,-0.869633,-0.944107,-0.264682,-0.423192,0.436482,-0.529781,-1.0,0.28089,-0.92591,0.897921,-0.559192,0.944515,-0.896119,0.235804,-0.174416,0.641793,0.917698,-0.214031,-0.523474,0.428578,-0.948966,0.931038,-0.180842,-0.408895,-0.821256,0.473733,-0.940421,-0.238981,-0.261727,14,sci.space


In [16]:
data_stem=data_stem.loc[data_stem.target_names.isin(['soc.religion.christian', 'rec.sport.hockey', 'talk.politics.mideast', 'rec.motorcycles']) , :]
data_w2v=data_w2v.loc[data_w2v.target_names.isin(['soc.religion.christian', 'rec.sport.hockey', 'talk.politics.mideast', 'rec.motorcycles']) , :]
data_BERT=data_BERT.loc[data_BERT.target_names.isin(['soc.religion.christian', 'rec.sport.hockey', 'talk.politics.mideast', 'rec.motorcycles']) , :]


In [17]:
data_stem.target_names.unique()

array(['rec.motorcycles', 'rec.sport.hockey', 'soc.religion.christian',
       'talk.politics.mideast'], dtype=object)

In [18]:
# Initialise the transformer
vec=TfidfVectorizer(max_df=0.7, min_df=0.1)

# Apply it to the data
vec_tf=vec.fit_transform(data_stem.preprocessed).toarray()
# Generate data frame
data_stem2=pd.DataFrame(vec_tf)
data_stem2['target']=data_stem['target'].values
data_stem2['target_names']=data_stem.target_names.values
data_stem2.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,target,target_names
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.677322,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.335449,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.365552,0.0,0.234338,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.280907,0.0,0.0,0.0,0.0,0.0,8,rec.motorcycles
1,0.166234,0.133132,0.0,0.0,0.220031,0.0,0.0,0.0,0.0,0.0,0.0,0.256887,0.0,0.0,0.506391,0.0,0.0,0.0,0.265473,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.497055,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.234536,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.237439,0.0,0.0,0.0,0.333475,0.0,0.172936,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.117344,0.0,10,rec.sport.hockey
2,0.090782,0.072706,0.0,0.0,0.0,0.0,0.0,0.0,0.136505,0.0,0.0,0.0,0.0,0.0,0.138274,0.0,0.0,0.0,0.289957,0.0,0.0,0.0,0.0,0.0,0.0,0.285169,0.0,0.124001,0.599195,0.0,0.135724,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.11795,0.142584,0.11744,0.0,0.0,0.0,0.127439,0.0,0.0,0.141333,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.454985,0.091058,0.0,0.094443,0.0,0.0,0.0,0.0,0.133906,0.0,0.0,0.0,0.0,0.0,0.064083,0.0,15,soc.religion.christian
3,0.174544,0.209682,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.132927,0.122023,0.0,0.0,0.278744,0.25977,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.130032,0.0,0.0,0.538798,0.136897,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.110437,0.0,0.0,0.0,0.0,0.0,0.230823,0.390982,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.087536,0.0,0.090791,0.0,0.0,0.120473,0.0,0.0,0.0,0.209864,0.0,0.0,0.0,0.184815,0.0,17,talk.politics.mideast
4,0.0,0.136988,0.0,0.0,0.0,0.260839,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.273162,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.264977,0.0,0.0,0.204252,0.0,0.258393,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.222237,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.2262,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.343133,0.0,0.177945,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.120743,0.196274,10,rec.sport.hockey


# B. Naïve Bayes classification model

Before we start, we will split the data in test and training data. Note that you need to set the random seed to make your results reproducable.

## Test/train split

In [21]:
def text_train(df):
  return train_test_split(df.iloc[:, :-2], df.target,test_size = 0.20, random_state = 12)
# Split data into training and test sets
docs_train_s, docs_test_s, y_train_s, y_test_s = text_train(data_stem2)
docs_train_w2v, docs_test_w2v, y_train_w2v, y_test_w2v = text_train(data_w2v)
docs_train_BERT, docs_test_BERT, y_train_BERT, y_test_BERT = text_train(data_BERT)

Here, we will first apply a simple Naive Bayes classifier. Then we will examine its performance.

In [22]:
# Train the model
clf = MultinomialNB()
clf.fit(docs_train_s, y_train_s)

MultinomialNB(alpha=1.0, class_prior=None, fit_prior=True)

## Performance analysis

The accuracy determines the percentage of correctly predicted targets.

In [23]:
# Predict the Test set results, determine accuracy
y_pred_s = clf.predict(docs_test_s)
print('Train accuracy: ', clf.score(docs_train_s, y_train_s))
print('Test accuracy: ', accuracy_score(y_pred_s, y_test_s))

Train accuracy:  0.8702330508474576
Test accuracy:  0.8372093023255814


We have three classes here and therefore if we randomly assign a datapoint to a class we have a 25% chance to be right. This means that the classifier on the test set is much better than a random classifier. Also, the test and training accuracy are close to each other, which means that there is no overfitting.

The classification report helps find class perfromance.

In [24]:
# Show the confusion matrix
print(classification_report(y_test_s, y_pred_s))

              precision    recall  f1-score   support

           8       0.82      0.87      0.84       134
          10       0.90      0.87      0.88       118
          15       0.79      0.87      0.83       106
          17       0.84      0.74      0.79       115

    accuracy                           0.84       473
   macro avg       0.84      0.84      0.84       473
weighted avg       0.84      0.84      0.84       473



Generally, the classes are equally well-estimated. Only class 10 has higher precision and class 17 lower recall.

# C.	Random Forest

In [25]:
# Train the model
clf2 = RandomForestClassifier(random_state = 42)
clf2.fit(docs_train_s, y_train_s)
# Predict the Test set results, determine accuracy
y_pred_s = clf2.predict(docs_test_s)
print('Train accuracy RF: ', clf2.score(docs_train_s, y_train_s))
print('Test accuracy RF: ', accuracy_score(y_pred_s, y_test_s))
# Show the confusion matrix
print(classification_report(y_test_s, y_pred_s))
print(clf2.n_estimators)
print(clf2.min_samples_leaf)

Train accuracy RF:  0.9936440677966102
Test accuracy RF:  0.8879492600422833
              precision    recall  f1-score   support

           8       0.93      0.90      0.92       134
          10       0.87      0.91      0.89       118
          15       0.85      0.86      0.85       106
          17       0.89      0.88      0.89       115

    accuracy                           0.89       473
   macro avg       0.89      0.89      0.89       473
weighted avg       0.89      0.89      0.89       473

100
1


-> Overfitting

In [28]:
# Train the model
clf2 = RandomForestClassifier(random_state = 42)
clf2.fit(docs_train_w2v, y_train_w2v)
# Predict the Test set results, determine accuracy
y_pred_w2v = clf2.predict(docs_test_w2v)
print('Train accuracy RF: ', clf2.score(docs_train_w2v, y_train_w2v))
print('Test accuracy RF: ', accuracy_score(y_pred_w2v, y_test_w2v))
# Show the confusion matrix
print(classification_report(y_test_w2v, y_pred_w2v))
clf2

Train accuracy RF:  1.0
Test accuracy RF:  0.9044585987261147
              precision    recall  f1-score   support

           8       0.80      0.96      0.88       107
          10       0.97      0.91      0.94       135
          15       0.91      0.94      0.92       109
          17       0.94      0.82      0.87       120

    accuracy                           0.90       471
   macro avg       0.91      0.91      0.90       471
weighted avg       0.91      0.90      0.90       471



RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,
                       criterion='gini', max_depth=None, max_features='auto',
                       max_leaf_nodes=None, max_samples=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=100,
                       n_jobs=None, oob_score=False, random_state=42, verbose=0,
                       warm_start=False)

Strong overfitting

In [29]:
# Train the model
clf2 = RandomForestClassifier(random_state = 42)
clf2.fit(docs_train_BERT, y_train_BERT)
# Predict the Test set results, determine accuracy
y_pred_BERT = clf2.predict(docs_test_BERT)
print('Train accuracy RF: ', clf2.score(docs_train_BERT, y_train_BERT))
print('Test accuracy RF: ', accuracy_score(y_pred_BERT, y_test_BERT))
# Show the confusion matrix
print(classification_report(y_test_BERT, y_pred_BERT))
clf2

Train accuracy RF:  1.0
Test accuracy RF:  0.7124735729386892
              precision    recall  f1-score   support

           8       0.71      0.75      0.73       134
          10       0.67      0.79      0.72       118
          15       0.69      0.73      0.71       106
          17       0.82      0.57      0.68       115

    accuracy                           0.71       473
   macro avg       0.72      0.71      0.71       473
weighted avg       0.72      0.71      0.71       473



RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,
                       criterion='gini', max_depth=None, max_features='auto',
                       max_leaf_nodes=None, max_samples=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=100,
                       n_jobs=None, oob_score=False, random_state=42, verbose=0,
                       warm_start=False)

Also overfitting.

# D.	Grid search


In the following, we will apply Grid Search for the Random Forest Classifier to determine the minimum samples in a leaf and the number of trees. We are using a ten-fold-cross validation here.

In [30]:
# Initialise the set of parameters you would like to test in cross validation
param_grid = {
    'min_samples_leaf': [5,10],
    'n_estimators': [3,5]
}

In [31]:
# Build the grid search
rf = RandomForestClassifier(random_state = 42)
grid_search = GridSearchCV(estimator = rf, param_grid = param_grid, cv = 10)

In [32]:
# Fit the grid search to the data
grid_search.fit(docs_train_s, y_train_s)

GridSearchCV(cv=10, error_score=nan,
             estimator=RandomForestClassifier(bootstrap=True, ccp_alpha=0.0,
                                              class_weight=None,
                                              criterion='gini', max_depth=None,
                                              max_features='auto',
                                              max_leaf_nodes=None,
                                              max_samples=None,
                                              min_impurity_decrease=0.0,
                                              min_impurity_split=None,
                                              min_samples_leaf=1,
                                              min_samples_split=2,
                                              min_weight_fraction_leaf=0.0,
                                              n_estimators=100, n_jobs=None,
                                              oob_score=False, random_state=42,
                                 

In [33]:
# Best model
grid_search.best_params_

{'min_samples_leaf': 5, 'n_estimators': 5}

In [34]:
best_grid = grid_search.best_estimator_
y_pred_grid = best_grid.predict(docs_test_s)
print('Train accuracy: ', best_grid.score(docs_train_s, y_train_s))
print('Test accuracy: ', accuracy_score(y_test_s, y_pred_s))
print(classification_report(y_test_s, y_pred_s))

Train accuracy:  0.9004237288135594
Test accuracy:  0.8879492600422833
              precision    recall  f1-score   support

           8       0.93      0.90      0.92       134
          10       0.87      0.91      0.89       118
          15       0.85      0.86      0.85       106
          17       0.89      0.88      0.89       115

    accuracy                           0.89       473
   macro avg       0.89      0.89      0.89       473
weighted avg       0.89      0.89      0.89       473



The model has less overfitting, same accuracy.