# Exercise 4. Text Representation Part 2



In this exercise we will apply the following models to the stemmed data from Exercise 2:

1.   Word2Vec
2.   Doc2vec
3.   BERT

At the end, we will derive a corpus with each of them which can be used in downstream tasks such as classification and clustering (see next exercises).


In [1]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/d8/b2/57495b5309f09fa501866e225c84532d1fd89536ea62406b2181933fb418/transformers-4.5.1-py3-none-any.whl (2.1MB)
[K     |▏                               | 10kB 28.7MB/s eta 0:00:01[K     |▎                               | 20kB 34.5MB/s eta 0:00:01[K     |▌                               | 30kB 20.8MB/s eta 0:00:01[K     |▋                               | 40kB 22.9MB/s eta 0:00:01[K     |▉                               | 51kB 25.1MB/s eta 0:00:01[K     |█                               | 61kB 27.8MB/s eta 0:00:01[K     |█▏                              | 71kB 18.9MB/s eta 0:00:01[K     |█▎                              | 81kB 20.3MB/s eta 0:00:01[K     |█▍                              | 92kB 18.8MB/s eta 0:00:01[K     |█▋                              | 102kB 18.9MB/s eta 0:00:01[K     |█▊                              | 112kB 18.9MB/s eta 0:00:01[K     |██                              | 

In [2]:
# Import packages
import pickle
import pandas as pd
from gensim.models import Word2Vec
from scipy.spatial.distance import cosine
from gensim.models.doc2vec import Doc2Vec, TaggedDocument
import tensorflow as tf
import torch
from transformers import BertTokenizer, BertModel
from keras.preprocessing.sequence import pad_sequences
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

## 0. Load data

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
# Import dataset
data_lemma=pickle.load(open("/content/drive/MyDrive/TWSM_Data/Lemma.pkl", "rb"))
print(data_lemma[0])

car wonder enlighten car see day door sport car look late early call bricklin door small addition bumper separate rest body know tellme model engine specs year production car history info funky looking car mail thank


## 1. Word2Vec


In this section we will train the word2vec model on the lemmatized data. 


In [5]:
# Prepare the dataset for the word2vec model
corpus_gen=[doc.split() for doc in data_lemma]

# Train the model for embeddings of size 100 considering words appearing in more than 566 documents, default window=5
model = Word2Vec(corpus_gen, size=100, min_count=566)
model.save('word2vec.model')

In [None]:
print([i for i in sorted(model.wv.vocab.keys())])

['able', 'accept', 'access', 'act', 'action', 'actually', 'add', 'address', 'advance', 'ago', 'agree', 'air', 'allow', 'american', 'answer', 'anti', 'anybody', 'appear', 'apple', 'application', 'apply', 'appreciate', 'apr', 'april', 'area', 'argument', 'arm', 'armenian', 'armenians', 'article', 'ask', 'assume', 'atheist', 'attack', 'attempt', 'available', 'away', 'bad', 'base', 'bear', 'begin', 'belief', 'believe', 'better', 'bible', 'big', 'bike', 'bit', 'black', 'board', 'body', 'book', 'box', 'break', 'bring', 'build', 'bus', 'business', 'buy', 'call', 'canada', 'car', 'card', 'care', 'carry', 'case', 'cause', 'center', 'certain', 'certainly', 'change', 'cheap', 'check', 'child', 'chip', 'choose', 'christ', 'christian', 'christians', 'church', 'city', 'claim', 'clear', 'clinton', 'clipper', 'close', 'code', 'color', 'com', 'come', 'command', 'comment', 'common', 'company', 'condition', 'consider', 'contact', 'contain', 'continue', 'control', 'copy', 'correct', 'cost', 'country', 'co

In [None]:
# Embedding for 'car'
vector = model.wv['car']
vector

array([ 1.4692358 , -0.8023282 , -0.13441241,  0.98153615,  0.7599832 ,
       -0.40347973,  1.0198705 , -1.5354207 , -0.9257218 , -1.0203286 ,
        2.215068  ,  0.6632183 , -0.18657741, -1.9154484 ,  0.44888872,
       -0.02129542,  0.798015  ,  0.37241736, -0.9987019 , -0.11381808,
        1.187318  ,  0.87888634,  0.7570844 , -1.2267259 ,  0.18649325,
       -0.83251554,  0.41996965,  0.52846485, -0.25219265,  0.4602982 ,
        0.49639538, -0.14283605, -0.30835825, -0.32588053,  0.39061022,
       -0.6104033 ,  0.7260111 ,  0.17973503, -0.89412355, -0.06654264,
       -0.2113204 ,  0.61939704, -1.3517871 , -1.120821  , -1.8280628 ,
       -0.1121858 , -0.88791466, -1.8047341 , -0.11037601,  1.8896706 ,
        0.69367045, -0.6998514 , -1.2943509 , -0.42140684,  0.11713248,
        1.9989212 ,  1.23207   , -0.23176973,  0.6543441 ,  0.05240316,
        0.2936794 , -0.62970936, -1.8247054 ,  0.63928753, -0.51164234,
        0.20888922, -1.4701684 ,  0.9333078 , -0.9873637 ,  0.15

In [None]:
# Most similar representations to 'car' based on cosine similarity
model.wv.most_similar('car')

[('bike', 0.5988668203353882),
 ('buy', 0.5692394971847534),
 ('friend', 0.5084896087646484),
 ('get', 0.4868040978908539),
 ('light', 0.4818478226661682),
 ('guy', 0.44782453775405884),
 ('sell', 0.44778281450271606),
 ('hit', 0.44717222452163696),
 ('price', 0.44604820013046265),
 ('figure', 0.4446602463722229)]

In [None]:
# Embeddings' arithmetics
model.wv.most_similar(positive=['bike', 'machine'], topn=1)

[('fast', 0.6370933651924133)]

In the following we will derive the corpus. Note that word2vec (as opposed to doc2vec) generates one embedding for each word in the document. These then need to be aggregated at a document level. The simplest way is to determine the average over all words, but you can also use other aggregators.

In [6]:
# Document representation for the text
corpus_w2v=[[model.wv[word] for word in doc if word in model.wv.vocab.keys()] for doc in corpus_gen]
positive=[i for i in range(len(corpus_gen)) if len(corpus_w2v[i])>0]

corpus_w2v2=[corpus_w2v[i] for i in positive]
data_lemma2=[data_lemma[i] for i in positive]

# Document average representation
corpus_w2v_avg_clean=[sum(words)/len(words) for words in corpus_w2v2]

# This corpus can be used later in clustering and classification tasks
print(corpus_w2v_avg_clean[10])

[-0.22174568  0.08236665  0.19761358  0.4131328  -0.24650982  0.05626864
 -0.23169933  0.05433622 -0.04930507  0.25527987 -0.28210744 -0.05326257
  0.49531323  0.4935223  -0.47186023 -0.3237659   0.09915254  0.24354464
  0.18197747  0.30309755 -0.21782619  0.15350541  0.08337717  0.05905351
  0.06262911 -0.00747794 -0.3849919  -0.04637662  0.13874954 -0.43448052
 -0.32311612  0.1476274   0.32036453 -0.28070176 -0.33536777  0.08565568
  0.40192673 -0.04286042  0.29775962 -0.3766572   0.4520554   0.17600219
 -0.24697562  0.16031866  0.13887359 -0.11053925  0.05216295  0.08006312
 -0.40994582  0.20044981  0.10320445  0.362728    0.10350568  0.31597996
  0.39345226 -0.27563915 -0.1057001  -0.5896429  -0.30657113 -0.2834169
  0.40558192 -0.49381116  0.19384031 -0.29545584 -0.09935518  0.33083814
  0.3685488   0.06066651 -0.08403632 -0.01092415 -0.27837294 -0.04147279
 -0.18091689  0.31867737  0.13949792 -0.292108    0.01628524 -0.36495912
  0.40776235  0.0194951   0.12885498 -0.0934267   0.

In [None]:
len(corpus_w2v)

11314

In [None]:
len(corpus_w2v_avg_clean)

11298

In [None]:
len(data_lemma2)

11298

In [7]:
model.wv.similar_by_vector(corpus_w2v_avg_clean[0])

[('car', 0.860916256904602),
 ('friend', 0.6163879036903381),
 ('month', 0.5582113862037659),
 ('buy', 0.5579794049263),
 ('bike', 0.555242657661438),
 ('get', 0.5453658103942871),
 ('see', 0.4985682964324951),
 ('price', 0.4835365414619446),
 ('look', 0.47569817304611206),
 ('figure', 0.4657348394393921)]

In [None]:
# Most simlar words to the document based on average representation
# This can be used to evaluate different aggregation methods and also provides interpretation of the document representation
print([token for (token,_) in model.wv.similar_by_vector(corpus_w2v_avg_clean[0])])

# cosine similarity to other documents
result=[(1 - cosine(corpus_w2v_avg_clean[0],corpus_w2v_avg_clean[i])) for i in range(1,len(corpus_w2v_avg_clean))]
most_similar=data_lemma2[result.index(max(result))+1]
print(data_lemma2[0])
print('')
print(most_similar)

['car', 'friend', 'buy', 'see', 'bike', 'get', 'look', 'month', 'lot', 'remember']
car wonder enlighten car see day door sport car look late early call bricklin door small addition bumper separate rest body know tellme model engine specs year production car history info funky looking car mail thank

aussie need info car show australia car enthusiast australia particularly interested american muscle car make amc ford chrysler mopar usa weeks june chicago sun thursday denver friday sunday austin texas monday friday oklahoma city friday monday anaheim california tuesday thursday las vegas nevada friday sunday grand canion monday tuesday june las angeles san diego vicinity wednesday june sunday june june south lake tahoe cal sunday june wednesday june reno thursday june san fransisco thursday june sunday june wonder send information car show swap meets drag meet model car show period anybody tell pomona swap meet year place visit car museum private collection collection bit information app

In [None]:
len(result)

11297

In [None]:
# Corpus as data frame that can be used in downstream tasks such as classification
corpus_w2v_avg_df=pd.DataFrame(corpus_w2v_avg_clean)
corpus_w2v_avg_df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99
0,0.351774,-0.270618,-0.099167,0.440975,0.147224,-0.04785,0.347204,-0.370328,-0.306052,-0.53304,0.784328,0.352205,-0.078587,-0.83025,0.189646,-0.017968,0.253096,0.125804,-0.177684,0.034206,0.132728,0.088732,-0.123832,-0.398519,0.014783,-0.270238,0.117907,0.391219,0.158757,0.173804,0.131754,-0.084592,-0.40822,-0.056334,-0.280491,-0.08024,0.259069,0.260064,-0.286499,-0.022011,...,0.080258,-0.143452,-0.554667,0.198921,-0.189796,0.216042,-0.360367,0.454792,-0.272879,0.330489,0.399188,0.305223,-0.085042,0.435105,0.282648,0.01771,0.235993,-0.084351,-0.208245,0.192481,-0.13251,0.138845,0.169404,0.136261,-0.032994,-0.172526,-0.098615,0.172062,-0.217515,-0.404443,0.4921,-0.176473,0.651371,0.46328,-0.441628,-0.175508,0.143901,0.502239,-0.043058,-0.491982
1,0.008163,0.162567,0.056294,0.110691,0.077435,0.173565,0.030174,0.191341,0.136779,-0.149317,0.083286,0.085803,0.400404,-0.347317,0.033176,-0.430936,0.143813,0.261639,-0.115579,-0.062114,-0.088979,-0.347597,-0.098744,-0.244849,-0.298315,0.557203,0.301427,0.090291,-0.113943,0.139821,0.007065,-0.216194,-0.249048,0.365799,0.004725,0.116574,-0.111931,0.211696,0.011096,-0.09235,...,-0.00641,0.043633,-0.382989,0.39463,-0.005884,-0.19418,-0.281401,0.188782,-0.145757,0.141406,-0.065778,0.286781,-0.13732,0.399824,0.202061,0.221534,0.074902,-0.190985,0.102994,0.424236,0.098024,0.057525,0.282348,0.437015,0.181157,0.176196,-0.049588,0.107599,0.250988,-0.072748,0.096912,0.354876,0.271327,0.113638,-0.151813,0.036743,0.010965,0.239501,0.214709,-0.38445
2,0.18277,-0.100424,0.019315,0.334073,0.107887,0.124762,-0.113445,0.359388,-0.018454,-0.086599,0.106137,-0.052467,0.183271,-0.273118,0.272357,0.148054,0.107701,0.01081,-0.053615,-0.050433,-0.063988,-0.272693,-0.25213,-0.030147,0.038941,0.213088,0.046555,0.011535,-0.005292,0.148623,0.0758,-0.08096,-0.206258,0.305082,-0.191143,0.226707,0.109818,-0.082238,0.061433,0.001232,...,0.198873,0.365456,-0.46816,0.261627,-0.177639,-0.010834,-0.067605,0.086463,-0.112576,0.312297,0.121522,0.208696,-0.096691,0.161278,0.142064,0.319702,0.152462,-0.155592,0.113821,0.149915,0.03235,-0.038154,-0.102814,0.364744,-0.000543,0.108923,0.129811,0.255257,0.161428,0.131891,0.146105,-0.147394,0.027552,-0.035525,-0.278403,0.124541,0.122495,0.117965,0.404937,-0.070498
3,-0.294891,0.21189,0.061987,-0.018384,-0.002906,-0.085502,-0.344018,0.004979,0.080335,-0.303516,0.162338,-0.112059,0.218604,-0.104201,0.311974,0.140519,0.176662,0.22641,-0.109118,0.018592,0.088117,-0.04605,0.132062,-0.438827,-0.211266,0.215469,0.046284,-0.195149,-0.235789,-0.169567,0.054132,-0.428622,-0.03286,0.047562,0.092482,0.606193,0.226337,-0.272084,-0.146256,0.248828,...,-0.134156,0.007131,-0.340334,0.18829,0.200545,-0.219851,0.00186,-0.140909,0.125717,0.002194,-0.06258,-0.022784,0.048569,0.217845,0.114779,0.001259,0.041265,-0.32713,-0.017834,0.16617,-0.104674,-0.128521,0.199192,0.021575,0.033676,-0.036859,0.058534,0.236719,0.294604,0.16678,-0.229979,-0.164436,0.149876,0.2124,-0.437927,0.156925,-0.166397,0.117495,0.07816,-0.37075
4,-0.217928,-0.257206,0.003999,-0.010809,0.019237,0.06663,-0.420439,0.188918,0.172918,-0.07668,0.010337,-0.287048,0.222951,-0.186517,0.071781,0.229308,-0.063299,-0.108122,0.03744,0.042467,-0.086294,0.102709,-0.073687,0.047708,-0.205369,0.355595,0.095391,-0.552123,-0.062507,-0.062732,-0.33627,-0.186323,-0.155956,0.449449,-0.010485,0.703118,0.079991,-0.158513,-0.078529,-0.17644,...,-0.187346,0.243681,-0.329644,0.08031,0.007752,-0.151513,0.08125,-0.358037,-0.188895,0.193158,-0.279578,-0.007327,0.048084,0.107231,0.098582,0.161298,-0.211673,-0.148438,0.013817,0.146621,-0.012112,-0.090765,-0.194526,0.116612,0.153502,0.167925,0.123219,0.106242,0.257059,0.200487,-0.486795,0.091707,0.08233,0.116555,-0.291279,0.382398,-0.324631,-0.012788,0.022908,-0.307021


In [None]:
len(corpus_w2v_avg_df)

11298

In [None]:
pickle.dump(corpus_w2v_avg_df, open("/content/drive/MyDrive/TWSM_Data/WordtoVecModel.pkl", "wb"))

## 2. Doc2Vec

In [None]:
# Run doc2vec on the tagged texts
documents = [TaggedDocument(doc, [i]) for i, doc in enumerate(corpus_gen)]
model2 = Doc2Vec(documents, vector_size=100, min_count=566)

In [None]:
# Embedding for the first document
vector = model2.infer_vector(corpus_gen[0])
vector

array([-0.05122636, -0.06806806, -0.01494412,  0.01662492, -0.0822029 ,
       -0.02819622, -0.02895089, -0.06420164, -0.04436788, -0.16889516,
        0.06832569,  0.00772923,  0.0206603 , -0.10237534,  0.06752286,
       -0.03361692,  0.00732967, -0.01266144, -0.03883754, -0.01195568,
        0.01666871, -0.06581076, -0.03252764, -0.03990467,  0.02596428,
       -0.05007704, -0.02736584,  0.06371695,  0.01925586, -0.04054766,
        0.02015869, -0.08511249, -0.01322431, -0.00986988,  0.02593201,
        0.05694926,  0.02162211, -0.00424531,  0.02517235,  0.04707081,
       -0.03879257,  0.07847797, -0.06109332, -0.06253222, -0.0751923 ,
       -0.0168055 ,  0.05226422, -0.06491261, -0.03435625,  0.0289027 ,
        0.01662089, -0.00954358, -0.05242245,  0.05766357,  0.09904396,
        0.04716758,  0.00344208,  0.02078508, -0.02033357, -0.00767193,
       -0.01240751, -0.0629037 , -0.06261925,  0.04491214,  0.03900221,
        0.01412876, -0.034928  ,  0.04426256, -0.0885656 ,  0.05

In [None]:

# cosine similarity to other documents
result=[(1 - cosine(vector,model2.infer_vector(corpus_gen[i]))) for i in range(1,len(corpus_gen))]
most_similar=data_lemma[result.index(max(result))+1]

print(most_similar)

car wonder enlighten car see day door sport car look late early call bricklin door small addition bumper separate rest body know tellme model engine specs year production car history info funky looking car mail thank

trading car pay pointer article fibercom com rrg rtp fibercom com rhonda gaines write plan purchase new car trading mazda get year pay take account purchase new car dealership pay car add pay purchase price new car explain know bank credit union finance company hold loan present car current payoff cost trading current car new car subtract payoff trade dealer give turn negative number need reconsider deal subtract difference price new car size loan need new car dealer care pay loan old car money pick new car work year ago ohio thank rhonda joseph staudt telxon corp joes telxon com box usenet like tetris people akron remember read heller


In [None]:
len(corpus_gen)

11314

In [None]:
# Final corpus for classification
corpus_d2v=pd.DataFrame([model2.infer_vector(doc) for doc in corpus_gen])

In [None]:
corpus_d2v.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99
0,-0.033262,-0.049674,-0.016017,-0.001603,-0.072937,-0.005914,-0.026601,-0.045852,-0.01693,-0.113542,0.031533,0.018571,0.016846,-0.076718,0.076404,-0.023871,-0.0255,-0.00935,-0.024001,-0.007181,0.020286,-0.081678,-0.031725,-0.052242,0.01991,-0.020617,-0.027672,0.039216,0.023487,-0.037598,0.004926,-0.072705,0.001625,-0.003011,-0.001036,0.049824,0.028008,0.002267,0.019915,0.065455,...,-0.025305,-0.02311,-0.058257,0.040698,0.008378,-0.013546,-0.008874,0.044304,-0.06403,0.055633,0.012384,0.01302,-0.019584,0.092485,0.051242,-0.019956,-0.005542,-0.038789,0.012394,0.112057,0.005942,-0.018436,0.020149,0.001158,0.028916,0.032556,-0.0546,0.023579,-0.005249,-0.010699,-0.040269,-0.002415,0.076182,0.014719,-0.048158,-0.020135,-0.013903,0.057317,-0.0199,-0.095958
1,0.033676,-0.049887,0.027683,0.026024,0.009057,0.031365,0.043312,0.064511,0.000866,0.025403,-0.002454,0.018382,0.03913,-0.00685,-0.01435,-0.055728,0.030716,0.000306,0.003765,-0.016803,-0.014938,-0.03831,-0.002253,-0.022165,-0.052017,0.081576,0.03585,-0.030607,-0.053455,0.018713,-0.038902,-0.024012,-0.053146,0.049171,0.023251,-0.009203,-0.067071,0.063215,-0.02507,-0.054228,...,0.021282,-0.060994,-0.010744,0.019836,-0.016958,-0.017521,0.002989,-0.019094,-0.016233,-0.008472,0.036808,0.034624,-0.040458,0.01958,0.060933,0.037337,0.013321,0.005449,0.009052,0.081936,0.01938,0.027454,0.003585,0.051145,0.02589,0.026966,-0.018686,-0.037749,0.004478,0.014192,-0.030557,0.052317,0.027663,-0.025419,-0.022255,0.00958,-0.03164,0.009603,0.051909,-0.120971
2,0.042794,-0.064218,0.037469,0.093485,0.003541,0.061478,-0.071684,0.193311,-0.005121,0.01881,-0.039404,-0.109424,-0.021143,-0.039642,0.025812,0.021862,0.05147,-0.07958,0.020494,0.047878,-0.039136,-0.059271,-0.165079,0.074292,-0.091949,0.158054,-0.017687,-0.000769,0.042526,0.075507,0.00149,-0.064246,-0.012298,0.086518,-0.076071,0.138779,0.048218,-0.011923,-0.065729,-0.025097,...,0.055347,0.144704,-0.073482,0.006659,0.015396,0.022823,-0.119585,-0.052392,0.034922,0.130359,0.145847,-0.0485,0.007007,-0.040287,0.057798,0.059772,-0.017194,-0.084103,0.049472,0.164592,0.057187,0.071866,-0.005872,0.117578,0.090215,0.102958,0.000502,-0.034454,-0.038481,0.036949,0.084788,-0.06647,-0.023559,0.05539,-0.037519,0.115474,0.023027,0.075927,0.156555,-0.033438
3,0.058012,0.08599,0.036921,0.074976,0.007012,-0.034835,-0.044319,-0.029044,0.054942,-0.036614,0.044074,-0.005479,0.03393,0.002474,-0.007639,0.031253,0.040478,0.107811,0.012438,0.020723,0.018761,-0.085365,0.060698,-0.024213,-0.011794,0.072897,0.023021,0.043049,-0.02976,-0.006028,0.019067,-0.041945,0.039625,0.007159,0.011404,0.06818,0.027937,-0.010359,0.01289,0.086288,...,-0.026592,0.030936,-0.044287,0.021562,0.034602,-0.039755,-0.01935,0.015115,0.038782,0.00792,-0.006176,0.012778,-0.052613,0.022979,0.021288,-0.009055,0.059402,-0.070373,-0.008902,0.040143,-0.011268,-0.01951,-0.044945,0.032114,-0.023231,0.075448,-0.042095,0.04781,0.098185,0.032535,-0.033829,0.014599,-0.026878,0.022225,-0.126712,0.017064,-0.001792,0.023652,0.031867,-0.029106
4,0.010468,-0.031585,0.066272,0.047936,-0.008234,-0.043216,-0.079181,0.037457,-0.020777,-0.045592,0.013054,-0.006131,0.090572,-0.037595,0.025614,0.006592,-0.075415,-0.015045,0.025322,0.072738,-0.026502,-0.113079,0.017322,-0.013613,0.008994,0.104365,0.014429,-0.032148,0.002166,-0.008551,-0.008472,-0.011973,-0.009424,0.084865,-0.003347,0.104144,0.068802,0.031158,-0.037503,0.027462,...,-0.048934,0.036901,-0.04948,0.013555,-0.092805,0.009015,0.100562,-0.073604,-0.044452,0.039012,-0.015884,0.002602,-0.02734,0.038119,0.017468,0.08468,-0.03127,-0.013489,0.009836,0.146343,-0.038083,-0.079363,-0.114853,0.083276,0.001444,0.085579,-0.015446,0.033404,0.061184,0.052265,-0.072259,0.04362,-0.01797,0.092339,-0.110279,0.027145,0.00193,0.029041,0.047018,-0.018687


In [None]:
pickle.dump(corpus_d2v, open("/content/drive/MyDrive/TWSM_Data/DoctoVecModel.pkl", "wb"))

## 3. BERT


Confirm that GPU is detected:

In [None]:
# Get the GPU device name.
device_name = tf.test.gpu_device_name()

# The device name should look like the following:
if device_name == '/device:GPU:0':
    print('Found GPU at: {}'.format(device_name))
else:
    raise SystemError('GPU device not found')

Found GPU at: /device:GPU:0


Assign the GPU device to torch:

In [None]:
# If there's a GPU available...
if torch.cuda.is_available():    

    # Tell PyTorch to use the GPU.    
    device = torch.device("cuda")

    print('There are %d GPU(s) available.' % torch.cuda.device_count())

    print('We will use the GPU:', torch.cuda.get_device_name(0))

# If not...
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

There are 1 GPU(s) available.
We will use the GPU: Tesla K80


In order to apply BERT, we need to derive three data objects for the text data:
1. Add [CLS] at the beginning and [SEP] at the end of each text. [SEP] is a legacy from teh model training. The result for [CLS] is then used later as document representation for classification tasks.
2. Tokenize the texts using BERT tokenizer
3. Pad or truncate the text to the maximum length (maximum 512)
4. Map the remaining tokens to BERT dictionary 





In [None]:
# 1. Add [CLS] at the beginning and [SEP] at the end of each text.
sentences = ["[CLS] " + query + " [SEP]" for query in data_lemma]
print(sentences[0])

[CLS] car wonder enlighten car see day door sport car look late early call bricklin door small addition bumper separate rest body know tellme model engine specs year production car history info funky looking car mail thank [SEP]


In [None]:
# 2. Tokenize the texts using BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokenized_texts = [tokenizer.tokenize(sent) for sent in sentences]
print(tokenized_texts[0])

['[CLS]', 'car', 'wonder', 'en', '##light', '##en', 'car', 'see', 'day', 'door', 'sport', 'car', 'look', 'late', 'early', 'call', 'brick', '##lin', 'door', 'small', 'addition', 'bumper', 'separate', 'rest', 'body', 'know', 'tell', '##me', 'model', 'engine', 'spec', '##s', 'year', 'production', 'car', 'history', 'info', 'funky', 'looking', 'car', 'mail', 'thank', '[SEP]']


In [None]:
# Show token IDs based on BERT's training
print(tokenizer.convert_tokens_to_ids(tokenized_texts[0]))

[101, 2482, 4687, 4372, 7138, 2368, 2482, 2156, 2154, 2341, 4368, 2482, 2298, 2397, 2220, 2655, 5318, 4115, 2341, 2235, 2804, 21519, 3584, 2717, 2303, 2113, 2425, 4168, 2944, 3194, 28699, 2015, 2095, 2537, 2482, 2381, 18558, 24151, 2559, 2482, 5653, 4067, 102]


In order to determine the maximum sequence length, we look at the list statistics.

In [None]:
leng=[]
for t in tokenized_texts:
  leng.append(len(t))
df=pd.DataFrame(leng)
df.describe()

Unnamed: 0,0
count,11314.0
mean,170.796182
std,394.87109
min,4.0
25%,63.0
50%,103.0
75%,166.0
max,8235.0


In [None]:
df.quantile([.95, .99])

Unnamed: 0,0
0.95,412.0
0.99,1304.09


In [None]:
# 3. Pad the text to the maximum length, max 512

# Pad sequences that are less than MAX_LEN, if more, remove from the end
sentences_padded = pad_sequences(tokenized_texts,  dtype=object,maxlen=412,  value='[PAD]', truncating="post",padding="post", return_tensors = 'pt')
print(sentences_padded[0])

['[CLS]' 'car' 'wonder' 'en' '##light' '##en' 'car' 'see' 'day' 'door'
 'sport' 'car' 'look' 'late' 'early' 'call' 'brick' '##lin' 'door' 'small'
 'addition' 'bumper' 'separate' 'rest' 'body' 'know' 'tell' '##me' 'model'
 'engine' 'spec' '##s' 'year' 'production' 'car' 'history' 'info' 'funky'
 'looking' 'car' 'mail' 'thank' '[SEP]' '[PAD]' '[PAD]' '[PAD]' '[PAD]'
 '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]'
 '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]'
 '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]'
 '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]'
 '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]'
 '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]'
 '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]'
 '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]'
 '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' '[PAD]' 

In [None]:
#4. Map the tokens to BERT dictionary 
# Convert the tokens to their index numbers in the BERT vocabulary
sentences_converted = [tokenizer.convert_tokens_to_ids(s) for s in sentences_padded]
print(sentences_converted[0])

[101, 2482, 4687, 4372, 7138, 2368, 2482, 2156, 2154, 2341, 4368, 2482, 2298, 2397, 2220, 2655, 5318, 4115, 2341, 2235, 2804, 21519, 3584, 2717, 2303, 2113, 2425, 4168, 2944, 3194, 28699, 2015, 2095, 2537, 2482, 2381, 18558, 24151, 2559, 2482, 5653, 4067, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [None]:
# Create attention masks
masks = []

# Create a mask of 1s for each token followed by 0s for padding
for seq in sentences_converted:
  seq_mask = [float(i>0) for i in seq]
  masks.append(seq_mask)

In [None]:
# 5. Generate embeddings

#Convert all of our data into torch tensors, the required datatype for our model

inputs = torch.LongTensor(sentences_converted)
masks = torch.LongTensor(masks)

In [None]:
inputs.size()

torch.Size([11314, 412])

In [None]:
masks.size()

torch.Size([11314, 412])

In [None]:
# Apply Pretrained model to the sentences
model = BertModel.from_pretrained('bert-base-uncased')

In [None]:
model.to(device)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [None]:
# Set the batch size.  
batch_size = 16  

# Create the DataLoader.
prediction_data = TensorDataset(inputs, masks)
prediction_sampler = SequentialSampler(prediction_data)
prediction_dataloader = DataLoader(prediction_data, sampler=prediction_sampler, batch_size=batch_size)

In [None]:
result=[]
i=0
for batch in prediction_dataloader:
  #print(i)
  # Add batch to GPU
  batch = tuple(t.to(device) for t in batch)


  # Unpack the inputs from our dataloader
  b_input_ids, b_input_mask = batch

  # Telling the model not to compute or store gradients, saving memory and 
  
  with torch.no_grad():
      # Forward pass, calculate embeddings
      outputs = model(b_input_ids)

  embeddings = outputs.pooler_output #CLS embeddings for the batch

  # Move em to CPU
  embeddings = embeddings.detach().cpu().numpy()
  
  # Store predictions and true labels
  result.append(embeddings)
  i=i+1


print('    DONE.')

In [None]:
#708 batches*16 texts with embedding size 768

768

In [None]:
final=[]
for b in result:
   for e in b:
      final.append(e)

In [None]:
# Final corpus
corpus_bert_df=pd.DataFrame(final)
corpus_bert_df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,728,729,730,731,732,733,734,735,736,737,738,739,740,741,742,743,744,745,746,747,748,749,750,751,752,753,754,755,756,757,758,759,760,761,762,763,764,765,766,767
0,-0.120891,-0.280994,-0.963245,0.419622,0.778366,-0.196418,-0.313293,0.23287,-0.862499,-0.955871,0.223536,0.886177,0.327486,0.866635,-0.249378,0.088374,0.439759,0.016313,0.064368,0.730549,0.462973,0.999983,-0.40831,0.307045,0.337647,0.908904,-0.068019,-0.055471,0.25283,0.328617,0.393255,0.119468,-0.622634,-0.270945,-0.978561,-0.187582,0.180457,0.12795,-0.101807,-0.25029,...,0.42801,-0.24705,-0.033118,-0.254504,-0.322066,-0.023203,-0.252481,-0.285563,0.193448,0.033795,0.999958,-0.736765,-0.848266,-0.186036,-0.336793,0.301686,-0.462011,-0.999997,0.255416,-0.859095,0.80337,-0.318131,0.902804,-0.754084,0.307099,-0.061251,0.62403,0.844899,-0.087035,-0.475938,0.161414,-0.928535,0.863065,-0.074633,-0.00226,-0.675741,0.185203,-0.855487,-0.17854,-0.240451
1,-0.183047,-0.320446,-0.967441,0.510457,0.806681,-0.212261,-0.239628,0.210489,-0.880438,-0.945462,0.122012,0.914262,0.214913,0.882294,-0.172117,0.000281,0.382229,0.015465,0.056818,0.700129,0.484552,0.999983,-0.460843,0.301718,0.323201,0.921945,-0.133789,-0.070382,0.238025,0.354671,0.378942,0.101619,-0.541697,-0.292166,-0.978894,-0.117593,0.185518,0.117982,-0.145663,-0.236753,...,0.481052,-0.270803,-0.04876,-0.217497,-0.346388,0.035731,-0.253299,-0.326463,0.175167,-0.009337,0.999946,-0.775669,-0.875688,-0.17833,-0.360491,0.283788,-0.45156,-0.999997,0.240596,-0.880769,0.822632,-0.391792,0.911289,-0.800638,0.229179,-0.054229,0.605605,0.866382,-0.10912,-0.466791,0.17515,-0.939977,0.888713,-0.089563,-0.091914,-0.731656,0.169998,-0.887561,-0.124455,-0.239807
2,-0.447093,-0.470325,-0.965326,0.46833,0.782669,-0.307996,-0.073427,0.399358,-0.898938,-0.996745,-0.013201,0.910904,0.592074,0.867216,0.129251,-0.309642,0.301718,-0.296469,0.227609,0.730816,0.562351,0.999993,-0.372397,0.492953,0.458219,0.939361,-0.354881,0.247452,0.618205,0.554835,0.12269,0.32039,-0.797047,-0.401551,-0.978959,-0.659457,0.303599,-0.156353,-0.21498,-0.144441,...,0.302321,-0.322595,-0.211455,-0.266762,-0.021651,-0.307902,-0.443819,-0.391257,0.400275,0.207513,0.999979,-0.793533,-0.903334,-0.252147,-0.451421,0.494141,-0.54382,-1.0,0.288275,-0.916234,0.872236,-0.40584,0.887067,-0.890248,-0.080239,-0.253946,0.607473,0.890332,-0.33116,-0.502541,0.636328,-0.899832,0.891958,0.028261,-0.23509,-0.60619,0.71639,-0.910605,-0.424803,-0.059388
3,-0.098515,-0.385608,-0.987904,0.494115,0.869421,-0.238608,-0.245373,0.274326,-0.947566,-0.950335,-0.018988,0.949433,-0.0147,0.939705,-0.249542,-0.130132,0.267692,-0.008533,0.047951,0.734438,0.511568,0.999997,-0.642022,0.352938,0.385997,0.963973,-0.198282,-0.205098,0.149731,0.367538,0.343082,0.175255,-0.408339,-0.325079,-0.989443,-0.010332,0.279057,0.168347,-0.1956,-0.278317,...,0.652145,-0.299753,-0.139638,-0.173683,-0.443671,-0.027985,-0.349703,-0.388918,0.165237,0.035473,0.999991,-0.882225,-0.953739,-0.24563,-0.413192,0.379641,-0.529157,-1.0,0.287768,-0.922833,0.899929,-0.592388,0.953749,-0.889039,0.282065,-0.129607,0.664973,0.92745,-0.143531,-0.5332,0.266515,-0.961368,0.948298,-0.195683,-0.299023,-0.849391,0.298872,-0.932983,-0.178504,-0.309659
4,-0.180124,-0.413501,-0.984492,0.474194,0.865759,-0.281325,-0.225921,0.305099,-0.935964,-0.974516,0.008402,0.932913,0.18234,0.928759,-0.195707,-0.15407,0.334482,-0.087378,0.10497,0.755661,0.508255,0.999997,-0.580859,0.43008,0.409025,0.948664,-0.238847,-0.117083,0.270822,0.436732,0.326114,0.217338,-0.518259,-0.354096,-0.987442,-0.17187,0.272495,0.096164,-0.216041,-0.242049,...,0.606793,-0.327609,-0.150611,-0.184259,-0.386002,-0.137363,-0.39325,-0.38306,0.206357,0.095422,0.999989,-0.869633,-0.944107,-0.264682,-0.423192,0.436482,-0.529781,-1.0,0.28089,-0.92591,0.897921,-0.559192,0.944515,-0.896119,0.235804,-0.174416,0.641793,0.917698,-0.214031,-0.523474,0.428578,-0.948966,0.931038,-0.180842,-0.408895,-0.821256,0.473733,-0.940421,-0.238981,-0.261727


In [None]:
pickle.dump(corpus_bert_df, open("/content/drive/MyDrive/TWSM_Data/BertModel.pkl", "wb"))