## Folds on SST Sentiment Dataset

This notebook is used to test the variation of testing result after varying the input using K-folds stratification, on Stanford Sentiment Treebank.

RoBERTa is used as the 5 way classifier.


###### This is the dataset of the following paper:

  Recursive Deep Models for Semantic Compositionality Over a Sentiment Treebank
  
 Richard Socher, Alex Perelygin, Jean Wu, Jason Chuang, Christopher Manning, Andrew Ng and Christopher Potts
 
 Conference on Empirical Methods in Natural Language Processing (EMNLP 2013)

In [1]:
# Importing necessary libraries
import pandas as pd
import numpy as np
from datetime import datetime
import random
import sklearn
import torch
import torch.nn as nn
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

from simpletransformers.classification import ClassificationModel

In [2]:
# procedure for getting the data sets and formatting them for the transformer
 
train=pd.read_csv('./SST_data/sst5_train_sentences.csv', names=['text', 'labels'] )

Eval=pd.read_csv('./SST_data/sst5_dev.csv' , names=['text', 'labels'] )

test=pd.read_csv('./SST_data/sst5_test.csv', names=['text', 'labels']  )

train

Unnamed: 0,text,labels
0,Reno himself can take credit for most of the m...,pos
1,"Despite the film 's shortcomings , the stories...",pos
2,"Despite its dry wit and compassion , the film ...",neg
3,The central character is n't complex enough to...,neu
4,Rifkin no doubt fancies himself something of a...,very neg
...,...,...
8529,A conventional but heartwarming tale .,very pos
8530,It has the air of a surprisingly juvenile lark...,neg
8531,The culmination of everyone 's efforts is give...,neu
8532,Overcomes its visual hideousness with a sharp ...,pos


In [3]:
Eval

Unnamed: 0,text,labels
0,( director ) O'Fallon manages to put some love...,very neg
1,A thinly veiled look at different aspects of C...,neu
2,If your taste runs to ` difficult ' films you ...,pos
3,( Leigh ) lays it on so thick this time that i...,neu
4,"A full world has been presented onscreen , not...",pos
...,...,...
1095,"Just as moving , uplifting and funny as ever .",pos
1096,Davis ... is so enamored of her own creation t...,neg
1097,"An exhilarating futuristic thriller-noir , Min...",very pos
1098,I got a headache watching this meaningless dow...,very neg


In [4]:
test=pd.read_csv('./SST_data/sst5_test.csv', names=['text', 'labels']  )
test

Unnamed: 0,text,labels
0,Maybe I found the proceedings a little bit too...,neg
1,"As with too many studio pics , plot mechanics ...",very neg
2,"Beers , who , when she 's given the right line...",neu
3,"Cute , funny , heartwarming digitally animated...",very pos
4,So what is the point ?,very neg
...,...,...
2205,It 's a glorious groove that leaves you wantin...,very pos
2206,It 's getting harder and harder to ignore the ...,neg
2207,"A real movie , about real people , that gives ...",pos
2208,"Sharp , lively , funny and ultimately sobering...",very pos


To keep closer to the standard way of doing K-folds I am including the validation and training set together.

The data will be spilt in 5 folds.  At every trial 4 will be used for training and the other will be used for validation.

In [5]:
train=train.append(Eval, ignore_index = True)
train

Unnamed: 0,text,labels
0,Reno himself can take credit for most of the m...,pos
1,"Despite the film 's shortcomings , the stories...",pos
2,"Despite its dry wit and compassion , the film ...",neg
3,The central character is n't complex enough to...,neu
4,Rifkin no doubt fancies himself something of a...,very neg
...,...,...
9629,"Just as moving , uplifting and funny as ever .",pos
9630,Davis ... is so enamored of her own creation t...,neg
9631,"An exhilarating futuristic thriller-noir , Min...",very pos
9632,I got a headache watching this meaningless dow...,very neg


#####  We now change the text labels to numeric( 0 to 4)

In [6]:

def labelsToNumbers(set):
    for row in range(len(set)):
        if set.iloc[row,1]=='very pos': set.iloc[row,1]=4
        if set.iloc[row,1]=='pos': set.iloc[row,1]=3
        if set.iloc[row,1]=='neu': set.iloc[row,1]=2
        if set.iloc[row,1]=='neg': set.iloc[row,1]=1
        if set.iloc[row,1]=='very neg': set.iloc[row,1]=0

    return set

train=labelsToNumbers(train)
test=labelsToNumbers(test)
test['labels']=test['labels'].astype('int64')

train

Unnamed: 0,text,labels
0,Reno himself can take credit for most of the m...,3
1,"Despite the film 's shortcomings , the stories...",3
2,"Despite its dry wit and compassion , the film ...",1
3,The central character is n't complex enough to...,2
4,Rifkin no doubt fancies himself something of a...,0
...,...,...
9629,"Just as moving , uplifting and funny as ever .",3
9630,Davis ... is so enamored of her own creation t...,1
9631,"An exhilarating futuristic thriller-noir , Min...",4
9632,I got a headache watching this meaningless dow...,0


In [26]:
# first we randomise the order of train

import random
 
randNum=[]

for row in range(len(train)):
    randNum.append(random.random())

train['RandNum']=randNum

train=train.sort_values(by=['RandNum'] )
train=train.drop(['RandNum'],axis=1)
del(randNum)
train
    

Unnamed: 0,text,labels
760,the phone rings and a voice tells you you 've ...,2
5169,"An often watchable , though goofy and lurid , ...",3
7343,Hawn and Sarandon form an acting bond that mak...,4
5695,"An unremarkable , modern action\/comedy buddy ...",2
3680,A hard look at one man 's occupational angst a...,3
...,...,...
5065,Not many movies have that kind of impact on me...,4
782,Intimate and panoramic .,3
2306,"Suspend your disbelief here and now , or you '...",1
1217,`` 13 Conversations About One Thing '' is an i...,4


In [27]:
count0=0
count1=0
count2=0
count3=0
count4=0
 

train0=[] #all stamtements that are very neg  (class 0)
train1=[]
train2=[]
train3=[]
train4=[] 


for row in range(len(train)):
        if train.iloc[row,1]==0: 
            count0+=1
            train0.append(train.iloc[row,:])
        if train.iloc[row,1]==1: 
            count1+=1
            train1.append(train.iloc[row,:])
        if train.iloc[row,1]==2: 
            count2+=1
            train2.append(train.iloc[row,:])
        if train.iloc[row,1]==3: 
            count3+=1
            train3.append(train.iloc[row,:])
        if train.iloc[row,1]==4: 
            count4+=1
            train4.append(train.iloc[row,:])


print('0s ', count0)
print('1s ', count1)
print('2s ', count2)
print('3s ', count3)
print('4s ', count4) 

            

0s  1090
1s  2215
2s  1623
3s  2319
4s  1287


In [29]:
def div5(myinteger):
    size_m5 =myinteger-(myinteger%5)
    QuantityToRemove=size_m5/5
    
    return QuantityToRemove, myinteger%5

C0div5=div5(count0)
C1div5=div5(count1)
C2div5=div5(count2)
C3div5=div5(count3)
C4div5=div5(count4)
 



print('To omit from class 0s ', C0div5)
print('To omit from class 1s ', C1div5)
print('To omit from class 2s ', C2div5)
print('To omit from class 3s ', C3div5)
print('To omit from class 4s ', C4div5) 

To omit from class 0s  (218.0, 0)
To omit from class 1s  (443.0, 0)
To omit from class 2s  (324.0, 3)
To omit from class 3s  (463.0, 4)
To omit from class 4s  (257.0, 2)


In [30]:
train0

[text      The characters are so generic and the plot so ...
 labels                                                    0
 Name: 1769, dtype: object,
 text      ... plays like a badly edited , 91-minute trai...
 labels                                                    0
 Name: 6037, dtype: object,
 text      How inept is Serving Sara ?
 labels                              0
 Name: 5551, dtype: object,
 text      With a completely predictable plot , you 'll s...
 labels                                                    0
 Name: 5763, dtype: object,
 text      Unfortunately , it 's also not very good .
 labels                                             0
 Name: 5446, dtype: object,
 text      The movie has a script ( by Paul Pender ) made...
 labels                                                    0
 Name: 104, dtype: object,
 text      Most of the problems with the film do n't deri...
 labels                                                    0
 Name: 2660, dtype: object,
 text    

 The order is already for each set is already randomised.
We will now, for sets train0 to train5 split each into 5 roughly equal parts.

In [31]:
#train[class][fold]
train01=[]
train02=[]
train03=[]
train04=[]
train05=[] 
   

for row in range(len(train0)):
    
    if row<C0div5[0]: train01.append(train0[row])
        
    if row>=C0div5[0] and row<(C0div5[0]*2): train02.append(train0[row])
        
    if row>=(C0div5[0]*2) and row<(C0div5[0]*3): train03.append(train0[row])
    
    if row>=(C0div5[0]*3) and row<(C0div5[0]*4): train04.append(train0[row])
        
    if row>=(C0div5[0]*4) and row<(C0div5[0]*5): train05.append(train0[row])
        
    if row>=(C0div5[0]*5):
        train01.append(train0[row])
         
     

     



In [32]:
train01

[text      The characters are so generic and the plot so ...
 labels                                                    0
 Name: 1769, dtype: object,
 text      ... plays like a badly edited , 91-minute trai...
 labels                                                    0
 Name: 6037, dtype: object,
 text      How inept is Serving Sara ?
 labels                              0
 Name: 5551, dtype: object,
 text      With a completely predictable plot , you 'll s...
 labels                                                    0
 Name: 5763, dtype: object,
 text      Unfortunately , it 's also not very good .
 labels                                             0
 Name: 5446, dtype: object,
 text      The movie has a script ( by Paul Pender ) made...
 labels                                                    0
 Name: 104, dtype: object,
 text      Most of the problems with the film do n't deri...
 labels                                                    0
 Name: 2660, dtype: object,
 text    

In [33]:
train02

[text      The film has a nearly terminal case of the cut...
 labels                                                    0
 Name: 4259, dtype: object,
 text      The most offensive thing about the movie is th...
 labels                                                    0
 Name: 1608, dtype: object,
 text      This is the case of a pregnant premise being w...
 labels                                                    0
 Name: 2485, dtype: object,
 text      There 's already been too many of these films ...
 labels                                                    0
 Name: 3775, dtype: object,
 text      A half-assed film .
 labels                      0
 Name: 1488, dtype: object,
 text      It 's mired in a shabby script that piles laye...
 labels                                                    0
 Name: 2910, dtype: object,
 text      director Hoffman , his writer and Kline 's age...
 labels                                                    0
 Name: 644, dtype: object,
 text      

In [34]:
        
train11=[]
train12=[]
train13=[]
train14=[]
train15=[]
train1remaining=[]

train21=[]
train22=[]
train23=[]
train24=[]
train25=[]
train2remaining=[]

train31=[]
train32=[]
train33=[]
train34=[]
train35=[]
train3remaining=[]

train41=[]
train42=[]
train43=[]
train44=[]
train45=[]
train4remaining=[]

 



In [35]:

for row in range(len(train1)):
    if row<C1div5[0]: train11.append(train1[row])
        
    if row>=C1div5[0] and row<(C1div5[0]*2): train12.append(train1[row])
        
    if row>=(C1div5[0]*2) and row<(C1div5[0]*3): train13.append(train1[row])
    
    if row>=(C1div5[0]*3) and row<(C1div5[0]*4): train14.append(train1[row])
        
    if row>=(C1div5[0]*4) and row<(C1div5[0]*5): train15.append(train1[row])
        
    if row>=(C1div5[0]*5):
        train11.append(train1[row])
         



In [36]:
for row in range(len(train2)):
    if row<C2div5[0]: train21.append(train2[row])
        
    if row>=C2div5[0] and row<(C2div5[0]*2): train22.append(train2[row])
        
    if row>=(C2div5[0]*2) and row<(C2div5[0]*3): train23.append(train2[row])
    
    if row>=(C2div5[0]*3) and row<(C2div5[0]*4): train24.append(train2[row])
        
    if row>=(C2div5[0]*4) and row<(C2div5[0]*5): train25.append(train2[row])
        
    if row>=(C2div5[0]*5):
        train21.append(train2[row])
         


In [37]:

for row in range(len(train3)):
    if row<C3div5[0]: train31.append(train3[row])
        
    if row>=C3div5[0] and row<(C3div5[0]*2): train32.append(train3[row])
        
    if row>=(C3div5[0]*2) and row<(C3div5[0]*3): train33.append(train3[row])
    
    if row>=(C3div5[0]*3) and row<(C3div5[0]*4): train34.append(train3[row])
        
    if row>=(C3div5[0]*4) and row<(C3div5[0]*5): train35.append(train3[row])
        
    if row>=(C3div5[0]*5):
        train31.append(train3[row])
         

In [38]:
for row in range(len(train4)):
    if row<C4div5[0]: train41.append(train4[row])
        
    if row>=C4div5[0] and row<(C4div5[0]*2): train42.append(train4[row])
        
    if row>=(C4div5[0]*2) and row<(C4div5[0]*3): train43.append(train4[row])
    
    if row>=(C4div5[0]*3) and row<(C4div5[0]*4): train44.append(train4[row])
        
    if row>=(C4div5[0]*4) and row<(C4div5[0]*5): train45.append(train4[row])
        
    if row>=(C4div5[0]*5):
        train41.append(train4[row])
         

In [39]:
pants1=pd.DataFrame(train01, columns=['text','labels'])
pants2=pd.DataFrame(train02, columns=['text','labels'])
pants3=pd.DataFrame(train03, columns=['text','labels'])
pants4=pd.DataFrame(train04, columns=['text','labels'])
pants5=pd.DataFrame(train05, columns=['text','labels'])


fake1=pd.DataFrame(train11, columns=['text','labels'])
fake2=pd.DataFrame(train12, columns=['text','labels'])
fake3=pd.DataFrame(train13, columns=['text','labels'])
fake4=pd.DataFrame(train14, columns=['text','labels'])
fake5=pd.DataFrame(train15, columns=['text','labels'])


Mfake1=pd.DataFrame(train21, columns=['text','labels'])
Mfake2=pd.DataFrame(train22, columns=['text','labels'])
Mfake3=pd.DataFrame(train23, columns=['text','labels'])
Mfake4=pd.DataFrame(train24, columns=['text','labels'])
Mfake5=pd.DataFrame(train25, columns=['text','labels'])


half1=pd.DataFrame(train31, columns=['text','labels'])
half2=pd.DataFrame(train32, columns=['text','labels'])
half3=pd.DataFrame(train33, columns=['text','labels'])
half4=pd.DataFrame(train34, columns=['text','labels'])
half5=pd.DataFrame(train35, columns=['text','labels'])


Mreal1=pd.DataFrame(train41, columns=['text','labels'])
Mreal2=pd.DataFrame(train42, columns=['text','labels'])
Mreal3=pd.DataFrame(train43, columns=['text','labels'])
Mreal4=pd.DataFrame(train44, columns=['text','labels'])
Mreal5=pd.DataFrame(train45, columns=['text','labels'])

 

frames1 = [pants2, pants3, pants4, pants5, fake2, fake3, fake4, fake5, Mfake2, Mfake3, Mfake4, Mfake5, half2, half3, half4, half5, Mreal2, Mreal3, Mreal4, Mreal5  ]

frames2 = [ pants1, pants3, pants4, pants5,fake1,  fake3, fake4, fake5,Mfake1, Mfake3, Mfake4, Mfake5,half1, half3, half4, half5,Mreal1, Mreal3, Mreal4, Mreal5 ]


frames3 = [pants1, pants2,  pants4, pants5,fake1, fake2, fake4, fake5,Mfake1, Mfake2, Mfake4, Mfake5,half1, half2,  half4, half5,Mreal1, Mreal2,  Mreal4, Mreal5 ]


frames4 = [ pants1, pants2, pants3, pants5,fake1, fake2, fake3, fake5,Mfake1, Mfake2, Mfake3, Mfake5,half1, half2, half3, half5,Mreal1, Mreal2, Mreal3,  Mreal5]

frames5 = [pants1, pants2, pants3, pants4, fake1, fake2, fake3, fake4,Mfake1, Mfake2, Mfake3, Mfake4, half1, half2, half3, half4,Mreal1, Mreal2, Mreal3, Mreal4]




train_fold1 = pd.concat(frames1)
train_fold2 = pd.concat(frames2)
train_fold3 = pd.concat(frames3)
train_fold4 = pd.concat(frames4)
train_fold5 = pd.concat(frames5)

#we set the omitted fold as the validation set

frames1=[pants1,fake1,Mfake1,half1,Mreal1 ]
valid1 = pd.concat(frames1)

frames2=[pants2,fake2,Mfake2,half2,Mreal2 ]
valid2 = pd.concat(frames2)

frames3=[pants3,fake3,Mfake3,half3,Mreal3 ]
valid3 = pd.concat(frames3)

frames4=[pants4,fake4,Mfake4,half4,Mreal4]
valid4 = pd.concat(frames4)

frames5=[pants5,fake5,Mfake5,half5,Mreal5 ]
valid5 = pd.concat(frames5)
 

In [40]:
def randomiseSet(set):
    #this function randomises the order of the set.
    #order shouldn't be an issue but better keep things close to the realistic realms
 
    randNum=[]

    for row in range(len(set)):
        randNum.append(random.random())

    set['RandNum']=randNum
    set=set.sort_values(by=['RandNum'] )
    set=set.drop(['RandNum'],axis=1)
    
    return set

In [41]:
train_fold1=randomiseSet(train_fold1)
train_fold2=randomiseSet(train_fold2)
train_fold3=randomiseSet(train_fold3)
train_fold4=randomiseSet(train_fold4)
train_fold5=randomiseSet(train_fold5)
valid1=randomiseSet(valid1)
valid2=randomiseSet(valid2)
valid3=randomiseSet(valid3)
valid4=randomiseSet(valid4)
valid5=randomiseSet(valid5)

In [44]:
train_fold1.to_excel('./folds/train_fold1.xls',index=False)
train_fold2.to_excel('./folds/train_fold2.xls',index=False)
train_fold3.to_excel('./folds/train_fold3.xls',index=False)
train_fold4.to_excel('./folds/train_fold4.xls',index=False)
train_fold5.to_excel('./folds/train_fold5.xls',index=False)

valid1.to_excel('./folds/valid1.xls',index=False)
valid2.to_excel('./folds/valid2.xls',index=False)
valid3.to_excel('./folds/valid3.xls',index=False)
valid4.to_excel('./folds/valid4.xls',index=False)
valid5.to_excel('./folds/valid5.xls',index=False)

# We can now run the tests

## Fold1 training and capturing predictions

In [16]:
fold_number='1'

train=pd.read_excel('./folds/train_fold'+fold_number+'.xls')
Eval =pd.read_excel('./folds/valid'+fold_number+'.xls') #evaluation set


In [9]:
#Set the model being used here
model_class='roberta'  # bert or roberta or albert
model_version='roberta-large' #bert-base-cased, roberta-base, roberta-large, albert-base-v2 OR albert-large-v2
labels_count=5  # the number of classification classes


output_folder='./folds/fold'+fold_number+'/'+model_class+'/'+model_version+"/"
cache_directory= "./folds/fold"+fold_number+'/'+model_class+"/"+model_version+"/cache/"


print('model variables were set up: ')

 
save_every_steps=1285
# assuming training batch size of 8
# any number above 1284 saves the model only at every epoch
# Saving the model mid training very often will consume disk space fast

train_args={
    "output_dir":output_folder,
    "cache_dir":cache_directory,
    'reprocess_input_data': True,
    'overwrite_output_dir': True,
    'num_train_epochs': 2,
    "save_steps": save_every_steps, 
    "learning_rate": 1.2e-5,
    "train_batch_size": 32,
    "eval_batch_size": 8,
    "weight_decay": 0,
    "evaluate_during_training_steps": 30,
    "max_seq_length": 100,
    "n_gpu": 1,
}

# Create a ClassificationModel
model = ClassificationModel(model_class, model_version, num_labels=labels_count, args=train_args) 

# loading a previously saved ClassificationModel model based on this particular Transformer Class and model_name



model variables were set up: 


In [8]:
# loading the checkpoint that gave the best result

CheckPoint='checkpoint-428-epoch-2'   


preSavedCheckpoint=output_folder+CheckPoint

print('Loading model, please wait...')
model = ClassificationModel( model_class, preSavedCheckpoint, num_labels=labels_count, args=train_args) 
print('model in use is :', preSavedCheckpoint )
 

Loading model, please wait...
model in use is : ./folds/fold1/roberta/roberta-large/checkpoint-428-epoch-2


In [31]:
# Train the model
current_time = datetime.now()
model.train_model(train)
print("Training time taken: ", datetime.now() - current_time, ' at:',datetime.now())

Converting to features started. Cache is not used.


HBox(children=(FloatProgress(value=0.0, max=6820.0), HTML(value='')))


Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=2.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=214.0, style=ProgressStyle(descri…

Running loss: 1.630410Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Running loss: 1.676554Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Running loss: 1.224474Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 8192.0
Running loss: 1.453225


HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=214.0, style=ProgressStyle(descri…

Running loss: 1.160760

Training of roberta model complete. Saved to ./folds/fold1/roberta/roberta-large/.
Training time taken:  0:18:08.272363  at: 2020-04-28 14:04:19.687315


In [32]:
TrainResult, TrainModel_outputs, wrong_predictions = model.eval_model(train, acc=sklearn.metrics.accuracy_score)
 
EvalResult, EvalModel_outputs, wrong_predictions = model.eval_model(Eval, acc=sklearn.metrics.accuracy_score)

print('Training Result:', TrainResult['acc'])
#print('Model Out:', TrainModel_outputs)

print('Eval Result:', EvalResult['acc'])
#print('Model Out:', EvalModel_outputs)


Features loaded from cache at ./folds/fold1/roberta/roberta-large/cache/cached_dev_roberta_100_5_6820


HBox(children=(FloatProgress(value=0.0, max=853.0), HTML(value='')))


{'mcc': 0.5567553414792227, 'acc': 0.6513196480938417, 'eval_loss': 0.8084349541423749}
Features loaded from cache at ./folds/fold1/roberta/roberta-large/cache/cached_dev_roberta_100_5_1714


HBox(children=(FloatProgress(value=0.0, max=215.0), HTML(value='')))


{'mcc': 0.47473708719467395, 'acc': 0.587514585764294, 'eval_loss': 0.9497925260732341}
Training Result: 0.6513196480938417
Eval Result: 0.587514585764294


In [9]:


TestResult, TestModel_outputs, wrong_predictions = model.eval_model(test, acc=sklearn.metrics.accuracy_score)

print('Test Set Result:', TestResult['acc'])
#print('Model Out:', TestModel_outputs)

Converting to features started. Cache is not used.


HBox(children=(FloatProgress(value=0.0, max=2210.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=277.0), HTML(value='')))


{'mcc': 0.48638002359478494, 'acc': 0.5981900452488688, 'eval_loss': 0.9062698445810738}
Test Set Result: 0.5981900452488688


In [10]:
Pred=[]
Targets=[]

countCorrect=0

for row in range(TestModel_outputs.shape[0]):
    outputs=TestModel_outputs[row]
    #print(test.iloc[row,0])
    print(outputs, end=' ')
    
    result=0
    if outputs[0]<outputs[1]:result=1
    if outputs[result]<outputs[2]:result=2
    if outputs[result]<outputs[3]:result=3
    if outputs[result]<outputs[4]:result=4
    
    Pred.append(result)
    Targets.append(test.iloc[row,1])
    print(result, ' ',test.iloc[row,1], end=' ')
    if result==test.iloc[row,1]:
        countCorrect+=1
        print('Match',countCorrect)
    print('')

print(countCorrect)
#Pred

[-0.682301   2.0310621  1.931816  -0.5134711 -2.8004656] 1   1 Match 1

[ 0.63714504  2.6623714   1.1769586  -1.6990744  -3.1518998 ] 1   0 
[-0.43096247  2.3492606   1.7593096  -0.7876669  -2.8023489 ] 1   2 
[-2.3385808 -2.1272495 -0.863038   2.3431942  3.517544 ] 4   4 Match 2

[-0.59622246  1.5581144   1.9621681  -0.48837048 -2.8487227 ] 2   0 
[-2.5027547  -2.0397758   0.36904395  3.1994476   1.7975659 ] 3   3 Match 3

[ 1.9724123  2.1415079  0.2116234 -2.3558404 -2.8676949] 1   0 
[-2.5943868  -2.0832145   0.26932967  3.151726    2.0880466 ] 3   4 
[ 2.9532235   1.5635093  -0.45549327 -2.4321623  -2.5818248 ] 0   0 Match 4

[ 2.3776605   1.8573465  -0.00813339 -2.3608994  -2.695493  ] 0   0 Match 5

[ 0.81382334  2.4480042   1.0393538  -1.7761447  -3.1406708 ] 1   2 
[-2.72299    -1.9694477   0.09742542  2.5999868   2.3456645 ] 3   4 
[-1.4602405  1.4469442  2.0501034  0.3539906 -2.4018743] 2   1 
[ 2.5491953  2.0025377 -0.0351897 -2.3843048 -2.8159416] 0   0 Match 6

[-2.4178593

[ 1.5482776   2.3789034   0.57794726 -1.9798918  -3.086655  ] 1   1 Match 60

[-2.4083035   0.25450653  2.086146    1.6960084  -1.2689809 ] 2   3 
[-2.9085324 -1.6694376  0.5372442  3.2439108  1.5671946] 3   3 Match 61

[-2.7256603  -2.0623636  -0.09075508  2.9922423   2.7184896 ] 3   4 
[-1.9610598  0.1611032  2.0458934  1.7245098 -1.4861388] 2   2 Match 62

[ 1.6221702   2.4242835   0.61444765 -2.1987567  -3.0940797 ] 1   0 
[-2.1963537 -1.8484775 -0.8129145  1.955355   3.628164 ] 4   3 
[-2.3948855  -2.0162554  -0.80029356  2.208299    3.5683556 ] 4   4 Match 63

[-0.4661812  2.2130418  1.8072469 -0.6369598 -2.948527 ] 1   1 Match 64

[-2.6727214 -2.0041132  0.3360456  3.0689716  2.259877 ] 3   3 Match 65

[-2.6268322  -1.3513283   1.3847407   3.074941    0.38721642] 3   3 Match 66

[-2.3161566  -0.9785731   1.0317887   2.476514    0.40087387] 3   4 
[ 2.2786434   2.0469549   0.06182667 -2.3721912  -2.9813097 ] 0   0 Match 67

[ 0.02300288  2.463956    1.446615   -1.0812947  -2.8326

[-0.1534365   2.0137634   1.539734   -0.84413403 -2.8614275 ] 1   0 
[-2.4098873  -1.939822    0.02263179  2.8431485   2.6973543 ] 3   4 
[-2.3883007 -2.122407  -0.7960624  2.320511   3.5572655] 4   4 Match 141

[-0.5049903   2.1307232   1.856218   -0.54694724 -2.8785672 ] 1   1 Match 142

[-2.5424244  -1.7408775   0.02458121  2.55669     2.2009351 ] 3   3 Match 143

[-2.458365   -0.63048387  1.7552316   2.609079   -0.68417203] 3   3 Match 144

[-2.4548633  -1.3339199   0.9297162   2.6355653   0.92183554] 3   2 
[-1.3068534  1.2719647  2.0787518  0.4906633 -2.450015 ] 2   1 
[-0.37693107  2.3450513   1.712669   -0.79422534 -2.822826  ] 1   2 
[ 2.3311007   2.113618    0.08630227 -2.2141337  -2.9366038 ] 0   1 
[ 2.6917508   1.9101056  -0.10600259 -2.534544   -2.8058295 ] 0   0 Match 145

[ 1.2635735  2.4366715  0.829657  -1.975025  -3.1681395] 1   2 
[-0.9522435   0.7999514   1.7030318   0.45564002 -1.9542538 ] 2   1 
[-1.2769872   1.1400572   2.187789    0.64613396 -2.3258424 ] 2   1 

[ 2.2783406   2.1291616   0.11269074 -2.3320177  -2.9787736 ] 0   0 Match 221

[-2.5012426  -2.0323265  -0.29463518  2.6940267   2.9530742 ] 4   3 
[-2.7395198 -0.8983934  1.7048726  2.7321718 -0.464658 ] 3   3 Match 222

[-2.656547  -1.661608   0.6162326  2.6166668  1.5281109] 3   3 Match 223

[ 1.0176096   2.2486162   0.97106403 -1.624985   -3.0061045 ] 1   1 Match 224

[-2.7192848 -2.040132   0.4561976  3.0989435  1.9235319] 3   3 Match 225

[-2.9041672  -1.0396847   1.6108841   2.8722026  -0.12437934] 3   2 
[-2.5290341  -2.1126077  -0.45997372  2.7461827   2.9204197 ] 4   3 
[-0.9626199   1.7627041   2.07178     0.06426138 -2.6570826 ] 2   2 Match 226

[-2.8610437  -0.8554127   1.4465541   2.7723348   0.03865676] 3   3 Match 227

[-2.8943655  -1.4766009   0.85364455  2.9273944   1.2472823 ] 3   2 
[-0.02014575  2.210169    1.7725915  -1.1324849  -2.8403273 ] 1   2 
[-2.6621501 -1.7584583  0.856078   3.1252327  1.5326701] 3   3 Match 228

[ 1.3876698  2.3502686  0.7362876 -1.959466

[ 2.4721146   1.8588923  -0.04210741 -2.4757564  -2.7652388 ] 0   0 Match 304

[ 2.7832575   1.784157   -0.28524676 -2.2784877  -2.760469  ] 0   1 
[ 0.63956344  2.4075766   1.1947033  -1.4237647  -3.1217556 ] 1   1 Match 305

[-2.574052  -2.1003258 -0.6370252  2.5256152  3.2908566] 4   4 Match 306

[-1.8601329   0.64677525  2.2230248   1.0855067  -1.9327425 ] 2   2 Match 307

[ 0.9016162   2.4766366   0.93056786 -1.8200824  -3.091153  ] 1   1 Match 308

[-0.6575278   1.4911269   1.8954384  -0.16057627 -2.5197773 ] 2   1 
[ 1.7128197   2.3645167   0.45235476 -2.2883484  -3.0840893 ] 1   2 
[-0.06116889  2.4168725   1.6974745  -1.0972279  -3.0693386 ] 1   2 
[-2.5542936 -2.091899  -0.651222   2.6002493  3.3315287] 4   4 Match 309

[-2.404516   -1.7819936   0.68089944  3.231721    1.3890977 ] 3   3 Match 310

[ 2.8538888   1.7886955  -0.38576767 -2.5302804  -2.7785783 ] 0   0 Match 311

[-2.066062    0.42506722  2.302463    1.4580232  -1.7597425 ] 2   1 
[-2.5010653 -2.2012944 -0.41363  

[-2.6090586 -1.3715711  1.1400203  3.0079205  0.5237929] 3   2 
[-2.8469012  -1.8887583   0.30534497  2.8783991   2.2278037 ] 3   4 
[ 0.47178355  2.3535368   1.1971611  -1.460509   -3.1515553 ] 1   1 Match 381

[-2.7168295 -1.9395952  0.5980491  3.0559828  1.6984317] 3   3 Match 382

[ 1.0799525   2.2774894   0.95704275 -1.708673   -3.3478482 ] 1   1 Match 383

[-1.0486463  1.4033556  2.111362  -0.328979  -2.6317809] 2   1 
[ 1.9362265   2.439075    0.36451268 -2.5018241  -2.9650843 ] 1   1 Match 384

[-2.295169    0.08179379  2.293016    2.0270402  -1.4733187 ] 2   2 Match 385

[-2.7407343 -1.7010826  0.9549605  3.2370112  1.100718 ] 3   4 
[-0.5138293  2.1543286  1.6479753 -0.5230944 -2.7377064] 1   2 
[ 0.809696   2.449716   1.1494218 -1.676704  -3.0739348] 1   0 
[-2.581289   -0.7267278   1.9186063   2.7486794  -0.68835974] 3   2 
[ 2.5558546   2.0135825  -0.20467341 -2.5840027  -2.8404834 ] 0   1 
[-2.525113   -2.0720184  -0.22022295  3.0956264   2.5461006 ] 3   4 
[ 0.26494807  

[-2.4718325  -2.1549053  -0.46153617  2.737934    3.199811  ] 4   3 
[-2.795964   -1.6721283   0.91981715  3.2422326   1.3299397 ] 3   3 Match 461

[-2.0876842 -1.0610875  0.8691167  1.6504662  0.8458619] 3   3 Match 462

[ 1.9598068  2.2508001  0.4049906 -2.2873485 -3.1846375] 1   1 Match 463

[ 0.02299052  2.0177112   1.6609522  -0.9057631  -2.9382515 ] 1   2 
[-2.7685142  -2.0193536   0.08885342  3.0080914   2.4049494 ] 3   4 
[-2.2197921  -0.97645617  0.7512028   2.1481369   0.7610853 ] 3   3 Match 464

[-0.3347825   2.2127542   1.9216263  -0.78162414 -3.10117   ] 1   2 
[ 2.6892672   1.6265165  -0.31395242 -2.4894977  -2.6422195 ] 0   1 
[ 2.1239796   2.039254    0.39405763 -2.268038   -2.9678252 ] 0   1 
[-2.5854552 -2.1694355 -0.6427684  2.5040526  3.4678977] 4   4 Match 465

[ 2.4057324  2.170847  -0.1111018 -2.6331887 -2.8712738] 0   1 
[-2.8003287  -2.0166311   0.50139964  3.3056607   1.8441259 ] 3   3 Match 466

[ 0.71614665  2.4962642   1.2403038  -1.5618321  -3.1560757 ] 1

[-2.7829475  -2.0503197  -0.20293489  2.7659872   2.9161873 ] 4   4 Match 547

[-1.868434    0.20421398  2.0594614   1.5411249  -1.8303621 ] 2   2 Match 548

[ 2.192816    2.127176   -0.00925678 -2.4642973  -2.900524  ] 0   0 Match 549

[ 1.7608899  2.3726332  0.5455334 -2.3113222 -3.028847 ] 1   1 Match 550

[-0.47992203  2.0082974   1.731864   -0.4954796  -2.6297193 ] 1   1 Match 551

[-2.4353342 -2.1119046 -0.8523152  2.3863873  3.6203403] 4   4 Match 552

[ 2.8359215   1.7923712  -0.44638976 -2.5504954  -2.799001  ] 0   0 Match 553

[-2.5853736  -2.108193   -0.11738186  3.0059304   2.6300583 ] 3   3 Match 554

[-2.88296    -2.0281239  -0.02315354  2.7733412   2.6290722 ] 3   3 Match 555

[ 2.615296    1.9712925  -0.45463082 -2.51277    -2.9586253 ] 0   1 
[ 2.5059702   1.9233205  -0.15547779 -2.4891458  -2.812196  ] 0   1 
[-2.5064073  -2.1628401  -0.40970474  2.873909    2.9572093 ] 4   4 Match 556

[ 2.5674496   1.8306532  -0.21625006 -2.5579991  -2.6458993 ] 0   0 Match 557

[-2

[-0.9547413   0.64907926  1.8385824   0.5965708  -1.9538797 ] 2   2 Match 639

[-2.490043   -2.16182    -0.58947414  2.4340513   3.373769  ] 4   3 
[ 0.3435556  2.33361    1.5357652 -1.2465909 -2.8338318] 1   1 Match 640

[-2.5408058  -1.4956995   0.57820344  2.5637407   1.1849755 ] 3   4 
[ 1.6957799  1.9167039  0.5629431 -1.9007816 -2.7849982] 1   1 Match 641

[ 2.5378468   2.0535035  -0.18820351 -2.5340333  -2.9147313 ] 0   0 Match 642

[-2.770078   -1.8277979   0.44983947  3.1270366   2.0524294 ] 3   4 
[-2.3369615 -2.1236262 -0.7558657  2.6150568  3.177259 ] 4   4 Match 643

[-0.21087714  2.2679284   1.7299985  -0.7994724  -3.0591762 ] 1   1 Match 644

[ 0.7297624  2.2705317  0.9610005 -1.4273154 -2.9131227] 1   1 Match 645

[ 1.6351138  2.144814   0.5528589 -1.9012073 -3.1180286] 1   0 
[ 2.5869026  1.9876668 -0.2762614 -2.638728  -2.7935224] 0   1 
[-2.7742498  -2.175639   -0.12959456  2.6324747   2.7743995 ] 4   3 
[-2.641288  -2.0211246 -0.4116296  2.6917005  3.0583477] 4   4 

[ 1.7931566   2.4416392   0.48785892 -2.16865    -3.2578921 ] 1   1 Match 722

[-2.8240726  -1.6305468   0.62337977  2.958617    1.6665002 ] 3   3 Match 723

[-2.5611618  -2.1009524  -0.65961266  2.2505498   3.5586286 ] 4   4 Match 724

[ 2.5680816   2.0007677  -0.11772255 -2.6416912  -2.8153517 ] 0   1 
[ 1.7116442   2.1930325   0.36422288 -2.0941138  -3.0461702 ] 1   1 Match 725

[-2.1464255  -0.8453425   1.055283    2.0924244   0.12782045] 3   2 
[ 0.43435243  1.2735565   1.057311   -0.559899   -2.7687495 ] 1   2 
[-0.986687    1.5461223   2.243012    0.04091471 -2.577567  ] 2   2 Match 726

[-2.550612  -2.1061044 -0.5851075  2.7084796  3.2694569] 4   4 Match 727

[-2.6371207  -0.59697294  1.926061    2.7203674  -0.4420905 ] 3   3 Match 728

[-2.6200697  -2.0883224  -0.23715845  2.6703482   2.9702458 ] 4   3 
[ 2.1838346   2.2654989   0.09896322 -2.456054   -3.0204065 ] 1   1 Match 729

[ 0.91864896  2.2001793   1.0683932  -1.6999023  -2.9944189 ] 1   1 Match 730

[-9.3975896e-01  1

[-2.7638566 -1.7261717  0.5575162  2.8863938  1.6982883] 3   3 Match 807

[ 2.657997    1.8113602  -0.23673993 -2.6166666  -2.7763276 ] 0   0 Match 808

[-2.7064211  -1.972162    0.03947364  3.032379    2.5863695 ] 3   3 Match 809

[-0.9926613   1.2524543   2.071233    0.37791163 -2.7341173 ] 2   3 
[ 2.5565426   1.922063   -0.16156653 -2.5673304  -2.8051667 ] 0   0 Match 810

[-0.5308444   1.9821539   1.9523017  -0.45081273 -2.8561366 ] 1   2 
[ 1.8720528  2.3571246  0.4229126 -2.4301507 -3.0900462] 1   0 
[ 1.1681024  1.9700129  1.2321343 -1.8042427 -2.8321612] 1   0 
[ 1.4561822  2.551843   0.6728175 -2.168205  -3.064646 ] 1   1 Match 811

[ 2.1720781   2.0569148   0.17404729 -2.2798269  -2.883175  ] 0   0 Match 812

[-2.910361   -1.994788    0.16005763  2.8765557   2.3483477 ] 3   3 Match 813

[-0.44440275  2.1295073   1.9832406  -0.6924435  -2.8351486 ] 1   1 Match 814

[ 1.3262966  2.346916   0.692967  -1.9615535 -3.1495752] 1   0 
[-2.5895681  -1.3400258   0.60591424  2.184019  

[-1.8158393  0.3655146  2.1820314  1.1677352 -1.4704082] 2   2 Match 890

[-1.3117404   0.75002927  1.9185692   0.997482   -2.0146596 ] 2   1 
[-2.4091935  -2.059553   -0.65747917  2.5092206   3.3858755 ] 4   3 
[-2.5490513 -2.0373232 -0.2454555  2.6683643  2.8372817] 4   3 
[-1.7788384   0.45213178  1.93711     1.2304348  -1.7083817 ] 2   1 
[-2.398039   -1.86373     0.26733786  3.1531134   2.1113262 ] 3   3 Match 891

[-2.7509236 -1.8760529  0.4374117  3.0181599  1.8912466] 3   3 Match 892

[-2.7399483  -1.1592786   1.5897003   3.0547404   0.07419112] 3   3 Match 893

[-2.0845063 -1.7430413 -0.6185422  1.7588539  3.1576385] 4   4 Match 894

[ 1.8676932   2.6197884   0.45841303 -2.398411   -3.168944  ] 1   1 Match 895

[-2.5521564 -2.0534792 -0.5079744  2.514227   3.355812 ] 4   4 Match 896

[-1.8945186 -1.0180206  1.1456664  1.5923195  0.5497405] 3   2 
[-2.8342025 -1.5839508  1.070841   3.0765727  1.115937 ] 3   3 Match 897

[ 2.551073    1.7519377  -0.23793784 -2.5067525  -2.699691

[-2.78128    -1.7926111   0.80869746  3.1249392   1.6806694 ] 3   3 Match 968

[-2.7011888  -2.0390017  -0.27504113  2.571443    3.0064828 ] 4   4 Match 969

[ 2.909446   1.4617484 -0.2260887 -2.392226  -2.371828 ] 0   0 Match 970

[ 1.532538   2.593876   0.5344701 -2.2541287 -3.1309335] 1   1 Match 971

[ 2.3031645   2.2068207   0.09644376 -2.5753899  -2.9302726 ] 0   0 Match 972

[-2.0288498 -1.1690905  0.5192834  2.113861   1.1784569] 3   3 Match 973

[-2.7581632  -1.9825948   0.28276777  3.040755    2.1952398 ] 3   3 Match 974

[ 1.8084767  2.265473   0.4312465 -2.186668  -3.1869514] 1   1 Match 975

[ 1.0809457  2.3538241  0.9780364 -1.6293304 -3.0052898] 1   1 Match 976

[-2.2435572  -0.43871665  2.0670848   2.5035849  -0.75610214] 3   2 
[ 1.8575225  2.389361   0.3082199 -2.4070356 -2.9967518] 1   0 
[ 1.0315065  2.4868388  1.0044831 -1.6817553 -3.3570902] 1   0 
[-2.4709303  -1.7795454   0.20331526  2.497456    1.850456  ] 3   3 Match 977

[-1.7699418  -0.04861505  1.9706028   

[ 1.77759     2.0372126   0.49066538 -2.0853314  -3.004704  ] 1   0 
[-2.7655811  -1.9556571  -0.03230938  2.6456552   2.7105937 ] 4   3 
[-1.2368405   1.6424913   1.9745382   0.10141537 -2.506821  ] 2   1 
[ 1.6976291  2.3679857  0.5722865 -2.1564481 -3.1030138] 1   0 
[-2.2786746 -1.9486421 -1.0855733  1.5980428  3.8999076] 4   4 Match 1055

[-2.4295187  -2.1512015  -0.54370224  2.3989303   3.4269605 ] 4   4 Match 1056

[-2.5039213 -2.0149164 -0.6062733  2.566039   3.3518863] 4   3 
[ 0.7695011  2.584665   1.1228482 -1.8855124 -3.1628792] 1   1 Match 1057

[ 1.2830607  1.9732349  0.6945058 -1.7448105 -2.7683892] 1   1 Match 1058

[-2.649014  -1.7699708  0.7403748  3.0160677  1.0678946] 3   2 
[-2.4842224 -1.9189594 -0.6401647  1.9037068  3.4988866] 4   4 Match 1059

[-2.395751   -2.0842984  -0.76143193  2.1833875   3.5958862 ] 4   4 Match 1060

[ 0.99651843  2.3905709   0.8317535  -1.999451   -3.0766265 ] 1   1 Match 1061

[ 2.3357987   2.0230665  -0.12687457 -2.4772758  -2.862966  ]

[-2.52338   -2.1434748 -0.4342262  2.9202232  2.9843779] 4   4 Match 1142

[-2.6409574  -1.6717856   1.1814516   3.1343477   0.89998305] 3   3 Match 1143

[ 2.3758073   1.9809022   0.09322944 -2.4182248  -2.7441406 ] 0   0 Match 1144

[-2.6258624  -2.0441785   0.30569905  2.8875232   2.2393205 ] 3   4 
[ 0.4082896  2.0355937  1.1717063 -1.2444835 -3.040502 ] 1   1 Match 1145

[-2.185107   -1.2013116   1.0030369   2.1838021   0.41326973] 3   3 Match 1146

[-1.9349873  -0.02077193  1.7999142   1.6680993  -0.969426  ] 2   2 Match 1147

[-2.6628127 -1.1922481  0.810426   2.7656953  0.6730108] 3   3 Match 1148

[-0.01996914  1.501723    1.5534356  -0.36159337 -2.8277285 ] 2   2 Match 1149

[-2.5633352 -2.1109042 -0.5511889  2.5484395  3.364386 ] 4   4 Match 1150

[ 0.2565897  2.4961915  1.4266347 -1.2867383 -2.9998105] 1   1 Match 1151

[ 2.002534   2.1764314  0.3334094 -2.3461978 -2.9849384] 1   1 Match 1152

[ 0.82364494  2.578926    1.0165014  -1.7610244  -3.179262  ] 1   2 
[-2.6874292 

[-2.6355162  -1.9215933   0.08375753  3.0526586   2.3809483 ] 3   4 
[ 2.0753713   2.3305714   0.21612275 -2.5002525  -3.00605   ] 1   1 Match 1231

[-2.5021415 -2.2468438 -0.5058151  2.5543356  3.2802625] 4   3 
[-2.6341403  -1.6694367   0.78226095  3.1558518   1.2272521 ] 3   3 Match 1232

[ 0.77290225  2.3790967   1.1102792  -1.6558987  -3.1843395 ] 1   0 
[-2.5221548 -1.8842242 -0.7405214  2.0419924  3.6019197] 4   4 Match 1233

[ 0.6018084   0.65982527  0.6379692  -0.76533157 -1.6423916 ] 1   1 Match 1234

[-1.5935652   0.25925085  1.4817146   1.0411241  -1.1716188 ] 2   3 
[ 2.9125237  1.6423123 -0.5967529 -2.3304005 -2.7072723] 0   0 Match 1235

[-2.4752064  -0.69576675  2.054239    2.4415095  -0.60563964] 3   3 Match 1236

[ 0.8618805  2.625183   1.0720829 -1.7473457 -3.1135974] 1   1 Match 1237

[-0.62563175  1.8070526   1.9875133  -0.3324968  -2.6683843 ] 2   3 
[ 0.36198804  2.5112276   1.4045944  -1.3577582  -3.1155019 ] 1   1 Match 1238

[ 1.643747    2.3840928   0.4485915


[ 1.7303159   2.2090108   0.48504984 -2.0900006  -3.1259587 ] 1   1 Match 1321

[-2.5470104  -2.1879747  -0.44111148  2.5002317   3.2186112 ] 4   3 
[-2.6618924  -2.067048   -0.30674317  2.561662    2.9902904 ] 4   4 Match 1322

[ 1.2419087   2.5051255   0.78550464 -2.0456576  -2.9030824 ] 1   0 
1322


In [15]:
from sklearn import metrics
 
print(metrics.confusion_matrix(Targets,Pred))

[[144 126   7   2   0]
 [110 431  76  15   1]
 [ 11 136 133 100   9]
 [  1   5  39 346 119]
 [  0   1   8 122 268]]


In [16]:
target_names = ['Very Neg', 'Negative', 'Neutral','Positive','Very Pos']

print(metrics.classification_report(Targets, Pred,target_names =target_names))

              precision    recall  f1-score   support

    Very Neg       0.54      0.52      0.53       279
    Negative       0.62      0.68      0.65       633
     Neutral       0.51      0.34      0.41       389
    Positive       0.59      0.68      0.63       510
    Very Pos       0.68      0.67      0.67       399

    accuracy                           0.60      2210
   macro avg       0.59      0.58      0.58      2210
weighted avg       0.59      0.60      0.59      2210



In [17]:
Fold1_Predictions=pd.DataFrame(Pred, columns=['Pred1'])
Fold1_Predictions

Unnamed: 0,Pred1
0,1
1,1
2,1
3,4
4,2
...,...
2205,4
2206,1
2207,4
2208,4


In [18]:
Fold1_Predictions.to_excel(output_folder+'/Saves/fold1_Predictions.xls')

In [21]:
#clearing GPU cache

del(model)
del(TrainResult, TrainModel_outputs, EvalResult, EvalModel_outputs, TestResult, TestModel_outputs, wrong_predictions)
torch.cuda.empty_cache()

## Fold 2: training & caturing predictions

In [22]:
fold_number='2'

train=pd.read_excel('./folds/train_fold'+fold_number+'.xls')
Eval=pd.read_excel('./folds/valid'+fold_number+'.xls') #evaluation set


In [12]:

output_folder='./folds/fold'+fold_number+'/'+model_class+'/'+model_version+"/"
cache_directory= "./folds/fold"+fold_number+'/'+model_class+"/"+model_version+"/cache/"


print('model variables were set up: ')

 
save_every_steps=1285
# assuming training batch size of 8
# any number above 1284 saves the model only at every epoch
# Saving the model mid training very often will consume disk space fast

train_args={
    "output_dir":output_folder,
    "cache_dir":cache_directory,
    'reprocess_input_data': True,
    'overwrite_output_dir': True,
    'num_train_epochs': 2,
    "save_steps": save_every_steps, 
    "learning_rate": 1.2e-5,
    "train_batch_size": 32,
    "eval_batch_size": 16,
    "weight_decay": 0,
    "evaluate_during_training_steps": 312,
    "max_seq_length": 100,
    "n_gpu": 1,
}

# Create a ClassificationModel
model = ClassificationModel(model_class, model_version, num_labels=labels_count, args=train_args) 

NameError: name 'model_class' is not defined

In [24]:
# loading the checkpoint that gave the best result
'''
CheckPoint='checkpoint-143-epoch-1'  #epoch 1


preSavedCheckpoint=output_folder+CheckPoint

print('Loading model, please wait...')
model = ClassificationModel( model_class, preSavedCheckpoint, num_labels=labels_count, args=train_args) 
print('model in use is :', preSavedCheckpoint )
'''

"\nCheckPoint='checkpoint-143-epoch-1'  #epoch 1\n\n\npreSavedCheckpoint=output_folder+CheckPoint\n\nprint('Loading model, please wait...')\nmodel = ClassificationModel( model_class, preSavedCheckpoint, num_labels=labels_count, args=train_args) \nprint('model in use is :', preSavedCheckpoint )\n"

In [25]:
# Train the model
current_time = datetime.now()
model.train_model(train)
print("Training time: ", datetime.now() - current_time)

Converting to features started. Cache is not used.


HBox(children=(FloatProgress(value=0.0, max=6829.0), HTML(value='')))


Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=2.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=214.0, style=ProgressStyle(descri…

Running loss: 1.605266



Running loss: 1.647637Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Running loss: 1.623508Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Running loss: 1.634390Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 8192.0
Running loss: 1.508893



Running loss: 0.842845


HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=214.0, style=ProgressStyle(descri…

Running loss: 0.902053Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 4096.0
Running loss: 1.068337

Training of roberta model complete. Saved to ./folds/fold2/roberta/roberta-large/.
Training time:  0:16:14.991749


In [26]:
TrainResult, TrainModel_outputs, wrong_predictions = model.eval_model(train, acc=sklearn.metrics.accuracy_score)
 
EvalResult, EvalModel_outputs, wrong_predictions = model.eval_model(Eval, acc=sklearn.metrics.accuracy_score)

TestResult, TestModel_outputs, wrong_predictions = model.eval_model(test, acc=sklearn.metrics.accuracy_score)

print('Training Result:', TrainResult['acc'])
#print('Model Out:', TrainModel_outputs)

print('Eval Result:', EvalResult['acc'])
#print('Model Out:', EvalModel_outputs)

print('Test Set Result:', TestResult['acc'])
#print('Model Out:', TestModel_outputs)

Converting to features started. Cache is not used.


HBox(children=(FloatProgress(value=0.0, max=6829.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=427.0), HTML(value='')))


{'mcc': 0.5270466209463697, 'acc': 0.6251281300336798, 'eval_loss': 0.8494768137116622}
Converting to features started. Cache is not used.


HBox(children=(FloatProgress(value=0.0, max=1705.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=107.0), HTML(value='')))


{'mcc': 0.4543178303132622, 'acc': 0.5695014662756598, 'eval_loss': 0.9625523385600508}
Converting to features started. Cache is not used.


HBox(children=(FloatProgress(value=0.0, max=2210.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=139.0), HTML(value='')))


{'mcc': 0.4774774805052575, 'acc': 0.5909502262443439, 'eval_loss': 0.9182235628580876}
Training Result: 0.6251281300336798
Eval Result: 0.5695014662756598
Test Set Result: 0.5909502262443439


In [27]:
Pred=[]
Targets=[]

countCorrect=0

for row in range(TestModel_outputs.shape[0]):
    outputs=TestModel_outputs[row]
    #print(test.iloc[row,0])
    print(outputs, end=' ')
    
    result=0
    if outputs[0]<outputs[1]:result=1
    if outputs[result]<outputs[2]:result=2
    if outputs[result]<outputs[3]:result=3
    if outputs[result]<outputs[4]:result=4
    
    Pred.append(result)
    Targets.append(test.iloc[row,1])
    print(result, ' ',test.iloc[row,1], end=' ')
    if result==test.iloc[row,1]:
        countCorrect+=1
        print('Match',countCorrect)
    print('')

print(countCorrect)

[-0.7495117  1.7011719  1.6083984 -0.4111328 -2.8320312] 1   1 Match 1

[ 1.0351562  2.4570312  0.6503906 -1.8427734 -3.3085938] 1   0 
[-0.09667969  2.3789062   1.328125   -1.0947266  -3.2050781 ] 1   2 
[-2.125     -2.1464844 -0.8613281  2.109375   3.2324219] 4   4 Match 2

[-0.93652344  1.6689453   1.7333984  -0.24414062 -2.4785156 ] 2   0 
[-2.8867188  -1.59375     0.46191406  3.0917969   0.8696289 ] 3   3 Match 3

[ 2.1445312   2.1660156   0.00565338 -2.2421875  -3.1640625 ] 1   0 
[-2.71875    -2.0410156  -0.02410889  2.6875      2.7050781 ] 4   4 Match 4

[ 2.6835938   1.2265625  -0.27807617 -1.9394531  -2.1445312 ] 0   0 Match 5

[ 2.5195312   1.7558594  -0.25048828 -2.1621094  -2.7167969 ] 0   0 Match 6

[ 0.6171875  2.4375     1.0039062 -1.6962891 -3.5234375] 1   2 
[-2.5585938  -1.15625     0.83496094  2.2636719   0.60791016] 3   4 
[-1.15625     1.5791016   1.9316406   0.23498535 -2.9570312 ] 2   1 
[ 2.140625    2.1289062   0.15734863 -1.9296875  -2.9824219 ] 0   0 Match 7

[ 0.3215332  2.4121094  1.2041016 -1.3505859 -3.2460938] 1   1 Match 75

[ 2.1582031   2.1464844   0.02610779 -2.1875     -2.9433594 ] 0   1 
[ 2.2578125   2.140625   -0.02841187 -2.15625    -3.1132812 ] 0   0 Match 76

[-0.7363281   1.9658203   1.4580078  -0.47631836 -2.5644531 ] 1   3 
[ 1.6601562  2.4980469  0.3034668 -2.0898438 -3.21875  ] 1   2 
[-2.8398438  -1.2138672   1.2216797   2.4082031   0.10717773] 3   3 Match 77

[-2.53125    -1.7958984  -0.18530273  2.1328125   2.65625   ] 4   4 Match 78

[-2.4550781 -2.2890625 -0.796875   2.2890625  3.4101562] 4   4 Match 79

[-0.09259033  1.8320312   1.5683594  -0.6074219  -3.1054688 ] 1   2 
[-1.9550781  -0.7060547   0.8745117   1.0576172  -0.27148438] 3   1 
[-2.6445312  -0.9980469   0.9458008   2.1992188   0.16845703] 3   2 
[-2.3574219  -2.3417969  -0.79003906  2.3027344   3.4628906 ] 4   4 Match 80

[-2.6484375  -2.140625   -0.37353516  2.1992188   2.9902344 ] 4   3 
[-2.5136719 -2.2773438 -0.5332031  2.2949219  3.1347656] 4   3 


[-0.87060547  1.9814453   1.7158203  -0.51123047 -3.1035156 ] 1   2 
[-2.8828125  -2.0917969  -0.17883301  2.9296875   2.4179688 ] 3   4 
[-1.0478516   1.0839844   1.796875    0.17614746 -2.3457031 ] 2   1 
[-1.9902344 -2.1777344 -0.8334961  1.8457031  3.4355469] 4   4 Match 137

[-2.6484375  -1.1132812   0.98095703  2.6972656   0.20214844] 3   2 
[-0.93896484  1.4814453   1.6054688   0.15710449 -2.5839844 ] 2   0 
[-2.515625  -2.2949219 -0.5908203  2.3925781  3.2753906] 4   3 
[-2.7226562 -0.6303711  1.7548828  2.0957031 -1.1513672] 3   2 
[-3.0097656  -1.5683594   0.76220703  2.734375    0.9921875 ] 3   3 Match 138

[-3.1367188  -1.7871094   0.40234375  2.9570312   1.1298828 ] 3   3 Match 139

[-0.21533203  2.328125    1.4287109  -0.9892578  -2.9199219 ] 1   1 Match 140

[ 0.6635742  2.1386719  1.0576172 -1.5820312 -3.1679688] 1   0 
[-2.6308594  -1.9521484  -0.13391113  2.4296875   2.2480469 ] 3   4 
[-2.4238281 -2.3164062 -0.7548828  2.2480469  3.6191406] 4   4 Match 141

[-0.20019

[ 0.42944336  2.4003906   1.15625    -1.3369141  -3.4277344 ] 1   1 Match 208

[-2.7539062  -1.7998047   0.30786133  2.6582031   1.7207031 ] 3   3 Match 209

[-2.484375  -2.1894531 -0.65625    2.3691406  3.2949219] 4   3 
[ 1.1757812  2.2792969  0.828125  -1.8183594 -3.28125  ] 1   1 Match 210

[-2.9160156  -1.2167969   0.84472656  2.5429688   0.0993042 ] 3   3 Match 211

[-2.6035156  -2.2460938  -0.47485352  2.7773438   2.6542969 ] 3   4 
[-2.0136719  -0.38305664  1.2666016   1.5625     -0.80126953] 3   3 Match 212

[ 2.8339844   1.1904297  -0.41625977 -1.8173828  -1.9912109 ] 0   0 Match 213

[ 0.22912598  1.8994141   1.0292969  -1.1708984  -2.8808594 ] 1   1 Match 214

[ 0.2467041  2.0449219  1.3955078 -1.0761719 -3.2617188] 1   1 Match 215

[-1.3085938   1.5488281   2.09375     0.20288086 -2.6933594 ] 2   2 Match 216

[-2.6386719 -2.2929688 -0.5854492  2.4726562  3.2617188] 4   4 Match 217

[ 1.9570312   1.9775391   0.11425781 -2.0351562  -2.9277344 ] 1   1 Match 218

[-2.6074219 -

[-2.53125    -2.2402344  -0.43725586  2.4882812   3.1152344 ] 4   4 Match 287

[ 1.8261719   2.3007812   0.46313477 -2.03125    -3.2871094 ] 1   0 
[-2.1894531 -2.2578125 -0.7758789  2.0214844  3.4921875] 4   4 Match 288

[ 1.1142578  2.484375   0.6635742 -1.8408203 -3.3417969] 1   1 Match 289

[-1.8183594   0.20983887  1.671875    1.1728516  -1.7158203 ] 2   1 
[ 0.14648438  2.2988281   1.3125     -1.3476562  -3.3964844 ] 1   1 Match 290

[-0.25024414  2.234375    1.3183594  -0.9848633  -3.0644531 ] 1   1 Match 291

[-2.8574219 -2.1933594 -0.2927246  2.6875     2.6953125] 4   4 Match 292

[-2.1269531 -2.2382812 -0.9658203  2.0058594  3.8476562] 4   4 Match 293

[ 2.4746094   2.0234375  -0.22851562 -2.2070312  -2.9980469 ] 0   2 
[-2.3046875  -1.8574219  -0.02940369  2.3847656   1.8515625 ] 3   3 Match 294

[ 0.8754883   2.4316406   0.72558594 -1.8447266  -3.3828125 ] 1   1 Match 295

[ 2.5761719   1.9169922  -0.15991211 -2.2890625  -2.9160156 ] 0   1 
[ 2.5644531   1.9316406  -0.26733

[-3.0253906 -1.3701172  0.7036133  2.2910156  1.2666016] 3   3 Match 361

[ 1.0800781  2.0859375  0.9291992 -1.5214844 -3.2714844] 1   2 
[ 1.7246094   2.453125    0.29589844 -2.2148438  -3.2636719 ] 1   1 Match 362

[ 2.7734375   1.4707031  -0.14562988 -1.9130859  -2.3398438 ] 0   0 Match 363

[-2.3984375 -2.0234375 -0.3527832  2.4277344  2.8945312] 4   3 
[-2.8125    -2.2597656 -0.1685791  2.6757812  2.5      ] 3   3 Match 364

[-3.0605469  -2.0664062   0.05111694  2.7460938   2.2109375 ] 3   3 Match 365

[-2.9121094  -1.9140625   0.03152466  2.8027344   2.0566406 ] 3   3 Match 366

[-2.1113281  0.3798828  2.0292969  1.1767578 -1.8789062] 2   2 Match 367

[-3.0195312  -1.6972656   0.37719727  2.6992188   1.1132812 ] 3   4 
[ 1.8085938  2.3730469  0.2322998 -2.0332031 -3.15625  ] 1   1 Match 368

[-1.9619141  -2.2519531  -0.89453125  1.8583984   3.734375  ] 4   4 Match 369

[ 1.0390625  2.4160156  0.7841797 -1.7158203 -3.2304688] 1   0 
[-2.0410156 -2.2265625 -1.0136719  1.7919922  3.

[-2.2636719  -0.87402344  1.4970703   1.6884766  -0.19628906] 3   2 
[ 0.18835449  2.1113281   1.1484375  -1.2353516  -3.4121094 ] 1   1 Match 436

[ 1.5410156   2.4980469   0.53222656 -2.0644531  -3.25      ] 1   1 Match 437

[-2.2851562  -2.3203125  -0.86621094  2.1054688   3.6113281 ] 4   3 
[-2.8886719 -1.5839844  0.390625   3.0839844  0.9165039] 3   2 
[-2.9433594 -1.7578125  0.1270752  2.6679688  2.1777344] 3   3 Match 438

[ 1.5195312   2.4082031   0.48754883 -1.9980469  -3.3535156 ] 1   0 
[-0.95947266  1.5322266   1.9257812  -0.39868164 -2.4140625 ] 2   2 Match 439

[-2.4082031  -2.265625   -0.87646484  2.1953125   3.5800781 ] 4   4 Match 440

[-2.234375  -2.0996094 -0.8510742  1.9208984  3.2988281] 4   3 
[-3.0136719  -1.1201172   0.9213867   2.7402344   0.22229004] 3   3 Match 441

[-2.8476562  -1.7460938   0.27368164  2.6777344   1.5302734 ] 3   3 Match 442

[-2.9882812  -1.5908203   0.41137695  2.5527344   0.77001953] 3   4 
[ 2.6308594  1.125     -0.3803711 -1.7519531 -2.

[-2.9902344  -1.7490234   0.48608398  2.9648438   1.0527344 ] 3   3 Match 512

[ 0.5551758   2.1640625   0.90478516 -1.1474609  -2.9941406 ] 1   0 
[-2.1835938 -2.2246094 -0.7441406  2.109375   3.3105469] 4   4 Match 513

[-2.7792969  -1.3623047   0.88183594  2.8457031   0.5058594 ] 3   3 Match 514

[-2.4511719 -2.3203125 -0.6269531  2.2207031  2.9589844] 4   3 
[-3.2167969  -1.8535156   0.36547852  3.0117188   1.2509766 ] 3   4 
[ 0.92871094  2.4980469   0.8857422  -1.7929688  -3.46875   ] 1   1 Match 515

[-2.09375   -2.1757812 -0.7529297  1.8408203  3.3222656] 4   3 
[-2.8964844  -2.1679688  -0.40673828  2.7167969   2.7441406 ] 4   2 
[-0.41674805  2.1796875   1.421875   -0.9008789  -3.171875  ] 1   1 Match 516

[-2.8476562  -2.0117188  -0.09594727  2.9902344   1.5214844 ] 3   3 Match 517

[-3.1113281 -1.8632812  0.5600586  2.6054688  1.2636719] 3   3 Match 518

[ 2.0253906  1.9052734  0.3017578 -1.8388672 -3.1015625] 0   0 Match 519

[-2.0429688   1.0830078   2.1484375   0.72265625

[ 1.8056641   2.15625     0.24707031 -1.7802734  -3.0878906 ] 1   1 Match 595

[-1.7128906 -2.2207031 -0.9614258  1.7197266  3.7402344] 4   4 Match 596

[-3.1152344  -1.7900391   0.20666504  2.75        1.6601562 ] 3   2 
[-2.6542969  -1.2265625   0.58740234  1.8125      0.91552734] 3   3 Match 597

[-2.8359375 -1.8974609  0.078125   2.7792969  1.9277344] 3   2 
[-1.8505859 -2.0859375 -0.8857422  1.7373047  3.5078125] 4   4 Match 598

[-0.12493896  2.0664062   1.5117188  -1.0703125  -3.2011719 ] 1   1 Match 599

[ 0.6772461  2.4824219  0.9824219 -1.7382812 -3.4082031] 1   1 Match 600

[-0.76660156  1.5380859   1.5693359  -0.02490234 -3.0214844 ] 2   2 Match 601

[-1.3027344   0.5756836   1.7880859   0.38500977 -1.9023438 ] 2   2 Match 602

[-2.078125   -0.28930664  1.7724609   1.6328125  -0.94091797] 2   2 Match 603

[ 2.2636719   2.2109375  -0.02407837 -2.140625   -3.1601562 ] 0   1 
[ 1.6972656   2.2539062   0.41601562 -1.9208984  -3.1542969 ] 1   1 Match 604

[-3.0371094  -1.8720703

[-0.60058594  1.8876953   1.7851562  -0.5703125  -3.3457031 ] 1   2 
[ 1.0810547  2.4902344  0.6826172 -1.8632812 -3.2148438] 1   1 Match 675

[ 1.4404297   2.359375    0.60498047 -2.0234375  -3.3300781 ] 1   1 Match 676

[-2.5996094  -1.2646484   1.0712891   2.2109375   0.43164062] 3   3 Match 677

[-1.3212891   1.1396484   2.1074219   0.38623047 -2.6679688 ] 2   1 
[-0.07983398  2.359375    1.3632812  -1.2900391  -3.3476562 ] 1   2 
[-2.0878906 -2.1640625 -0.8881836  1.8564453  3.6484375] 4   3 
[-3.0019531  -1.4638672   0.42993164  2.9628906   1.1367188 ] 3   3 Match 678

[ 1.0917969   2.4765625   0.77783203 -1.7548828  -3.3632812 ] 1   1 Match 679

[-2.3964844   0.13696289  1.5478516   1.7216797  -1.5253906 ] 3   2 
[ 0.28710938  2.2519531   1.1269531  -1.2519531  -3.2792969 ] 1   1 Match 680

[-2.9785156  -1.4091797   0.8203125   2.5859375   0.60839844] 3   2 
[-3.         -1.1162109   1.0859375   2.6953125   0.09527588] 3   3 Match 681

[-2.0820312  -2.3359375  -0.87109375  1.863

[-2.5898438  -0.19335938  1.5947266   2.1972656  -1.1132812 ] 3   2 
[-2.984375   -1.4003906   0.72314453  2.6425781   0.74609375] 3   3 Match 742

[ 1.2470703  2.4140625  0.6459961 -1.8017578 -3.4589844] 1   2 
[-2.5039062 -2.1445312 -0.6821289  2.3320312  3.5488281] 4   3 
[-2.8671875  -1.9472656  -0.14367676  2.6328125   2.4082031 ] 3   3 Match 743

[-0.91015625 -0.41674805  0.4020996   0.55908203  0.10528564] 3   3 Match 744

[ 1.8652344   2.1328125   0.16845703 -1.9365234  -2.9433594 ] 1   1 Match 745

[ 0.5732422  2.0664062  1.0175781 -1.2070312 -3.3925781] 1   1 Match 746

[ 0.04800415  2.3359375   1.4287109  -1.1005859  -3.4003906 ] 1   2 
[ 2.7265625  1.3222656 -0.2788086 -1.8964844 -2.2792969] 0   0 Match 747

[-0.17797852  2.3945312   1.5136719  -1.1640625  -3.5136719 ] 1   2 
[ 1.0283203  2.0371094  0.5078125 -1.3710938 -3.0742188] 1   2 
[ 1.9208984  2.3417969  0.2644043 -2.1816406 -3.3535156] 1   1 Match 748

[ 2.4765625   1.8173828  -0.17651367 -2.1523438  -2.78125   ] 0

[-2.1796875 -1.984375  -0.6923828  1.7568359  3.2695312] 4   3 
[-2.0507812  -2.1816406  -0.96728516  1.9433594   3.8027344 ] 4   4 Match 819

[ 0.20568848  2.375       1.1982422  -1.4814453  -3.4804688 ] 1   0 
[ 1.3320312  2.5136719  0.5517578 -1.8330078 -3.1992188] 1   2 
[ 1.5634766  2.2597656  0.5522461 -2.0742188 -3.3730469] 1   0 
[-1.3164062   0.51171875  1.7011719   0.77490234 -2.0429688 ] 2   3 
[-2.8183594  -2.0234375  -0.12487793  2.9921875   1.7539062 ] 3   3 Match 820

[ 1.0908203  2.4414062  0.7548828 -1.8759766 -3.2597656] 1   1 Match 821

[ 1.3134766  2.2089844  0.6899414 -1.7480469 -3.3183594] 1   0 
[ 2.0957031   2.2558594   0.11444092 -2.1113281  -3.1582031 ] 1   0 
[-2.9003906  -1.1748047   0.98876953  2.7128906   0.35058594] 3   3 Match 822

[ 2.4492188   2.0292969  -0.04754639 -2.1269531  -2.8925781 ] 0   1 
[-2.7832031  -2.0214844  -0.13586426  2.7792969   2.3691406 ] 3   3 Match 823

[ 1.9521484   2.3378906   0.27172852 -2.1386719  -3.2167969 ] 1   2 
[ 2.70703

[-2.4296875 -2.0703125 -0.5576172  2.4335938  2.9433594] 4   3 
[-2.3554688 -2.2363281 -0.4099121  1.8554688  2.7265625] 4   3 
[-1.5283203   1.0947266   1.7177734   0.23522949 -2.2890625 ] 2   1 
[-1.1640625   1.8945312   2.0214844  -0.22802734 -3.0878906 ] 2   2 Match 900

[-1.2070312   0.15319824  1.1806641   0.61572266 -1.6005859 ] 2   4 
[-2.4824219  -2.1542969  -0.60302734  2.2851562   3.34375   ] 4   2 
[-2.2460938  -0.9448242   0.5649414   1.4052734   0.59472656] 3   4 
[-0.06445312  2.4257812   1.2490234  -1.3125     -3.3203125 ] 1   1 Match 901

[-2.2128906  -2.1328125  -0.79003906  1.9257812   3.71875   ] 4   4 Match 902

[ 2.390625    2.1699219  -0.01901245 -2.21875    -3.0058594 ] 0   1 
[-2.9179688  -0.67333984  1.4863281   2.3320312  -0.84814453] 3   1 
[ 1.8193359   2.3789062   0.36889648 -2.1601562  -3.265625  ] 1   1 Match 903

[-2.9179688  -1.71875     0.48291016  2.9238281   1.4238281 ] 3   3 Match 904

[-2.2089844 -2.0507812 -0.7998047  1.7412109  3.5820312] 4   4 

[ 2.03125     2.0605469   0.34155273 -1.8261719  -2.9902344 ] 1   0 
[-2.8964844  -1.0800781   1.1992188   2.4492188  -0.37329102] 3   3 Match 978

[-0.35986328  2.2792969   1.6689453  -0.9121094  -3.2011719 ] 1   1 Match 979

[ 0.28759766  2.1796875   1.1640625  -1.3681641  -3.359375  ] 1   1 Match 980

[ 1.3398438  2.5566406  0.5024414 -1.984375  -3.3730469] 1   0 
[-2.0703125 -0.7675781  0.7495117  1.5429688 -0.0579834] 3   3 Match 981

[ 1.5800781   2.3339844   0.54052734 -1.9638672  -3.3535156 ] 1   1 Match 982

[ 1.1396484  2.5410156  0.6171875 -1.8701172 -3.4570312] 1   1 Match 983

[-2.703125   -1.0927734   1.1083984   2.2421875   0.07000732] 3   2 
[-2.5351562  -2.2050781  -0.49658203  2.28125     3.0371094 ] 4   2 
[-1.9628906  -0.73046875  0.54296875  1.4238281   0.5214844 ] 3   3 Match 984

[-2.8144531  -1.9130859  -0.03102112  2.8261719   2.109375  ] 3   2 
[-2.0273438  -2.2910156  -0.87060547  2.0585938   3.2871094 ] 4   3 
[-3.0351562 -1.6904297  0.6694336  2.9785156  0.

[-2.0859375 -2.1582031 -0.8911133  2.0527344  3.7753906] 4   4 Match 1060

[ 0.7963867  2.3066406  0.8676758 -1.546875  -3.1171875] 1   1 Match 1061

[ 2.5371094   1.8310547  -0.19018555 -2.171875   -2.7265625 ] 0   0 Match 1062

[ 0.68603516  2.4160156   1.0839844  -1.7626953  -3.3730469 ] 1   1 Match 1063

[-2.2890625 -2.3632812 -0.9536133  2.2285156  3.4238281] 4   4 Match 1064

[-2.1738281 -2.0117188 -0.5629883  1.9394531  3.3203125] 4   4 Match 1065

[ 0.4404297  2.4785156  1.2304688 -1.4921875 -3.4121094] 1   1 Match 1066

[-1.9941406   0.53125     1.7988281   0.90722656 -2.0683594 ] 2   1 
[-2.7851562  -0.47998047  1.4619141   2.3867188  -0.5019531 ] 3   2 
[-3.0488281 -1.859375   0.2376709  2.9238281  1.7646484] 3   4 
[-3.046875   -1.3085938   0.86035156  2.8144531   0.6381836 ] 3   3 Match 1067

[-2.2050781  0.3010254  2.2285156  1.3046875 -1.4541016] 2   2 Match 1068

[-1.3710938   0.84521484  1.9482422   0.5283203  -2.1269531 ] 2   2 Match 1069

[-2.4824219  -1.9707031  -0.

[-2.5742188 -2.1621094 -0.5629883  2.6914062  3.0195312] 4   4 Match 1133

[-2.9140625  -1.1103516   1.1171875   2.7773438  -0.03353882] 3   3 Match 1134

[ 2.6425781   1.8144531  -0.22473145 -2.1132812  -2.7792969 ] 0   0 Match 1135

[-2.8105469  -2.1347656  -0.35961914  2.5019531   2.7832031 ] 4   4 Match 1136

[-0.85253906  1.2255859   1.25       -0.05285645 -2.6425781 ] 2   1 
[-2.0195312 -0.9736328  0.8520508  1.8896484  0.6484375] 3   3 Match 1137

[-0.9042969   1.4912109   1.8574219  -0.09802246 -3.0703125 ] 2   2 Match 1138

[-2.1777344   0.12463379  1.59375     1.7822266  -1.6396484 ] 3   3 Match 1139

[ 0.75634766  2.1894531   0.8964844  -1.421875   -3.4765625 ] 1   2 
[-2.5683594  -2.3925781  -0.72021484  2.4238281   3.2617188 ] 4   4 Match 1140

[ 0.05825806  2.4609375   1.4296875  -1.4511719  -3.3710938 ] 1   1 Match 1141

[ 2.09375     1.9580078   0.13989258 -1.9306641  -2.9414062 ] 0   1 
[ 0.82177734  2.4882812   0.8808594  -1.6044922  -3.2285156 ] 1   2 
[-2.8730469  -

[ 1.0009766  2.609375   0.9160156 -1.8222656 -3.5097656] 1   1 Match 1213

[-2.7089844  -0.7001953   1.0722656   2.0546875   0.01235962] 3   4 
[-3.1054688  -1.1083984   1.1552734   2.3925781   0.04086304] 3   2 
[-2.9257812  -1.0078125   0.92626953  2.6289062  -0.0592041 ] 3   2 
[-2.6894531  -2.0136719  -0.43652344  2.7050781   2.7402344 ] 4   3 
[-2.1113281  -2.1386719  -0.75439453  1.9003906   3.6308594 ] 4   4 Match 1214

[-2.8984375 -2.2753906 -0.4177246  2.6582031  2.703125 ] 4   4 Match 1215

[ 2.2734375   1.9990234   0.10614014 -1.9931641  -3.140625  ] 0   0 Match 1216

[-1.4667969   0.8432617   2.1914062   0.41796875 -2.109375  ] 2   2 Match 1217

[-0.00902557  2.0234375   1.5839844  -0.7915039  -2.8652344 ] 1   1 Match 1218

[ 2.15625     2.3085938   0.09075928 -2.1699219  -3.0625    ] 1   1 Match 1219

[-1.8710938   0.35766602  1.984375    0.9536133  -1.6982422 ] 2   2 Match 1220

[-2.6699219 -2.2109375 -0.5751953  2.5527344  2.9550781] 4   4 Match 1221

[ 1.6474609  2.4199

[-2.5175781  -2.015625   -0.25927734  2.4121094   2.8222656 ] 4   4 Match 1293

[-2.2871094 -2.1679688 -0.7182617  2.0234375  3.3613281] 4   3 
[-1.5703125   0.07098389  1.3203125   0.9194336  -0.98339844] 2   3 
[-2.8027344  -2.2148438  -0.30688477  2.578125    2.7285156 ] 4   4 Match 1294

[-3.1113281 -1.7060547  0.6459961  3.1757812  0.5913086] 3   3 Match 1295

[ 2.2109375   1.625       0.04223633 -1.6220703  -2.640625  ] 0   0 Match 1296

[ 2.9433594  1.3115234 -0.4790039 -1.8007812 -2.1621094] 0   0 Match 1297

[-2.7304688  -1.3242188   0.40283203  2.0117188   1.2011719 ] 3   2 
[-2.7910156  -0.84472656  1.1513672   2.796875   -0.42456055] 3   2 
[ 0.6635742  2.3320312  1.015625  -1.359375  -3.265625 ] 1   1 Match 1298

[ 0.40893555  2.0527344   1.1962891  -1.2041016  -3.0996094 ] 1   1 Match 1299

[ 1.8251953   2.3261719   0.17565918 -2.1132812  -3.2382812 ] 1   1 Match 1300

[ 0.52685547  2.3964844   0.9614258  -1.7304688  -3.4941406 ] 1   1 Match 1301

[ 2.1933594   2.1894531 

In [28]:
from sklearn import metrics
print(metrics.confusion_matrix(Targets,Pred))

[[137 133   7   2   0]
 [ 89 443  75  25   1]
 [ 13 139 109 113  15]
 [  0  11  25 317 157]
 [  0   0   4  95 300]]


In [29]:
target_names = ['Very Neg', 'Negative', 'Neutral','Positive','Very Pos']
print(metrics.classification_report(Targets, Pred,target_names =target_names))

              precision    recall  f1-score   support

    Very Neg       0.57      0.49      0.53       279
    Negative       0.61      0.70      0.65       633
     Neutral       0.50      0.28      0.36       389
    Positive       0.57      0.62      0.60       510
    Very Pos       0.63      0.75      0.69       399

    accuracy                           0.59      2210
   macro avg       0.58      0.57      0.56      2210
weighted avg       0.58      0.59      0.58      2210



In [30]:
Fold_Predictions=pd.DataFrame(Pred, columns=['Pred2'] )
Fold_Predictions

Unnamed: 0,Pred2
0,1
1,1
2,1
3,4
4,2
...,...
2205,4
2206,1
2207,4
2208,4


In [32]:
Fold_Predictions.to_excel(output_folder+'Saves/fold2_Predictions.xls')

In [34]:
#clearing GPU cache

del(model)
del(TrainResult, TrainModel_outputs, EvalResult, EvalModel_outputs, TestResult, TestModel_outputs, wrong_predictions)
torch.cuda.empty_cache()

NameError: name 'model' is not defined

## Fold 3: training & caturing predictions

In [35]:
fold_number='3'

train=pd.read_excel('./folds/train_fold'+fold_number+'.xls')
Eval=pd.read_excel('./folds/valid'+fold_number+'.xls') #evaluation set


In [36]:
 
output_folder='./folds/fold'+fold_number+'/'+model_class+'/'+model_version+"/"
cache_directory= "./folds/fold"+fold_number+'/'+model_class+"/"+model_version+"/cache/"


print('model variables were set up: ')

 
save_every_steps=1285
# assuming training batch size of 8
# any number above 1284 saves the model only at every epoch
# Saving the model mid training very often will consume disk space fast

train_args={
    "output_dir":output_folder,
    "cache_dir":cache_directory,
    'reprocess_input_data': True,
    'overwrite_output_dir': True,
    'num_train_epochs': 2,
    "save_steps": save_every_steps, 
    "learning_rate": 1.2e-5,
    "train_batch_size": 32,
    "eval_batch_size": 16,
    "weight_decay": 0,
    "evaluate_during_training_steps": 312,
    "max_seq_length": 100,
    "n_gpu": 1,
}

# Create a ClassificationModel
model = ClassificationModel(model_class, model_version, num_labels=labels_count, args=train_args) 

model variables were set up: 


In [37]:
# loading the checkpoint that gave the best result
'''
CheckPoint='checkpoint-130-epoch-1'  #epoch 1


preSavedCheckpoint=output_folder+CheckPoint

print('Loading model, please wait...')
model = ClassificationModel( model_class, preSavedCheckpoint, num_labels=labels_count, args=train_args) 
print('model in use is :', preSavedCheckpoint )
'''

"\nCheckPoint='checkpoint-130-epoch-1'  #epoch 1\n\n\npreSavedCheckpoint=output_folder+CheckPoint\n\nprint('Loading model, please wait...')\nmodel = ClassificationModel( model_class, preSavedCheckpoint, num_labels=labels_count, args=train_args) \nprint('model in use is :', preSavedCheckpoint )\n"

In [38]:
# Train the model
current_time = datetime.now()
model.train_model(train)
print("Training time: ", datetime.now() - current_time)

Converting to features started. Cache is not used.


HBox(children=(FloatProgress(value=0.0, max=6829.0), HTML(value='')))


Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=2.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=214.0, style=ProgressStyle(descri…

Running loss: 1.738001Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Running loss: 1.573199Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Running loss: 1.438697Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 8192.0
Running loss: 1.535589Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 4096.0
Running loss: 0.902376


HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=214.0, style=ProgressStyle(descri…

Running loss: 1.061669Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 2048.0
Running loss: 0.510550

Training of roberta model complete. Saved to ./folds/fold3/roberta/roberta-large/.
Training time:  0:16:54.503028


In [39]:
TrainResult, TrainModel_outputs, wrong_predictions = model.eval_model(train, acc=sklearn.metrics.accuracy_score)
 
EvalResult, EvalModel_outputs, wrong_predictions = model.eval_model(Eval, acc=sklearn.metrics.accuracy_score)

TestResult, TestModel_outputs, wrong_predictions = model.eval_model(test, acc=sklearn.metrics.accuracy_score)

print('Training Result:', TrainResult['acc'])
#print('Model Out:', TrainModel_outputs)

print('Eval Result:', EvalResult['acc'])
#print('Model Out:', EvalModel_outputs)

print('Test Set Result:', TestResult['acc'])
#print('Model Out:', TestModel_outputs)

Converting to features started. Cache is not used.


HBox(children=(FloatProgress(value=0.0, max=6829.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=427.0), HTML(value='')))


{'mcc': 0.5469919940205865, 'acc': 0.6418216429931176, 'eval_loss': 0.8197598973119008}
Converting to features started. Cache is not used.


HBox(children=(FloatProgress(value=0.0, max=1705.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=107.0), HTML(value='')))


{'mcc': 0.46626471718625995, 'acc': 0.5794721407624633, 'eval_loss': 0.9690464918858537}
Converting to features started. Cache is not used.


HBox(children=(FloatProgress(value=0.0, max=2210.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=139.0), HTML(value='')))


{'mcc': 0.4650448285648897, 'acc': 0.5796380090497738, 'eval_loss': 0.9241459018034901}
Training Result: 0.6418216429931176
Eval Result: 0.5794721407624633
Test Set Result: 0.5796380090497738


In [40]:
Pred=[]
Targets=[]

countCorrect=0

for row in range(TestModel_outputs.shape[0]):
    outputs=TestModel_outputs[row]
    #print(test.iloc[row,0])
    print(outputs, end=' ')
    
    result=0
    if outputs[0]<outputs[1]:result=1
    if outputs[result]<outputs[2]:result=2
    if outputs[result]<outputs[3]:result=3
    if outputs[result]<outputs[4]:result=4
    
    Pred.append(result)
    Targets.append(test.iloc[row,1])
    print(result, ' ',test.iloc[row,1], end=' ')
    if result==test.iloc[row,1]:
        countCorrect+=1
        print('Match',countCorrect)
    print('')

print(countCorrect)

[-0.22241211  2.6132812   1.6464844  -0.6254883  -2.875     ] 1   1 Match 1

[ 1.3984375  2.9902344  1.0166016 -1.5849609 -3.0136719] 1   0 
[-0.06884766  2.609375    1.5742188  -0.6303711  -2.6621094 ] 1   2 
[-2.5566406 -2.5742188 -1.0878906  2.0957031  3.3027344] 4   4 Match 2

[-0.49389648  1.5654297   2.1308594  -0.5283203  -2.5761719 ] 2   0 
[-3.3691406  -2.0917969   0.38134766  3.0859375   1.5029297 ] 3   3 Match 3

[ 2.8359375  2.4667969  0.2705078 -2.2675781 -2.6679688] 0   0 Match 4

[-3.3457031  -2.546875   -0.17907715  2.8984375   2.2695312 ] 3   4 
[ 3.3359375   1.6513672  -0.15930176 -1.9589844  -2.2285156 ] 0   0 Match 5

[ 3.2285156   2.203125    0.03579712 -2.09375    -2.5683594 ] 0   0 Match 6

[ 0.3388672  2.6640625  1.6855469 -0.8676758 -3.2890625] 1   2 
[-3.1992188  -1.9238281   0.29663086  3.0058594   1.3525391 ] 3   4 
[-1.8652344  1.4560547  2.0273438  0.9448242 -1.9462891] 2   1 
[ 2.5605469   2.2929688   0.52978516 -1.9873047  -2.8144531 ] 0   0 Match 7

[-2

[-3.1542969  -1.4814453   1.1875      2.7382812   0.22436523] 3   2 
[ 2.1679688   2.7558594   0.90771484 -1.9160156  -3.2148438 ] 1   1 Match 76

[-1.7529297  0.2109375  1.6796875  1.2783203 -1.2548828] 2   1 
[-1.0888672   1.5537109   1.8496094   0.37841797 -2.3320312 ] 2   2 Match 77

[ 2.6660156  2.71875    0.640625  -2.0527344 -2.9492188] 1   0 
[-2.4394531  -1.3359375   0.52734375  1.8828125   1.0273438 ] 3   3 Match 78

[-2.2753906  -1.4707031  -0.13208008  1.3603516   1.6308594 ] 4   3 
[-2.5019531 -2.5742188 -0.8564453  1.7216797  3.5058594] 4   3 
[-3.1972656  -2.5859375  -0.20336914  2.7109375   2.4003906 ] 3   2 
[-2.3652344  -1.6074219   0.35009766  1.8994141   1.2724609 ] 3   4 
[-2.9179688  -1.8671875   0.42016602  2.8457031   1.2402344 ] 3   3 Match 79

[-2.6796875 -2.5214844 -1.0927734  1.7900391  3.5429688] 4   4 Match 80

[ 2.9804688   2.1757812   0.27954102 -2.0390625  -2.7480469 ] 0   1 
[ 0.17663574  2.046875    1.3085938  -0.6152344  -2.8925781 ] 1   1 Match 81



[-2.9648438  -0.78564453  1.4267578   2.6210938  -0.61572266] 3   3 Match 153

[-3.1191406 -2.5273438 -0.3552246  2.7441406  2.5820312] 3   4 
[-1.5400391  -0.92871094  0.19311523  1.3515625   1.0244141 ] 3   4 
[-2.6992188 -2.640625  -1.0742188  1.8359375  3.5019531] 4   4 Match 154

[ 2.4160156   2.4707031   0.64697266 -1.9931641  -2.90625   ] 1   2 
[ 1.0927734  2.7792969  1.3466797 -1.4716797 -3.0605469] 1   1 Match 155

[ 2.5820312  2.4140625  0.7246094 -1.921875  -3.1777344] 0   1 
[-3.0039062 -2.6640625 -0.7739258  2.328125   3.1191406] 4   4 Match 156

[-3.1796875  -1.6445312   0.66748047  2.453125    1.0048828 ] 3   4 
[ 3.1289062   2.1914062   0.12200928 -2.0585938  -2.625     ] 0   0 Match 157

[ 2.6835938  2.3554688  0.7836914 -2.1796875 -2.9726562] 0   1 
[ 3.4609375  1.4541016 -0.4482422 -1.8330078 -1.8935547] 0   0 Match 158

[ 1.6767578  2.6015625  1.3330078 -1.5673828 -3.0117188] 1   1 Match 159

[-3.2949219  -2.4179688   0.18725586  3.046875    1.8837891 ] 3   3 Match

[-3.2460938  -2.5742188  -0.49121094  2.5839844   2.6582031 ] 4   3 
[-2.8789062  -1.15625     0.86035156  2.6875      0.20581055] 3   2 
[-3.1464844  -2.6054688  -0.54052734  2.5039062   2.9101562 ] 4   3 
[ 0.9511719  3.0527344  1.4375    -1.5019531 -3.2695312] 1   2 
[-3.1738281  -1.4853516   1.1738281   2.5976562   0.22192383] 3   3 Match 226

[-3.2285156 -1.4648438  0.8510742  2.8691406  0.6381836] 3   2 
[ 0.9013672  2.8105469  1.5664062 -1.3046875 -2.953125 ] 1   2 
[-3.2226562 -2.1835938  0.5102539  2.7949219  1.3242188] 3   3 Match 227

[ 1.4179688  3.0078125  1.1962891 -1.4726562 -3.1933594] 1   1 Match 228

[-2.9765625  -2.7207031  -0.56689453  2.5644531   2.8925781 ] 4   4 Match 229

[-3.03125   -1.7275391  0.9189453  2.640625   0.9316406] 3   2 
[-3.2539062  -2.0332031   0.46289062  2.9570312   1.4150391 ] 3   2 
[-0.01430511  2.3984375   1.5849609  -0.7832031  -2.8398438 ] 1   1 Match 230

[-2.1660156  1.1601562  1.8134766  1.1259766 -1.7363281] 2   2 Match 231

[-2.87109

[-0.04745483  2.4765625   1.6611328  -0.59716797 -2.8398438 ] 1   1 Match 309

[ 2.7871094  2.4121094  0.5917969 -2.0214844 -3.0058594] 0   1 
[-3.0546875  -1.7207031   0.33374023  2.6738281   1.2998047 ] 3   3 Match 310

[-3.3378906  -1.8916016   0.7915039   2.8203125   0.78759766] 3   4 
[ 2.796875    2.4238281   0.37158203 -2.1152344  -2.7324219 ] 0   1 
[-2.4804688 -2.4550781 -1.2880859  1.6318359  3.6289062] 4   4 Match 311

[-2.7832031  -2.171875    0.26000977  2.4570312   1.4375    ] 3   4 
[ 1.5107422  2.5449219  1.0146484 -1.6914062 -2.8359375] 1   1 Match 312

[-3.3671875  -2.3457031   0.28808594  2.9707031   1.6699219 ] 3   3 Match 313

[-2.140625    0.27368164  2.0390625   1.3203125  -1.6982422 ] 2   1 
[-0.01174927  2.2363281   1.703125   -0.7368164  -3.0546875 ] 1   2 
[ 3.0800781   2.2011719   0.17211914 -2.1113281  -2.5273438 ] 0   1 
[-2.703125  -2.53125   -0.7285156  2.1777344  3.1992188] 4   3 
[ 2.6445312   2.4160156   0.63183594 -1.9570312  -2.9296875 ] 0   1 
[ 2.

[ 0.8676758  2.9335938  1.5722656 -1.4228516 -3.1972656] 1   1 Match 382

[-2.7226562 -2.4160156 -0.8798828  1.9931641  3.3652344] 4   4 Match 383

[-0.99853516  0.98828125  2.3574219   0.10015869 -2.5019531 ] 2   1 
[-1.4912109  1.6552734  2.1054688  0.7167969 -2.5195312] 2   1 
[-2.9140625 -2.6523438 -0.5678711  2.3828125  2.8085938] 4   3 
[ 3.0742188   1.6083984   0.01010895 -1.8251953  -2.2558594 ] 0   2 
[-1.0683594   1.6640625   2.1035156   0.03591919 -2.7558594 ] 2   2 Match 384

[-3.0117188 -1.8916016  0.4482422  2.8007812  1.3007812] 3   3 Match 385

[-1.9453125  0.0173645  1.9462891  1.2265625 -1.3554688] 2   3 
[-2.3515625 -2.4902344 -1.2451172  1.3457031  3.7226562] 4   4 Match 386

[ 3.0371094   2.1894531   0.07202148 -2.0605469  -2.5898438 ] 0   0 Match 387

[-2.5371094   0.01673889  1.8164062   1.8681641  -1.3300781 ] 3   2 
[-1.3857422   1.3544922   1.5751953   0.81884766 -1.9511719 ] 2   2 Match 388

[ 2.0878906  2.6523438  0.7270508 -1.9345703 -3.2265625] 1   0 
[-2.

[ 2.7753906  2.3339844  0.3059082 -2.109375  -2.8710938] 0   0 Match 463

[-2.9414062 -0.5083008  1.8427734  2.125     -1.0849609] 3   3 Match 464

[ 2.5429688  2.5136719  0.6894531 -2.046875  -2.859375 ] 0   0 Match 465

[-2.6210938  0.6386719  1.7421875  1.6826172 -1.4902344] 2   4 
[-2.8398438 -2.578125  -0.8823242  1.9619141  3.3769531] 4   4 Match 466

[-1.265625    2.2167969   1.7861328   0.23754883 -2.3105469 ] 1   1 Match 467

[ 0.04248047  2.1601562   1.5595703  -0.74609375 -2.7792969 ] 1   1 Match 468

[ 0.8544922  2.5878906  1.5986328 -1.2050781 -3.1289062] 1   2 
[-1.5693359   1.2216797   2.1796875   0.74560547 -2.3359375 ] 2   1 
[-0.28808594  1.6777344   1.4179688  -0.14465332 -2.53125   ] 1   1 Match 469

[-2.8613281 -2.4824219 -0.7885742  2.2011719  3.1855469] 4   3 
[-3.2578125 -1.8076172  0.6821289  2.9746094  0.6557617] 3   3 Match 470

[-2.8125    -2.6757812 -0.8071289  2.1621094  3.2714844] 4   3 
[ 3.2871094   2.1933594   0.03231812 -2.1210938  -2.6699219 ] 0   0 

[-2.1171875 -2.1386719 -1.1494141  1.1035156  3.3144531] 4   3 
[ 1.453125   2.8828125  1.2119141 -1.7636719 -3.0625   ] 1   1 Match 556

[ 0.8613281  2.8046875  1.5097656 -1.2900391 -3.2226562] 1   0 
[ 3.4902344   1.8466797  -0.30859375 -2.0097656  -2.2089844 ] 0   1 
[-3.0488281 -1.1289062  1.2929688  2.7324219 -0.052948 ] 3   3 Match 557

[-3.1621094  -2.5683594  -0.43164062  2.7460938   2.6699219 ] 3   4 
[-2.4394531 -2.5605469 -1.0595703  1.6982422  3.5527344] 4   4 Match 558

[ 0.21044922  2.6269531   1.7246094  -0.8417969  -3.3027344 ] 1   1 Match 559

[-2.9550781 -2.4726562 -0.7192383  2.3183594  2.9121094] 4   3 
[ 2.4941406   2.1855469   0.58251953 -1.796875   -2.6386719 ] 0   1 
[-3.2832031  -1.5654297   0.7763672   2.9121094   0.57714844] 3   3 Match 560

[-2.90625    -2.1796875  -0.45239258  2.5097656   1.9335938 ] 3   3 Match 561

[-2.3984375  0.5126953  1.9736328  1.7519531 -1.7421875] 2   2 Match 562

[-3.2890625  -2.0273438   0.50146484  3.109375    1.2666016 ] 3   3 

[-2.625     -2.5058594 -1.0107422  1.96875    3.1660156] 4   4 Match 644

[-3.1484375  -2.1933594   0.51660156  2.8496094   1.1435547 ] 3   3 Match 645

[-0.48657227  2.3574219   1.7714844  -0.24194336 -2.7597656 ] 1   2 
[ 2.15625    2.5097656  1.0771484 -1.6494141 -3.1621094] 1   1 Match 646

[ 2.0429688  2.8339844  1.1210938 -1.890625  -3.1347656] 1   1 Match 647

[-2.0644531   0.07067871  1.8710938   1.1474609  -1.328125  ] 2   2 Match 648

[-0.01130676  2.7402344   1.5888672  -0.8378906  -2.9003906 ] 1   1 Match 649

[-1.8251953   0.35498047  2.4355469   0.9692383  -2.0078125 ] 2   2 Match 650

[ 1.3271484  2.4980469  1.3007812 -1.4658203 -3.1035156] 1   2 
[-2.4238281 -2.3613281 -1.0849609  1.4638672  3.6191406] 4   3 
[-3.0332031 -2.4453125 -0.5463867  2.6152344  2.7558594] 4   4 Match 651

[ 2.765625   2.5429688  0.671875  -2.0859375 -3.0566406] 0   1 
[-3.1777344  -2.6503906  -0.62841797  2.4472656   2.8105469 ] 4   3 
[ 1.2080078  2.7128906  1.3662109 -1.6523438 -3.2011719] 1

[-2.7734375 -2.5742188 -0.7626953  2.0019531  3.3535156] 4   4 Match 713

[ 3.0742188   2.2929688   0.12634277 -2.1367188  -2.6328125 ] 0   1 
[ 0.7260742  2.8125     1.7519531 -1.2451172 -3.3359375] 1   1 Match 714

[-0.3605957   1.9013672   1.5830078  -0.63134766 -2.2734375 ] 1   2 
[ 1.9121094  3.0117188  1.0263672 -1.8818359 -3.1445312] 1   1 Match 715

[ 2.9804688   2.1777344   0.20202637 -2.0253906  -2.7304688 ] 0   0 Match 716

[ 1.7861328  2.3574219  1.1210938 -1.5839844 -3.03125  ] 1   2 
[-3.2304688  -2.3242188  -0.40551758  2.6796875   2.2148438 ] 3   3 Match 717

[-3.1191406  -2.6113281  -0.65283203  2.5136719   2.9628906 ] 4   4 Match 718

[-0.33764648  2.4140625   1.6865234  -0.6821289  -2.7363281 ] 1   2 
[-3.1777344  -2.3867188   0.09973145  3.0078125   2.0351562 ] 3   3 Match 719

[-2.8554688 -2.5585938 -0.7285156  2.3359375  3.0527344] 4   4 Match 720

[-0.6147461   2.1816406   1.8544922  -0.07843018 -2.6738281 ] 1   1 Match 721

[-2.7792969 -2.4550781 -0.8183594  1.9

[ 2.2070312  2.7792969  0.6845703 -2.0117188 -3.0585938] 1   1 Match 796

[ 1.2128906  2.7929688  1.3662109 -1.5927734 -3.3144531] 1   0 
[-1.390625   1.5820312  2.0976562  0.5332031 -2.5175781] 2   1 
[-2.8828125  -0.66259766  1.7451172   2.4296875  -0.9716797 ] 3   2 
[ 0.89404297  2.7363281   1.7460938  -1.4375     -3.0996094 ] 1   1 Match 797

[-2.7285156 -2.6152344 -0.9394531  2.0253906  3.4179688] 4   3 
[-2.4824219 -2.5058594 -1.1542969  1.5722656  3.6699219] 4   4 Match 798

[ 1.1464844  2.8574219  1.1279297 -1.6123047 -3.2402344] 1   0 
[ 1.6923828  2.9804688  1.1103516 -1.6972656 -3.0566406] 1   2 
[ 2.6328125   2.8242188   0.77197266 -2.1035156  -3.0859375 ] 1   0 
[-2.5292969 -0.2088623  1.9296875  2.0722656 -1.3066406] 3   3 Match 799

[-3.1796875  -2.0292969   0.19946289  2.9726562   1.3222656 ] 3   3 Match 800

[ 1.7890625  2.9453125  1.0664062 -1.8408203 -2.9140625] 1   1 Match 801

[ 1.7275391  2.4804688  1.0507812 -1.5078125 -3.1601562] 1   0 
[ 2.6835938  2.6132812  

[-2.4375    -2.5117188 -1.1835938  1.4365234  3.7714844] 4   4 Match 879

[ 2.6992188   2.3242188   0.34155273 -1.9492188  -2.7460938 ] 0   1 
[-2.7382812  -0.24206543  1.7080078   2.1953125  -1.1435547 ] 3   1 
[ 2.0292969  2.7246094  0.9580078 -1.9580078 -3.2089844] 1   1 Match 880

[-3.28125    -1.8886719   0.44726562  2.9804688   1.1953125 ] 3   3 Match 881

[-2.4863281 -2.5273438 -1.1826172  1.5966797  3.5742188] 4   4 Match 882

[-3.1230469 -2.6855469 -0.7451172  2.3125     3.1503906] 4   3 
[-2.4902344  -0.11999512  1.8583984   1.5742188  -1.1669922 ] 2   2 Match 883

[ 0.6582031  2.4824219  1.9179688 -1.2080078 -3.1484375] 1   1 Match 884

[ 2.1972656   2.9765625   0.78515625 -2.0214844  -3.0292969 ] 1   1 Match 885

[-2.4042969  -2.0800781  -0.33129883  1.796875    2.3320312 ] 4   3 
[-3.1796875  -1.2646484   0.96484375  2.6503906   0.33984375] 3   4 
[-2.484375  -2.4101562 -1.0761719  1.7353516  3.6054688] 4   3 
[-3.3339844  -1.6738281   0.7470703   2.9746094   0.38354492] 3

[ 2.3769531   2.6171875   0.51220703 -1.9785156  -2.84375   ] 1   0 
[-2.4101562  -0.05563354  1.7089844   1.7099609  -1.1455078 ] 3   4 
[-1.7529297  1.6992188  2.0429688  0.7607422 -2.2304688] 2   1 
[-3.4082031 -1.8251953  0.9453125  2.7929688  0.3618164] 3   3 Match 968

[-1.3427734  1.2148438  2.3457031  0.6977539 -2.3613281] 2   2 Match 969

[ 1.6289062   2.7753906   0.86083984 -1.6914062  -2.7988281 ] 1   0 
[ 2.0605469   2.2753906   0.87109375 -1.6328125  -2.8769531 ] 1   1 Match 970

[-1.9453125  0.8769531  2.0410156  1.2460938 -1.8310547] 2   2 Match 971

[-0.01184082  2.1074219   2.0039062  -0.74365234 -3.0273438 ] 1   1 Match 972

[-2.5917969 -2.5058594 -0.9658203  1.6738281  3.3300781] 4   3 
[ 1.3769531  2.6464844  1.1933594 -1.4677734 -2.9550781] 1   1 Match 973

[-3.0703125  -2.5722656  -0.21728516  2.5429688   2.6132812 ] 4   4 Match 974

[ 2.1015625  2.9433594  0.7324219 -2.0957031 -3.09375  ] 1   1 Match 975

[-3.2617188  -2.2714844   0.20703125  2.5898438   1.685546

[-2.9882812  -2.0625      0.21484375  2.8535156   1.8662109 ] 3   4 
[-1.3154297   1.7050781   1.8408203   0.16967773 -1.9248047 ] 2   4 
[-2.8339844 -0.5449219  1.7919922  2.375     -1.015625 ] 3   1 
[-3.2363281  -2.4902344  -0.04714966  2.7890625   2.0644531 ] 3   3 Match 1052

[ 0.2524414  2.2402344  1.9287109 -0.8510742 -2.9023438] 1   1 Match 1053

[-3.1875    -2.0625     0.3095703  2.4101562  1.6035156] 3   3 Match 1054

[ 3.2089844   1.9492188   0.38134766 -2.0039062  -2.5859375 ] 0   0 Match 1055

[-2.8613281  -2.6855469  -0.87109375  2.2011719   3.2363281 ] 4   4 Match 1056

[ 2.1328125  2.9042969  0.6748047 -2.0253906 -3.1328125] 1   0 
[-2.9550781  -2.5136719  -0.68359375  2.3046875   3.1054688 ] 4   4 Match 1057

[ 2.2773438  2.71875    1.1660156 -2.109375  -3.3320312] 1   1 Match 1058

[ 1.4863281  2.515625   1.2402344 -1.4892578 -3.2226562] 1   1 Match 1059

[-3.2246094  -2.0839844  -0.07904053  2.5214844   1.9814453 ] 3   3 Match 1060

[ 1.59375    2.8730469  1.1347656 

[-2.9238281 -1.5615234  0.7416992  2.21875    1.2138672] 3   1 
[-2.1816406 -2.3066406 -1.2949219  1.2441406  3.5820312] 4   4 Match 1140

[-2.5488281 -2.5859375 -1.0917969  1.6425781  3.5742188] 4   4 Match 1141

[ 2.9804688  2.6210938  0.5083008 -2.2167969 -3.       ] 0   1 
[ 0.76953125  2.1035156   1.7685547  -0.8276367  -2.921875  ] 1   3 
[-2.5058594 -2.4121094 -1.1855469  1.5742188  3.5742188] 4   4 Match 1142

[ 3.0878906   2.0742188   0.15112305 -1.9902344  -2.5742188 ] 0   1 
[ 2.5351562   2.4414062   0.42407227 -2.0585938  -2.828125  ] 0   0 Match 1143

[ 2.7851562   2.2832031   0.43188477 -2.0371094  -2.8945312 ] 0   1 
[ 2.8183594  2.2402344  0.2866211 -1.9033203 -2.6015625] 0   0 Match 1144

[-3.1191406  -1.5283203   1.0683594   2.828125    0.33203125] 3   3 Match 1145

[-2.9082031  -2.2949219  -0.38378906  1.9863281   2.65625   ] 4   4 Match 1146

[-2.6660156  -2.5410156  -0.85058594  2.0820312   3.2480469 ] 4   4 Match 1147

[-3.0390625  -2.2460938   0.01303101  2.64453


[-3.171875   -1.1425781   1.296875    2.9277344  -0.13537598] 3   3 Match 1228

[ 3.0429688   1.8242188  -0.01165771 -1.8193359  -2.3730469 ] 0   0 Match 1229

[-3.0429688 -2.6015625 -0.7050781  2.6054688  2.359375 ] 3   3 Match 1230

[ 2.9785156   2.3925781   0.15991211 -2.28125    -2.6875    ] 0   0 Match 1231

[-2.2832031 -2.2695312 -1.1162109  1.2285156  3.4707031] 4   4 Match 1232

[-3.2207031  -1.96875     0.48657227  2.71875     1.0380859 ] 3   3 Match 1233

[-2.7578125  -1.4912109   0.39526367  2.546875    0.609375  ] 3   3 Match 1234

[-1.3955078   1.7919922   2.0566406   0.47045898 -2.625     ] 2   2 Match 1235

[-0.16577148  2.6875      1.6767578  -0.7348633  -2.9335938 ] 1   2 
[ 3.4296875   1.5214844  -0.49145508 -1.8505859  -1.8886719 ] 0   0 Match 1236

[ 0.66748047  2.9707031   1.3623047  -1.2919922  -2.9335938 ] 1   1 Match 1237

[ 1.4638672  3.0546875  1.1416016 -1.6982422 -3.2265625] 1   1 Match 1238

[ 2.8398438  2.5214844  0.4206543 -2.2910156 -2.8203125] 0   1 
[

In [41]:
from sklearn import metrics
print(metrics.confusion_matrix(Targets,Pred))

[[164 105   7   3   0]
 [135 392  83  21   2]
 [ 16 139 116 111   7]
 [  2   9  37 340 122]
 [  0   1   7 122 269]]


In [42]:
target_names = ['Very Neg', 'Negative', 'Neutral','Positive','Very Pos']
print(metrics.classification_report(Targets, Pred,target_names =target_names))

              precision    recall  f1-score   support

    Very Neg       0.52      0.59      0.55       279
    Negative       0.61      0.62      0.61       633
     Neutral       0.46      0.30      0.36       389
    Positive       0.57      0.67      0.61       510
    Very Pos       0.67      0.67      0.67       399

    accuracy                           0.58      2210
   macro avg       0.57      0.57      0.56      2210
weighted avg       0.57      0.58      0.57      2210



In [43]:
Fold_Predictions=pd.DataFrame(Pred, columns=['Pred3'] )
Fold_Predictions

Unnamed: 0,Pred3
0,1
1,1
2,1
3,4
4,2
...,...
2205,4
2206,0
2207,4
2208,3


In [44]:
Fold_Predictions.to_excel(output_folder+'/Saves/fold3_Predictions.xls')

In [45]:
#clearing GPU cache

del(model)
del(TrainResult, TrainModel_outputs, EvalResult, EvalModel_outputs, TestResult, TestModel_outputs, wrong_predictions)
torch.cuda.empty_cache()

## Fold 4: training & capturing predictions

In [10]:
fold_number='4'

train=pd.read_excel('./folds/train_fold'+fold_number+'.xls')
Eval=pd.read_excel('./folds/valid'+fold_number+'.xls') #evaluation set


In [11]:

output_folder='./folds/fold'+fold_number+'/'+model_class+'/'+model_version+"/"
cache_directory= "./folds/fold"+fold_number+'/'+model_class+"/"+model_version+"/cache/"


print('model variables were set up: ')

 
save_every_steps=1285
# assuming training batch size of 8
# any number above 1284 saves the model only at every epoch
# Saving the model mid training very often will consume disk space fast

train_args={
    "output_dir":output_folder,
    "cache_dir":cache_directory,
    'reprocess_input_data': True,
    'overwrite_output_dir': True,
    'num_train_epochs': 2,
    "save_steps": save_every_steps, 
    "learning_rate": 1.1e-5,
    "train_batch_size": 32,
    "eval_batch_size": 16,
    "weight_decay": 0,
    "evaluate_during_training_steps": 312,
    "max_seq_length": 100,
    "n_gpu": 1,
}

# Create a ClassificationModel
model = ClassificationModel(model_class, model_version, num_labels=labels_count, args=train_args) 

model variables were set up: 


In [42]:
# loading the checkpoint that gave the best result

CheckPoint='checkpoint-428-epoch-2' 


preSavedCheckpoint=output_folder+CheckPoint

print('Loading model, please wait...')
model = ClassificationModel( model_class, preSavedCheckpoint, num_labels=labels_count, args=train_args) 
print('model in use is :', preSavedCheckpoint )


Loading model, please wait...
model in use is : ./folds/fold4/roberta/roberta-large/checkpoint-428-epoch-2


In [12]:
# Train the model
current_time = datetime.now()
model.train_model(train)
print("Training time: ", datetime.now() - current_time, 'at: ' ,datetime.now())

Converting to features started. Cache is not used.


HBox(children=(FloatProgress(value=0.0, max=6829.0), HTML(value='')))


Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=2.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=214.0, style=ProgressStyle(descri…

Running loss: 1.627575



Running loss: 1.584292Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Running loss: 1.624308Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0




Running loss: 1.530892Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 8192.0
Running loss: 1.506448Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 4096.0
Running loss: 1.192236


HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=214.0, style=ProgressStyle(descri…

Running loss: 1.029952Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 2048.0
Running loss: 0.771634

Training of roberta model complete. Saved to ./folds/fold4/roberta/roberta-large/.
Training time:  0:16:26.605530 at:  2020-04-28 18:03:47.976852


In [13]:
TrainResult, TrainModel_outputs, wrong_predictions = model.eval_model(train, acc=sklearn.metrics.accuracy_score)
 
EvalResult, EvalModel_outputs, wrong_predictions = model.eval_model(Eval, acc=sklearn.metrics.accuracy_score)


print('Training Result:', TrainResult['acc'])
#print('Model Out:', TrainModel_outputs)

print('Eval Result:', EvalResult['acc'])
#print('Model Out:', EvalModel_outputs)



Features loaded from cache at ./folds/fold4/roberta/roberta-large/cache/cached_dev_roberta_100_5_6829


HBox(children=(FloatProgress(value=0.0, max=427.0), HTML(value='')))


{'mcc': 0.46423184385130406, 'acc': 0.5744618538585444, 'eval_loss': 0.9510010246370659}
Features loaded from cache at ./folds/fold4/roberta/roberta-large/cache/cached_dev_roberta_100_5_1705


HBox(children=(FloatProgress(value=0.0, max=107.0), HTML(value='')))


{'mcc': 0.4146582417377948, 'acc': 0.5378299120234604, 'eval_loss': 1.0215839504081512}
Training Result: 0.5744618538585444
Eval Result: 0.5378299120234604


In [19]:
#test['labels']=test['labels'].astype('int64')
test.dtypes

text      object
labels     int64
dtype: object

In [18]:

TestResult, TestModel_outputs, wrong_predictions = model.eval_model(test.head(2209), acc=sklearn.metrics.accuracy_score)

print('Test Set Result:', TestResult['acc'])
#print('Model Out:', TestModel_outputs)


Features loaded from cache at ./folds/fold4/roberta/roberta-large/cache/cached_dev_roberta_100_5_2209


HBox(children=(FloatProgress(value=0.0, max=139.0), HTML(value='')))


{'mcc': 0.4336134006979582, 'acc': 0.5536441828881847, 'eval_loss': 0.9907341234975582}
Test Set Result: 0.5536441828881847


In [20]:
Pred=[]
Targets=[]

countCorrect=0

for row in range(TestModel_outputs.shape[0]):
    outputs=TestModel_outputs[row]
    #print(test.iloc[row,0])
    print(outputs, end=' ')
    
    result=0
    if outputs[0]<outputs[1]:result=1
    if outputs[result]<outputs[2]:result=2
    if outputs[result]<outputs[3]:result=3
    if outputs[result]<outputs[4]:result=4
    
    Pred.append(result)
    Targets.append(test.iloc[row,1])
    print(result, ' ',test.iloc[row,1], end=' ')
    if result==test.iloc[row,1]:
        countCorrect+=1
        print('Match',countCorrect)
    print('')

print(countCorrect)

[-0.2548828  1.6025391  1.8496094 -0.3190918 -2.8144531] 2   1 
[ 1.0800781   2.3730469   0.97265625 -1.2275391  -3.140625  ] 1   0 
[-0.03372192  1.7353516   1.6318359  -0.37109375 -2.6425781 ] 1   2 
[-2.3261719 -2.6269531 -0.7260742  1.8349609  2.9960938] 4   4 Match 1

[ 0.3149414   1.6035156   1.0351562  -0.36083984 -2.65625   ] 1   0 
[-2.640625  -2.1816406 -0.0083847  2.4121094  1.9345703] 3   3 Match 2

[ 2.3027344  1.9960938  0.4111328 -1.9238281 -2.7734375] 0   0 Match 3

[-2.8222656  -2.0878906   0.11444092  2.4589844   1.7939453 ] 3   4 
[ 2.3984375   1.4501953  -0.02168274 -1.7675781  -2.1601562 ] 0   0 Match 4

[ 2.53125     1.8886719   0.22424316 -1.8496094  -2.7402344 ] 0   0 Match 5

[ 1.8564453   2.0566406   0.58496094 -1.7285156  -2.9863281 ] 1   2 
[-2.5       -1.9042969  0.1194458  2.1210938  1.5185547] 3   4 
[-1.0507812  0.8310547  1.2871094  0.8198242 -2.3554688] 2   1 
[ 2.0234375   1.9296875   0.55810547 -1.7519531  -3.0253906 ] 0   0 Match 6

[-2.2753906  -2.

[-2.4257812  -2.578125   -0.44335938  1.7587891   2.9882812 ] 4   3 
[-0.12084961  1.3652344   1.3457031  -0.21118164 -2.515625  ] 1   1 Match 59

[-2.0839844  -0.5449219   1.390625    1.7851562  -0.46826172] 3   3 Match 60

[-2.9394531  -2.03125     0.31298828  2.4609375   1.6611328 ] 3   3 Match 61

[-2.734375   -2.2226562   0.11804199  2.2070312   2.3164062 ] 4   4 Match 62

[-1.5478516  -0.90283203  0.42236328  1.8369141  -0.6484375 ] 3   2 
[ 1.5478516  1.7900391  0.8461914 -1.2509766 -2.828125 ] 1   0 
[-2.15625   -2.4746094 -0.6723633  1.6162109  3.1894531] 4   3 
[-2.2558594 -2.7128906 -0.8383789  1.6953125  3.2910156] 4   4 Match 63

[ 0.09344482  1.8300781   1.2617188  -0.390625   -2.6191406 ] 1   1 Match 64

[-2.9941406 -1.7519531  0.3515625  2.3671875  1.3955078] 3   3 Match 65

[-2.2050781  -0.97021484  0.9223633   2.1582031  -0.3425293 ] 3   3 Match 66

[-2.609375   -1.8974609   0.20385742  2.3964844   1.5458984 ] 3   4 
[ 1.7158203  1.8994141  0.5913086 -1.703125  -2.816

[-2.4335938  -2.4902344  -0.27148438  1.9492188   2.8242188 ] 4   3 
[-2.7773438  -2.3515625  -0.17333984  2.0722656   2.7519531 ] 4   4 Match 121

[-2.40625    -1.0693359   0.76660156  2.3632812  -0.07403564] 3   3 Match 122

[ 1.2314453  2.2285156  1.0332031 -1.2695312 -3.359375 ] 1   1 Match 123

[-0.75683594  1.078125    1.6005859   0.1427002  -2.4453125 ] 2   0 
[ 0.92285156  2.0585938   0.9941406  -1.2060547  -2.9667969 ] 1   2 
[-3.0019531  -2.2285156   0.27319336  2.203125    2.2480469 ] 4   3 
[ 2.0371094   1.9072266   0.32470703 -1.7441406  -2.7832031 ] 0   0 Match 124

[ 2.4628906   1.3632812  -0.03009033 -1.8310547  -2.3378906 ] 0   1 
[-0.56884766  0.80322266  1.5068359   0.7368164  -2.0566406 ] 2   1 
[ 2.0917969  1.6542969  0.3647461 -1.5507812 -2.4433594] 0   1 
[-2.375      -2.4628906  -0.63623047  1.9404297   3.1035156 ] 4   4 Match 125

[ 0.01913452  1.7197266   1.3798828  -0.2084961  -2.6992188 ] 1   2 
[-2.8652344 -2.3085938 -0.2763672  2.4824219  2.3496094] 3   4 

[-2.1582031  -0.35888672  1.4921875   1.8369141  -0.9355469 ] 3   2 
[ 2.0292969   1.9277344   0.30664062 -1.84375    -2.6289062 ] 0   0 Match 181

[ 1.4638672   2.0839844   0.74853516 -1.3828125  -3.265625  ] 1   1 Match 182

[-2.7363281  -2.3652344   0.02456665  2.2910156   2.1503906 ] 3   3 Match 183

[-2.3691406  -1.4462891   0.33691406  2.0957031   1.0800781 ] 3   3 Match 184

[ 2.0820312  2.1289062  0.5102539 -1.6513672 -3.1894531] 1   1 Match 185

[-2.2441406  -2.2675781  -0.35620117  1.6123047   2.546875  ] 4   3 
[-2.6230469  -1.7148438   0.38061523  2.3222656   0.9296875 ] 3   2 
[ 2.1425781   1.8935547   0.52490234 -1.9433594  -2.6933594 ] 0   2 
[-2.296875  -2.4023438 -0.5366211  1.6474609  3.       ] 4   4 Match 186

[ 1.9853516   2.0273438   0.62890625 -2.0292969  -2.8710938 ] 1   2 
[ 1.4453125  2.4824219  0.9248047 -1.1904297 -3.2421875] 1   0 
[ 1.4775391  1.8896484  0.8798828 -1.5537109 -3.1464844] 1   1 Match 187

[-1.4199219   0.21496582  1.8359375   0.8359375  -1.1

[ 0.71728516  2.2246094   0.8701172  -1.03125    -3.2070312 ] 1   1 Match 257

[-2.3066406  -2.359375   -0.21862793  1.6660156   2.8085938 ] 4   4 Match 258

[-2.8847656  -1.875       0.05993652  2.3730469   1.4541016 ] 3   3 Match 259

[-1.1113281  0.7246094  1.6367188  0.6879883 -1.4863281] 2   2 Match 260

[-0.04589844  1.0361328   1.3662109  -0.06689453 -2.3691406 ] 2   2 Match 261

[ 2.2597656   1.5751953   0.07489014 -1.4726562  -2.5625    ] 0   0 Match 262

[ 2.5390625   1.6787109   0.07086182 -1.9863281  -2.5390625 ] 0   0 Match 263

[ 1.953125    1.2167969   0.09533691 -1.4824219  -2.0957031 ] 0   2 
[-2.1230469  -0.7192383   0.79003906  1.6914062  -0.07531738] 3   3 Match 264

[ 0.9472656  1.7587891  1.0078125 -0.9692383 -3.0527344] 1   3 
[ 1.3115234   2.4433594   0.80810547 -1.3681641  -3.0195312 ] 1   0 
[ 1.4599609  2.0039062  0.8847656 -1.2324219 -2.9550781] 1   0 
[-2.2324219  -1.8212891   0.21435547  1.921875    1.21875   ] 3   4 
[ 2.4277344   1.4746094   0.14416504 -

[ 2.2324219   2.3574219   0.41552734 -1.7099609  -3.2207031 ] 1   0 
[ 1.7705078  2.1855469  0.8564453 -1.4912109 -3.0449219] 1   1 Match 319

[-2.4746094 -0.9995117  0.9506836  2.2558594  0.2364502] 3   3 Match 320

[-2.578125   -0.95703125  1.1025391   2.0683594  -0.17675781] 3   1 
[ 0.78125     2.1542969   1.0595703  -0.83935547 -3.1621094 ] 1   3 
[ 2.3105469   2.0527344   0.22839355 -2.0351562  -3.0253906 ] 0   0 Match 321

[ 0.50927734  1.6015625   1.2744141  -0.5107422  -2.7441406 ] 1   1 Match 322

[-1.28125     0.20605469  1.5         1.2119141  -1.2421875 ] 2   2 Match 323

[-2.9746094  -2.0976562   0.22949219  2.3476562   1.7441406 ] 3   2 
[-1.0126953  0.8071289  1.4990234  0.7050781 -1.7451172] 2   1 
[-1.7949219   0.09631348  1.4023438   1.1904297  -1.4775391 ] 2   0 
[ 1.2216797  2.0410156  0.7583008 -1.3681641 -3.2363281] 1   1 Match 324

[ 2.2324219   1.4052734   0.09527588 -1.7324219  -2.5878906 ] 0   0 Match 325

[-0.4729004   0.52246094  1.4599609   0.39013672 -1.9


[ 1.3261719  2.1582031  0.7949219 -1.3691406 -3.0800781] 1   0 
[ 1.8759766  2.0507812  0.7685547 -1.6201172 -2.8632812] 1   1 Match 390

[ 1.5068359  2.3164062  0.9404297 -1.46875   -3.0996094] 1   2 
[ 1.4384766  1.734375   0.8051758 -1.1962891 -3.1542969] 1   2 
[-2.1113281  -0.10174561  1.3642578   1.5439453  -0.4428711 ] 3   2 
[ 1.6005859  2.3203125  1.1777344 -1.7314453 -2.9667969] 1   1 Match 391

[-2.3222656 -2.6054688 -0.5234375  1.8525391  2.9394531] 4   4 Match 392

[-0.82470703  0.9580078   1.7861328  -0.04025269 -1.9892578 ] 2   3 
[-2.6035156  -2.5722656  -0.31323242  2.140625    2.7636719 ] 4   3 
[ 1.9697266   1.6308594   0.17590332 -1.8339844  -2.4082031 ] 0   1 
[ 2.0742188   1.6630859   0.12286377 -1.6845703  -2.6015625 ] 0   2 
[-2.765625   -1.9501953   0.00729752  2.4804688   1.7509766 ] 3   3 Match 393

[-2.5644531  -2.3671875  -0.40820312  1.8417969   3.0820312 ] 4   4 Match 394

[-0.39892578  1.1611328   1.5712891  -0.10003662 -2.3222656 ] 2   2 Match 395

[ 0

[ 1.7001953  1.6621094  0.7788086 -1.5009766 -3.0917969] 0   0 Match 452

[ 2.6445312   2.          0.03302002 -1.90625    -2.7890625 ] 0   0 Match 453

[ 2.125       2.1328125   0.35375977 -2.0742188  -3.0019531 ] 1   1 Match 454

[-1.8515625  -0.05496216  1.4521484   1.4375     -0.9941406 ] 2   1 
[-2.4355469  -0.65283203  1.3769531   2.0175781  -0.44921875] 3   3 Match 455

[-2.5566406  -2.5214844  -0.42871094  1.9384766   3.046875  ] 4   3 
[ 0.8901367  1.3261719  1.2255859 -0.890625  -2.8984375] 1   1 Match 456

[-2.5488281  -2.3398438  -0.09234619  2.1308594   2.7597656 ] 4   3 
[-2.7753906 -1.6757812  0.4243164  2.4492188  1.1630859] 3   3 Match 457

[-0.01930237  0.9477539   1.2050781   0.21386719 -2.5429688 ] 2   1 
[ 2.1757812  1.7578125  0.359375  -1.7558594 -2.7578125] 0   0 Match 458

[-2.3242188  -0.33129883  1.0664062   2.015625   -0.26879883] 3   3 Match 459

[ 1.2919922  1.4365234  0.6191406 -1.2773438 -2.578125 ] 1   0 
[-0.671875    1.4492188   1.4677734   0.31103516

[-0.07025146  1.2480469   1.0537109  -0.13208008 -2.6269531 ] 1   1 Match 523

[ 2.1152344  1.6289062  0.3425293 -1.6630859 -2.5351562] 0   1 
[-2.8476562e+00 -2.2656250e+00 -2.1686554e-03  2.2714844e+00
  2.2617188e+00] 3   4 
[-2.1210938  -0.1862793   1.1494141   1.7890625  -0.54833984] 3   2 
[ 2.1914062   1.9052734   0.30688477 -1.7207031  -2.453125  ] 0   0 Match 524

[ 1.3623047  2.1425781  0.8017578 -1.5976562 -3.0058594] 1   1 Match 525

[ 0.63378906  2.3828125   1.3183594  -0.7324219  -2.8730469 ] 1   1 Match 526

[-2.3945312 -2.6699219 -0.6386719  1.8925781  3.2324219] 4   4 Match 527

[ 2.0976562   1.5351562   0.27563477 -1.7910156  -2.4394531 ] 0   0 Match 528

[-2.796875  -2.2167969 -0.3798828  2.3378906  2.3046875] 3   3 Match 529

[-2.6953125  -2.3886719  -0.22265625  1.9736328   2.7460938 ] 4   3 
[ 2.421875   1.7734375  0.2084961 -1.796875  -2.671875 ] 0   1 
[ 1.9296875   1.7705078   0.16662598 -1.6591797  -2.7167969 ] 0   1 
[-2.7734375  -2.4628906  -0.31591797  2.08

[ 2.1289062   1.6191406   0.42626953 -1.6865234  -2.4511719 ] 0   0 Match 592

[-2.515625   -0.828125    0.52783203  2.1933594   0.4333496 ] 3   3 Match 593

[ 1.1982422  2.0429688  0.9614258 -1.0390625 -2.8808594] 1   1 Match 594

[ 1.1513672   1.0488281   0.79785156 -0.7988281  -2.1230469 ] 0   1 
[ 2.109375    2.0429688   0.32885742 -1.5292969  -2.7539062 ] 0   0 Match 595

[-2.8496094  -1.5654297   0.81591797  2.3105469   0.99658203] 3   3 Match 596

[ 0.3942871   1.4169922   1.2363281  -0.78222656 -2.9140625 ] 1   0 
[ 1.9824219  2.0507812  0.7426758 -1.7197266 -2.9667969] 1   0 
[-2.5761719  -0.7558594   0.90527344  2.2382812  -0.0539856 ] 3   3 Match 597

[-1.9775391 -2.6777344 -0.9272461  1.5019531  3.0273438] 4   4 Match 598

[-2.8457031  -2.0097656   0.25512695  2.2949219   1.5244141 ] 3   3 Match 599

[-2.3300781  -0.8671875   0.79052734  2.2890625   0.2692871 ] 3   2 
[-2.5039062 -2.4121094 -0.2668457  1.8671875  2.7207031] 4   3 
[-2.7949219 -1.5751953  0.4802246  2.337890

[ 1.6923828   2.2910156   0.75439453 -1.4716797  -2.9765625 ] 1   0 
[-2.6777344  -1.9853516   0.36450195  2.2578125   1.4863281 ] 3   2 
[ 1.8720703   1.7099609   0.32641602 -1.5458984  -2.4882812 ] 0   1 
[ 2.34375     1.2548828  -0.13195801 -1.7705078  -1.9882812 ] 0   0 Match 658

[-2.5820312  -2.5664062  -0.36254883  1.9746094   2.828125  ] 4   4 Match 659

[-2.390625   -1.6523438   0.45629883  2.0566406   1.0966797 ] 3   2 
[ 1.7617188  1.9404297  0.5991211 -1.53125   -2.890625 ] 1   1 Match 660

[-0.86572266  1.171875    1.5800781   0.38964844 -2.046875  ] 2   1 
[ 1.4501953   1.7441406   0.86328125 -1.296875   -2.8125    ] 1   2 
[ 1.640625    1.9912109   0.70410156 -1.5449219  -3.0976562 ] 1   1 Match 661

[-2.6582031  -2.4257812   0.01708984  2.140625    2.2871094 ] 4   4 Match 662

[ 0.82373047  1.3955078   0.7788086  -0.64746094 -2.5175781 ] 1   2 
[ 0.53125    1.5664062  1.0263672 -1.0429688 -3.0234375] 1   0 
[-2.3671875  -2.5078125  -0.42895508  1.8603516   3.        ] 4


[-2.8554688  -2.2617188  -0.11413574  2.0410156   2.3417969 ] 4   3 
[ 1.4580078  1.7822266  0.7265625 -1.2646484 -2.7558594] 1   0 
[ 1.5068359  1.8447266  0.5317383 -1.5058594 -2.8671875] 1   0 
[ 0.48608398  2.125       1.3974609  -1.2695312  -3.0058594 ] 1   1 Match 718

[ 0.27001953  1.9560547   1.4619141  -0.54785156 -2.9335938 ] 1   1 Match 719

[-1.6533203  -2.2695312  -0.57714844  1.3007812   2.1894531 ] 4   3 
[-1.2529297  -0.9892578   0.51220703  1.3261719   0.43066406] 3   3 Match 720

[ 2.0664062  2.1640625  0.5019531 -1.7177734 -3.1367188] 1   1 Match 721

[ 1.5283203   1.7353516   0.94091797 -1.6015625  -2.7480469 ] 1   0 
[-1.6337891  -0.5620117   0.6796875   1.6191406  -0.73828125] 3   3 Match 722

[ 1.8105469   1.7236328   0.64501953 -1.6708984  -2.8105469 ] 0   1 
[ 0.46044922  2.2519531   1.4609375  -0.6152344  -3.0351562 ] 1   2 
[-2.8828125 -2.1601562 -0.1373291  2.3027344  2.1894531] 3   3 Match 723

[-0.60302734  0.9667969   1.7402344   0.30395508 -2.0761719 ] 

[-0.5263672  0.9526367  1.6152344  0.0914917 -2.234375 ] 2   2 Match 779

[-0.40649414  0.5390625   0.8623047   0.07568359 -1.7705078 ] 2   2 Match 780

[-0.34960938  1.0576172   1.6416016  -0.2166748  -2.3457031 ] 2   2 Match 781

[-1.0810547   0.8652344   1.5126953   0.66308594 -1.9746094 ] 2   3 
[ 0.32006836  1.7285156   1.2226562  -0.39111328 -3.0195312 ] 1   1 Match 782

[-2.5019531  -2.3339844  -0.20703125  2.2578125   2.0488281 ] 3   4 
[ 2.1367188  1.9560547  0.3659668 -1.6162109 -2.9921875] 0   1 
[-2.6425781  -2.0195312   0.15258789  2.5136719   1.5732422 ] 3   3 Match 783

[-2.5175781  -2.28125    -0.09771729  2.140625    2.2617188 ] 4   1 
[-2.7558594  -2.0214844   0.20251465  2.3984375   1.7001953 ] 3   4 
[-2.7460938 -1.1240234  1.0146484  2.3457031  0.1194458] 3   3 Match 784

[-0.1875      1.7226562   1.5244141  -0.24621582 -2.5761719 ] 1   1 Match 785

[ 2.3554688   1.984375    0.37060547 -2.0585938  -2.8085938 ] 0   0 Match 786

[ 2.0488281  1.6914062  0.1550293 -1.9

[-2.2832031  -0.15686035  1.1787109   1.6386719  -0.72216797] 3   1 
[ 2.328125   1.9658203  0.59375   -1.7470703 -2.6699219] 0   1 
[-2.8398438 -1.2587891  0.5131836  2.2832031  0.390625 ] 3   3 Match 849

[-2.2949219 -2.5917969 -0.7402344  1.7119141  3.2070312] 4   4 Match 850

[-2.5664062  -2.4394531  -0.38476562  2.0078125   2.9570312 ] 4   3 
[-1.5576172   0.24157715  1.3134766   1.3232422  -1.40625   ] 3   2 
[ 1.0146484  2.0527344  1.34375   -1.4394531 -3.1347656] 1   1 Match 851

[ 2.2265625  2.1679688  0.6245117 -1.5351562 -2.9335938] 0   1 
[-2.5195312  -1.6591797   0.33984375  2.0605469   1.4091797 ] 3   3 Match 852

[-2.4257812  -1.3554688   0.78222656  2.2773438   0.00532532] 3   4 
[-2.4570312  -2.2675781  -0.55810547  1.6113281   3.2597656 ] 4   3 
[-2.7695312  -1.9091797   0.13964844  2.5292969   1.0830078 ] 3   3 Match 853

[-0.9160156   0.9667969   1.6679688   0.41503906 -2.1875    ] 2   2 Match 854

[-1.9335938 -2.3515625 -0.8588867  1.3779297  3.4277344] 4   4 Match


[-2.3945312 -2.3886719 -0.3564453  1.7216797  2.7695312] 4   4 Match 913

[ 2.0859375  1.8076172  0.4951172 -1.6318359 -3.0429688] 0   1 
[-2.046875   -0.3239746   1.2089844   1.7636719  -0.71533203] 3   3 Match 914

[ 1.7041016  1.8242188  0.7314453 -1.8232422 -2.5664062] 1   0 
[-2.4804688  -1.2509766   0.6899414   2.2558594   0.47558594] 3   3 Match 915

[ 0.8642578  2.2617188  1.3115234 -1.0771484 -3.2402344] 1   1 Match 916

[ 1.6542969  1.4443359  0.6826172 -1.5234375 -2.5273438] 0   1 
[ 2.1503906   2.4160156   0.66748047 -1.9375     -3.0253906 ] 1   0 
[-1.9316406  -1.4755859   0.27368164  1.796875    0.98535156] 3   3 Match 917

[ 1.8886719   1.7275391   0.73583984 -1.6142578  -3.0332031 ] 0   1 
[ 1.1523438  2.0625     1.1103516 -1.4277344 -3.0839844] 1   1 Match 918

[-1.9521484  -0.33618164  1.1728516   1.8076172  -0.1730957 ] 3   2 
[-2.7109375  -2.046875    0.40844727  2.2539062   1.4882812 ] 3   2 
[-1.4482422  -0.40112305  0.82910156  1.1933594  -0.23254395] 3   3 Matc

[-2.5859375  -2.078125   -0.06072998  2.0039062   2.15625   ] 4   3 
[ 1.6533203   1.7597656   0.96972656 -1.6337891  -2.7792969 ] 1   1 Match 982

[ 1.1054688  1.4541016  0.6503906 -1.4091797 -2.6269531] 1   0 
[-2.9804688  -2.3183594  -0.02719116  2.1367188   2.4472656 ] 4   3 
[-0.61328125  1.2041016   1.7353516   0.23144531 -2.2578125 ] 2   1 
[ 1.1298828  1.6728516  0.8105469 -1.2050781 -2.7324219] 1   0 
[-1.9365234  -2.5078125  -0.69677734  1.3261719   3.3085938 ] 4   4 Match 983

[-2.484375  -2.3925781 -0.3972168  2.0039062  3.0683594] 4   4 Match 984

[-2.3554688  -2.4667969  -0.62841797  1.7060547   3.4160156 ] 4   3 
[ 0.9589844   1.6552734   0.8408203  -0.87646484 -2.7773438 ] 1   1 Match 985

[ 1.6201172   1.9707031   0.67285156 -1.4765625  -2.7988281 ] 1   1 Match 986

[-2.6699219  -1.7861328   0.40942383  2.1328125   0.859375  ] 3   2 
[-2.1484375 -2.5800781 -0.6435547  1.6113281  2.8886719] 4   4 Match 987

[-2.296875   -2.6054688  -0.63671875  1.6269531   3.3945312 ] 4

[ 1.15625    2.171875   1.1376953 -1.4541016 -3.296875 ] 1   2 
[-2.3476562 -2.4980469 -0.6147461  1.6787109  3.2363281] 4   4 Match 1044

[-1.5273438   0.54296875  1.5332031   0.97998047 -1.2910156 ] 2   2 Match 1045

[-2.5488281 -2.4199219 -0.3984375  1.7851562  3.0820312] 4   4 Match 1046

[ 1.2636719  2.25       1.0693359 -1.1650391 -3.2011719] 1   0 
[ 1.6972656  2.1132812  0.9272461 -1.3525391 -3.0449219] 1   2 
[-0.12298584  1.3017578   1.5039062  -0.26953125 -2.5097656 ] 2   2 Match 1047

[-2.4667969  -0.9892578   0.92626953  2.4179688   0.08605957] 3   2 
[-2.3710938 -2.4804688 -0.6621094  1.6435547  2.8984375] 4   4 Match 1048

[ 2.2617188  1.6357422  0.1015625 -1.71875   -2.4960938] 0   0 Match 1049

[-0.22375488  1.1308594   1.3388672  -0.171875   -2.2773438 ] 2   0 
[ 2.234375   1.9423828  0.328125  -2.125     -2.9453125] 0   0 Match 1050

[ 1.4121094  1.5537109  0.7832031 -1.6123047 -2.3007812] 1   1 Match 1051

[ 2.0703125   2.3300781   0.82666016 -1.7080078  -3.1445312 

[ 1.578125    2.0175781   0.95947266 -1.5849609  -3.1230469 ] 1   1 Match 1115

[ 1.6699219  2.1113281  0.6767578 -1.5117188 -2.9296875] 1   1 Match 1116

[ 2.1679688  2.1914062  0.6411133 -1.6591797 -2.8515625] 1   0 
[ 0.04663086  1.7324219   1.4882812  -0.4724121  -2.6679688 ] 1   1 Match 1117

[ 1.6054688   2.1679688   0.91845703 -1.3574219  -3.1503906 ] 1   1 Match 1118

[-2.5097656  -0.8720703   1.0009766   2.1542969   0.39331055] 3   2 
[-2.7460938  -2.1796875  -0.12512207  1.875       2.2929688 ] 4   4 Match 1119

[ 1.7597656   2.109375    0.76660156 -1.5664062  -2.9804688 ] 1   1 Match 1120

[-2.4707031  -2.6289062  -0.65185547  1.984375    3.046875  ] 4   4 Match 1121

[ 1.1884766  2.4179688  1.1142578 -1.4931641 -3.2636719] 1   2 
[-2.5546875  -1.4580078   0.63916016  2.328125    0.77734375] 3   2 
[-1.71875    -0.8769531   0.77490234  1.5917969  -0.02992249] 3   4 
[-1.3417969  -0.02944946  1.0986328   0.8828125  -1.4736328 ] 2   1 
[ 1.578125   2.3359375  0.6689453 -1.5498

[-2.6914062 -2.1835938 -0.1104126  2.2148438  2.109375 ] 3   4 
[ 0.21130371  1.3339844   0.99560547 -0.61083984 -2.3027344 ] 1   1 Match 1183

[-2.7753906  -2.3007812  -0.01428223  2.2460938   1.9335938 ] 3   4 
[-2.6523438  -1.0507812   1.0263672   2.2832031   0.25317383] 3   3 Match 1184

[-2.5449219 -1.6201172  0.5185547  2.2773438  1.2724609] 3   4 
[-2.3632812  -1.3925781   0.51416016  2.421875    0.6333008 ] 3   3 Match 1185

[ 0.08544922  1.7431641   1.3955078  -0.22875977 -2.7910156 ] 1   2 
[-2.7578125  -2.0234375   0.08288574  2.5117188   1.8261719 ] 3   3 Match 1186

[ 2.65625     2.015625    0.19506836 -1.8193359  -2.6152344 ] 0   1 
[-2.71875    -2.2382812   0.00383759  2.1542969   2.4394531 ] 4   4 Match 1187

[ 1.0019531   1.9785156   0.85058594 -0.83447266 -3.2050781 ] 1   1 Match 1188

[-2.3378906 -2.78125   -0.8339844  1.7519531  3.2519531] 4   4 Match 1189

[-2.8222656 -2.09375    0.0958252  2.1777344  2.0214844] 3   4 
[-2.6074219  -2.2910156  -0.43310547  1.911132

In [21]:
from sklearn import metrics
 
print(metrics.confusion_matrix(Targets,Pred))

[[157 108  10   3   0]
 [143 371  81  36   2]
 [ 25 131  87 130  16]
 [  1  12  25 318 154]
 [  0   0  14  95 290]]


In [22]:
target_names = ['Very Neg', 'Negative', 'Neutral','Positive','Very Pos']
print(metrics.classification_report(Targets, Pred,target_names =target_names))

              precision    recall  f1-score   support

    Very Neg       0.48      0.56      0.52       278
    Negative       0.60      0.59      0.59       633
     Neutral       0.40      0.22      0.29       389
    Positive       0.55      0.62      0.58       510
    Very Pos       0.63      0.73      0.67       399

    accuracy                           0.55      2209
   macro avg       0.53      0.54      0.53      2209
weighted avg       0.54      0.55      0.54      2209



In [23]:
Fold_Predictions=pd.DataFrame(Pred, columns=['Pred4'] )
Fold_Predictions

Unnamed: 0,Pred4
0,2
1,1
2,1
3,4
4,1
...,...
2204,1
2205,4
2206,1
2207,4


In [24]:
Fold_Predictions.to_excel(output_folder+'/Saves/fold4_Predictions.xls')

In [25]:
#clearing GPU cache

del(model)
del(TrainResult, TrainModel_outputs, EvalResult, EvalModel_outputs, TestResult, TestModel_outputs, wrong_predictions)
torch.cuda.empty_cache()

## Fold 5: training & capturing predictions

In [26]:
fold_number='5'

train=pd.read_excel('./folds/train_fold'+fold_number+'.xls')
Eval=pd.read_excel('./folds/valid'+fold_number+'.xls') #evaluation set


In [27]:
 
output_folder='./folds/fold'+fold_number+'/'+model_class+'/'+model_version+"/"
cache_directory= "./folds/fold"+fold_number+'/'+model_class+"/"+model_version+"/cache/"


print('model variables were set up: ')

 
save_every_steps=1285
# assuming training batch size of 8
# any number above 1284 saves the model only at every epoch
# Saving the model mid training very often will consume disk space fast

train_args={
    "output_dir":output_folder,
    "cache_dir":cache_directory,
    'reprocess_input_data': True,
    'overwrite_output_dir': True,
    'num_train_epochs': 2,
    "save_steps": save_every_steps,
    "learning_rate": 1.2e-5,
    "train_batch_size": 32,
    "eval_batch_size": 16,
    "weight_decay": 0,
    "evaluate_during_training_steps": 312,
    "max_seq_length": 100,
    "n_gpu": 1,
}

# Create a ClassificationModel
model = ClassificationModel(model_class, model_version, num_labels=labels_count, args=train_args) 

model variables were set up: 


In [28]:
# loading the checkpoint that gave the best result
'''
CheckPoint='checkpoint-286-epoch-2'   


preSavedCheckpoint=output_folder+CheckPoint

print('Loading model, please wait...')
model = ClassificationModel( model_class, preSavedCheckpoint, num_labels=labels_count, args=train_args) 
print('model in use is :', preSavedCheckpoint )
'''

"\nCheckPoint='checkpoint-286-epoch-2'   \n\n\npreSavedCheckpoint=output_folder+CheckPoint\n\nprint('Loading model, please wait...')\nmodel = ClassificationModel( model_class, preSavedCheckpoint, num_labels=labels_count, args=train_args) \nprint('model in use is :', preSavedCheckpoint )\n"

In [29]:
# Train the model
current_time = datetime.now()
model.train_model(train)
print("Training time: ", datetime.now() - current_time)

Converting to features started. Cache is not used.


HBox(children=(FloatProgress(value=0.0, max=6829.0), HTML(value='')))


Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=2.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=214.0, style=ProgressStyle(descri…

Running loss: 1.565867Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Running loss: 1.605363Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Running loss: 1.050404Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 8192.0
Running loss: 0.991105


HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=214.0, style=ProgressStyle(descri…

Running loss: 0.763252

Training of roberta model complete. Saved to ./folds/fold5/roberta/roberta-large/.
Training time:  0:16:54.832875


In [30]:
TrainResult, TrainModel_outputs, wrong_predictions = model.eval_model(train, acc=sklearn.metrics.accuracy_score)
 
EvalResult, EvalModel_outputs, wrong_predictions = model.eval_model(Eval, acc=sklearn.metrics.accuracy_score)

TestResult, TestModel_outputs, wrong_predictions = model.eval_model(test, acc=sklearn.metrics.accuracy_score)

print('Training Result:', TrainResult['acc'])
#print('Model Out:', TrainModel_outputs)

print('Eval Result:', EvalResult['acc'])
#print('Model Out:', EvalModel_outputs)

print('Test Set Result:', TestResult['acc'])
#print('Model Out:', TestModel_outputs)

Converting to features started. Cache is not used.


HBox(children=(FloatProgress(value=0.0, max=6829.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=427.0), HTML(value='')))


{'mcc': 0.5694218574625568, 'acc': 0.660418802167228, 'eval_loss': 0.7944691421276513}
Converting to features started. Cache is not used.


HBox(children=(FloatProgress(value=0.0, max=1705.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=107.0), HTML(value='')))


{'mcc': 0.4503913555485112, 'acc': 0.5695014662756598, 'eval_loss': 0.9450209817596685}
Converting to features started. Cache is not used.


HBox(children=(FloatProgress(value=0.0, max=2210.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=139.0), HTML(value='')))


{'mcc': 0.4908723021870356, 'acc': 0.6009049773755656, 'eval_loss': 0.9089024735869263}
Training Result: 0.660418802167228
Eval Result: 0.5695014662756598
Test Set Result: 0.6009049773755656


In [31]:
Pred=[]
Targets=[]

countCorrect=0

for row in range(TestModel_outputs.shape[0]):
    outputs=TestModel_outputs[row]
    #print(test.iloc[row,0])
    print(outputs, end=' ')
    
    result=0
    if outputs[0]<outputs[1]:result=1
    if outputs[result]<outputs[2]:result=2
    if outputs[result]<outputs[3]:result=3
    if outputs[result]<outputs[4]:result=4
    
    Pred.append(result)
    Targets.append(test.iloc[row,1])
    print(result, ' ',test.iloc[row,1], end=' ')
    if result==test.iloc[row,1]:
        countCorrect+=1
        print('Match',countCorrect)
    print('')

print(countCorrect)

[-0.22265625  2.4316406   1.9746094  -0.5444336  -3.0917969 ] 1   1 Match 1

[ 0.57177734  2.7558594   1.5693359  -1.2646484  -3.3554688 ] 1   0 
[ 0.2310791  2.3027344  1.4111328 -1.0644531 -3.0273438] 1   2 
[-2.6992188  -2.5605469  -0.52734375  2.5449219   3.5214844 ] 4   4 Match 2

[ 0.49194336  2.4433594   1.9580078  -1.0419922  -2.9863281 ] 1   0 
[-2.9921875 -2.0234375  0.1583252  3.0410156  2.1269531] 3   3 Match 3

[ 2.7578125   2.4492188   0.26733398 -2.2148438  -3.0449219 ] 0   0 Match 4

[-3.0820312  -2.1152344   0.30029297  3.40625     2.0839844 ] 3   4 
[ 3.5683594   2.0390625  -0.13659668 -2.2558594  -2.6132812 ] 0   0 Match 5

[ 3.3007812e+00  2.2578125e+00  2.4890900e-03 -2.2421875e+00
 -2.8496094e+00] 0   0 Match 6

[ 0.7866211  2.6308594  1.4970703 -1.3007812 -3.3085938] 1   2 
[-3.0214844  -1.8134766   0.71533203  2.9277344   1.6171875 ] 3   4 
[-1.7705078  1.2832031  2.234375   1.0429688 -2.3574219] 2   1 
[ 2.9238281   2.3964844   0.24157715 -2.2460938  -3.0898438

[-3.2285156 -1.9677734  0.7265625  3.4199219  1.6376953] 3   3 Match 66

[-3.0332031 -2.4199219  0.0682373  3.09375    2.7871094] 3   4 
[-0.88720703  1.7900391   2.2265625   0.24707031 -2.6953125 ] 2   2 Match 67

[ 1.9296875  2.7070312  0.9785156 -1.8974609 -3.4277344] 1   0 
[-2.6855469 -2.4648438 -0.5048828  2.1855469  3.6289062] 4   3 
[-2.6054688 -2.46875   -0.6220703  2.3867188  3.6855469] 4   4 Match 68

[-0.47265625  2.3105469   1.8828125  -0.59277344 -3.2265625 ] 1   1 Match 69

[-3.1425781  -2.4023438   0.11315918  3.3203125   2.4902344 ] 3   3 Match 70

[-2.6582031  -0.94140625  1.5908203   2.8359375  -0.15625   ] 3   3 Match 71

[-2.96875   -1.7958984  0.7895508  2.921875   1.7246094] 3   4 
[ 2.6035156  2.5566406  0.6479492 -1.9423828 -3.1015625] 0   0 Match 72

[ 0.47802734  2.703125    1.4199219  -1.2011719  -3.0898438 ] 1   1 Match 73

[ 2.3203125  2.5058594  0.4873047 -2.0917969 -3.0996094] 1   1 Match 74

[ 2.3613281  2.421875   0.4375    -1.9990234 -3.1308594] 1   0

[ 3.3984375   2.0332031  -0.21960449 -2.1875     -2.65625   ] 0   1 
[-0.40698242  2.2851562   2.0585938  -0.39404297 -3.0703125 ] 1   1 Match 138

[ 2.1777344   2.4550781   0.49316406 -2.0332031  -3.1054688 ] 1   1 Match 139

[-2.5058594 -2.2929688 -0.5620117  2.0195312  3.671875 ] 4   4 Match 140

[-0.64746094  2.0859375   1.9179688  -0.30688477 -3.0175781 ] 1   2 
[-3.1953125  -2.2773438   0.13659668  3.3789062   2.3457031 ] 3   4 
[-0.31982422  1.9052734   1.9394531  -0.4362793  -2.8789062 ] 2   1 
[-2.6035156 -2.4648438 -0.4555664  2.0839844  3.7167969] 4   4 Match 141

[-2.6835938 -0.8935547  1.2607422  2.9023438  0.2685547] 3   2 
[ 1.3613281  2.6152344  1.1113281 -1.6201172 -3.3339844] 1   0 
[-2.8964844 -2.4082031 -0.3708496  2.7324219  3.359375 ] 4   3 
[-1.9980469   0.44018555  2.328125    1.6923828  -1.4921875 ] 2   2 Match 142

[-3.109375   -1.96875     0.66259766  3.2558594   1.5576172 ] 3   3 Match 143

[-2.7324219  -1.4316406   1.2587891   2.8671875   0.65185547] 3   3 

[ 1.9863281  2.6601562  0.7709961 -1.9609375 -3.3007812] 1   1 Match 204

[-2.9785156  -2.4140625  -0.21118164  2.8496094   3.0917969 ] 4   4 Match 205

[-2.765625   -1.0185547   1.4189453   2.9199219   0.35498047] 3   3 Match 206

[-2.96875    -2.3164062   0.13598633  2.7109375   2.8261719 ] 4   4 Match 207

[-2.734375   -2.5195312  -0.24194336  2.2441406   3.3652344 ] 4   4 Match 208

[-0.37890625  1.3613281   1.9765625   0.06890869 -2.6113281 ] 2   1 
[-2.5        -0.87060547  1.4814453   2.6542969   0.12060547] 3   3 Match 209

[-2.2480469  -0.5131836   1.4980469   2.2851562  -0.20275879] 3   3 Match 210

[ 1.1904297  2.9375     1.2734375 -1.7070312 -3.3847656] 1   1 Match 211

[-2.8828125  -2.5917969  -0.44995117  2.7753906   3.5644531 ] 4   4 Match 212

[ 0.23364258  2.5625      1.5742188  -1.1083984  -3.2148438 ] 1   1 Match 213

[-3.0351562  -2.5410156  -0.16040039  3.0214844   3.0351562 ] 4   4 Match 214

[-2.8183594  -2.6367188  -0.65234375  2.4042969   3.6953125 ] 4   4 Matc

[ 3.4414062  1.8837891 -0.140625  -2.1308594 -2.4121094] 0   1 
[-2.6679688 -2.5957031 -0.6621094  2.2597656  3.7636719] 4   4 Match 283

[ 1.1337891  2.6894531  1.3115234 -1.5195312 -3.3847656] 1   2 
[-0.6850586   1.4667969   2.0410156   0.45043945 -2.5449219 ] 2   2 Match 284

[ 2.09375     2.6171875   0.68115234 -1.9091797  -3.2675781 ] 1   2 
[-2.9824219 -2.4667969 -0.2084961  3.0429688  3.1835938] 4   3 
[-2.9101562  -2.4804688  -0.45288086  2.6914062   3.390625  ] 4   4 Match 285

[ 2.5195312   2.5839844   0.34228516 -2.0800781  -3.1035156 ] 1   1 Match 286

[ 2.8632812   2.4238281   0.27124023 -2.1796875  -3.0722656 ] 0   0 Match 287

[-2.9726562 -1.7880859  0.5654297  3.2753906  1.4228516] 3   3 Match 288

[-2.7109375 -2.5214844 -0.8125     2.2246094  3.9609375] 4   4 Match 289

[-3.1074219  -2.0664062   0.38330078  3.5039062   1.8193359 ] 3   3 Match 290

[-3.1425781  -2.2011719   0.28808594  3.3144531   2.2226562 ] 3   4 
[-2.90625    -2.3554688  -0.19250488  2.9296875   2.9

[ 2.2714844  2.578125   0.5341797 -2.0195312 -3.0527344] 1   1 Match 350

[ 3.5097656   2.1738281  -0.06555176 -2.0917969  -2.7304688 ] 0   0 Match 351

[-1.4140625   0.58203125  2.4511719   1.1328125  -1.8378906 ] 2   2 Match 352

[-1.6386719 -0.3774414  1.7851562  1.4023438 -1.0273438] 2   4 
[-0.19067383  2.5292969   1.8095703  -0.8017578  -2.8925781 ] 1   1 Match 353

[ 1.1699219  2.6894531  1.1357422 -1.6708984 -3.2714844] 1   2 
[ 0.06793213  2.7207031   1.8574219  -1.0888672  -3.1835938 ] 1   2 
[-0.44482422  1.4443359   2.234375    0.03262329 -2.8320312 ] 2   1 
[-2.8710938  -2.3515625   0.01347351  2.7148438   3.0585938 ] 4   3 
[-2.4921875 -1.0195312  1.4023438  2.8476562  0.5708008] 3   3 Match 354

[-3.1582031 -2.2109375  0.3005371  3.1679688  2.4042969] 3   2 
[-3.0878906  -1.8486328   0.75341797  3.2246094   1.5185547 ] 3   3 Match 355

[-2.9121094  -1.5068359   1.3554688   3.0507812   0.67529297] 3   3 Match 356

[ 2.0058594  2.4433594  0.5517578 -1.9082031 -3.0605469] 1

[-2.8164062  -2.4042969  -0.57177734  2.4121094   3.5800781 ] 4   4 Match 419

[-0.55322266  1.9980469   1.9091797  -0.25146484 -2.8378906 ] 1   2 
[ 1.75       2.7421875  1.1152344 -1.7578125 -3.3613281] 1   1 Match 420

[ 2.7402344   2.328125    0.27001953 -2.1503906  -3.0429688 ] 0   2 
[-2.7871094  -2.28125    -0.03637695  2.96875     2.6816406 ] 3   3 Match 421

[ 2.765625    2.6914062   0.46948242 -2.2382812  -3.1542969 ] 0   1 
[ 0.36865234  2.7851562   1.3886719  -1.1376953  -3.3476562 ] 1   1 Match 422

[ 1.4990234  2.7734375  1.0332031 -1.8359375 -3.1855469] 1   0 
[ 2.078125    2.6328125   0.86279297 -1.8623047  -3.2109375 ] 1   2 
[-0.02688599  2.2988281   1.6083984  -0.5800781  -3.1054688 ] 1   2 
[ 3.6054688   1.9228516  -0.14428711 -2.2128906  -2.6210938 ] 0   0 Match 423

[-3.0546875 -1.2958984  1.5429688  3.0878906  0.5708008] 3   3 Match 424

[ 0.92089844  2.6542969   1.3925781  -1.4423828  -3.3574219 ] 1   1 Match 425

[-1.8837891  1.2255859  2.296875   0.9819336 -2.

[ 1.7666016   2.2421875   0.84033203 -1.5185547  -3.0898438 ] 1   1 Match 486

[-3.0664062  -2.4433594   0.02838135  2.9414062   2.984375  ] 4   3 
[-3.0722656 -1.8144531  1.0019531  3.3222656  1.1152344] 3   3 Match 487

[ 0.13464355  2.2265625   1.9941406  -0.74609375 -3.2324219 ] 1   1 Match 488

[ 3.0488281   2.4335938   0.28149414 -2.2480469  -3.109375  ] 0   0 Match 489

[-2.5976562  -0.48876953  2.0683594   2.5195312  -0.5463867 ] 3   3 Match 490

[ 1.4521484  2.7890625  1.0996094 -1.7568359 -3.3476562] 1   0 
[-1.7832031   1.4072266   2.1660156   0.96972656 -2.3066406 ] 2   4 
[-2.796875   -2.4082031  -0.33618164  2.1992188   3.6347656 ] 4   4 Match 491

[-0.3293457   2.4472656   1.8564453  -0.70214844 -3.0742188 ] 1   1 Match 492

[ 0.39990234  2.078125    1.8671875  -0.7680664  -3.1464844 ] 1   1 Match 493

[-0.6567383   2.0292969   2.0566406  -0.23364258 -2.8710938 ] 2   2 Match 494

[-1.4941406   1.3193359   2.1992188   0.90478516 -2.3046875 ] 2   1 
[ 0.6279297   2.1777344

[-3.1699219  -2.3652344  -0.25097656  3.0058594   2.9882812 ] 3   4 
[-1.8798828  0.7290039  1.9589844  1.4150391 -1.6738281] 2   2 Match 559

[ 2.6113281  2.5019531  0.3269043 -2.1875    -3.140625 ] 0   0 Match 560

[ 1.8027344   2.734375    0.96728516 -1.8857422  -3.34375   ] 1   1 Match 561

[-0.4765625   2.1796875   1.8730469  -0.46166992 -2.7050781 ] 1   1 Match 562

[-2.8339844  -2.6601562  -0.68066406  2.6660156   3.4882812 ] 4   4 Match 563

[ 3.546875    2.2128906   0.02380371 -2.2382812  -2.7773438 ] 0   0 Match 564

[-3.0234375  -2.4355469  -0.24536133  2.9882812   2.6425781 ] 3   3 Match 565

[-2.9101562  -2.28125     0.05206299  2.6835938   2.9199219 ] 4   3 
[ 3.4101562   2.25       -0.01638794 -2.2128906  -2.8203125 ] 0   1 
[ 2.7070312  2.3808594  0.2758789 -2.1035156 -2.9257812] 0   1 
[-3.0136719  -2.6230469  -0.37695312  2.9472656   3.2324219 ] 4   4 Match 566

[ 3.5         2.2167969   0.00434875 -2.2929688  -2.8398438 ] 0   0 Match 567

[-3.0253906  -1.7724609   0.

[-2.6523438  -0.6503906   1.7646484   2.7753906  -0.35083008] 3   3 Match 634

[ 1.171875    2.3417969   1.6484375  -0.97753906 -2.9667969 ] 1   2 
[-2.6542969 -2.3671875 -0.5336914  2.1113281  3.3945312] 4   4 Match 635

[-0.43041992  1.4472656   2.1367188   0.19934082 -2.5019531 ] 2   0 
[-2.9824219  -2.46875    -0.17822266  2.6640625   3.2753906 ] 4   4 Match 636

[-2.96875   -1.7470703  1.0009766  3.3398438  1.2441406] 3   2 
[ 2.2109375  2.2480469  0.6899414 -1.8701172 -3.1191406] 1   0 
[-2.8339844  -1.5380859   0.82128906  3.0605469   0.93603516] 3   3 Match 637

[ 1.1298828  2.328125   1.3378906 -1.3544922 -3.2421875] 1   1 Match 638

[ 1.5498047  1.6630859  1.0566406 -1.3457031 -2.7988281] 1   1 Match 639

[ 3.0039062   2.4589844   0.25073242 -2.3066406  -3.0058594 ] 0   0 Match 640

[-2.640625  -1.1611328  1.3496094  2.7753906  0.5830078] 3   3 Match 641

[ 1.2089844  2.2441406  1.2783203 -1.1025391 -2.9980469] 1   0 
[ 1.8769531  2.546875   0.9453125 -1.9121094 -3.3339844] 1

[-2.7304688 -2.5175781 -0.5029297  2.4453125  3.5625   ] 4   3 
[-2.9648438  -2.4941406  -0.42578125  3.0703125   3.0820312 ] 4   3 
[ 2.5058594  2.3046875  0.5185547 -1.8173828 -2.84375  ] 0   1 
[-2.8945312  -2.1660156  -0.06384277  3.1445312   2.7539062 ] 3   4 
[-2.984375   -2.0820312   0.08892822  2.8007812   2.2421875 ] 3   4 
[-3.0625     -2.3867188  -0.17810059  3.2050781   2.7851562 ] 3   4 
[-2.1972656   0.28295898  2.4355469   1.4755859  -1.3476562 ] 2   2 Match 699

[-1.8085938   0.39111328  2.2441406   1.4648438  -1.5146484 ] 2   1 
[ 2.5527344   2.7167969   0.66308594 -2.1621094  -3.2070312 ] 1   1 Match 700

[ 1.6552734  2.46875    0.8022461 -1.8203125 -3.2363281] 1   1 Match 701

[ 2.1699219   2.6914062   0.81347656 -1.8886719  -3.171875  ] 1   1 Match 702

[ 2.2890625   2.6992188   0.57373047 -2.1035156  -3.1582031 ] 1   2 
[ 1.0224609  2.8164062  1.3759766 -1.5703125 -3.2753906] 1   1 Match 703

[ 1.3505859  2.6738281  1.0625    -1.6650391 -3.4238281] 1   0 
[-2.79492

[ 3.2285156  2.2421875  0.0982666 -2.2441406 -2.8613281] 0   2 
[ 2.4257812   2.4003906   0.45410156 -2.0019531  -3.1269531 ] 0   1 
[ 3.0253906  2.4824219  0.3479004 -2.1347656 -3.0585938] 0   1 
[-2.6191406  -2.2753906  -0.08380127  2.7167969   2.4433594 ] 3   4 
[ 3.2441406   2.2773438  -0.00954437 -2.2929688  -2.84375   ] 0   0 Match 765

[ 1.7460938  2.3183594  1.2255859 -1.390625  -2.8691406] 1   1 Match 766

[-2.9609375  -1.5439453   0.83691406  3.3613281   1.3466797 ] 3   3 Match 767

[-3.078125  -1.8701172  0.8227539  3.1132812  1.5732422] 3   3 Match 768

[ 1.4443359   2.671875    0.98583984 -1.6855469  -3.3730469 ] 1   1 Match 769

[-0.36035156  1.9453125   1.9384766  -0.45874023 -2.9824219 ] 1   2 
[ 1.5195312  2.7460938  1.2431641 -1.8310547 -3.3691406] 1   1 Match 770

[-2.3144531  0.3779297  2.2890625  1.984375  -1.4511719] 2   3 
[ 1.6181641  2.9472656  1.2744141 -1.703125  -3.2636719] 1   1 Match 771

[-2.6210938  -0.21069336  1.8925781   2.4921875  -0.70996094] 3   3 

[-2.6152344  -0.68847656  1.9111328   2.6777344  -0.37426758] 3   2 
[ 1.2363281  2.7851562  1.3701172 -1.5371094 -3.2070312] 1   1 Match 832

[-2.5839844  -2.4550781  -0.27612305  2.0351562   3.1679688 ] 4   3 
[-2.5527344  -2.4433594  -0.67529297  2.1328125   3.9140625 ] 4   4 Match 833

[ 0.24621582  2.4667969   1.5976562  -0.9511719  -3.1992188 ] 1   0 
[ 1.7001953  2.4199219  0.9428711 -1.8242188 -3.2324219] 1   2 
[ 2.4453125   2.4472656   0.44091797 -1.9902344  -2.9023438 ] 1   0 
[-2.3378906  -0.62353516  1.7324219   2.5371094  -0.42871094] 3   3 Match 834

[-2.8886719  -1.9882812   0.43798828  3.1582031   1.6240234 ] 3   3 Match 835

[ 0.36108398  2.6542969   1.6123047  -1.2675781  -3.171875  ] 1   1 Match 836

[ 2.609375    2.4804688   0.42138672 -2.0722656  -3.1992188 ] 0   0 Match 837

[ 2.4667969   2.6230469   0.48828125 -2.1503906  -3.2382812 ] 1   0 
[-2.0957031   0.12469482  1.9404297   2.1503906  -1.0175781 ] 3   3 Match 838

[ 2.9726562  2.3925781  0.234375  -2.171875

[-3.0507812 -1.9667969  0.671875   3.4296875  1.5673828] 3   3 Match 906

[-2.8046875  -1.9228516   0.57128906  2.9824219   1.7363281 ] 3   4 
[-2.3535156  -0.10455322  2.1152344   2.2675781  -0.90283203] 3   3 Match 907

[ 1.9160156   2.5371094   0.62353516 -1.9414062  -3.3398438 ] 1   1 Match 908

[-3.0703125  -2.3789062   0.00523376  3.21875     2.6367188 ] 3   4 
[-2.8007812  -2.4941406  -0.18737793  2.453125    3.4277344 ] 4   4 Match 909

[ 1.3115234  2.6855469  1.265625  -1.4316406 -3.2792969] 1   0 
[-2.7128906  -2.3320312  -0.27978516  2.4121094   3.171875  ] 4   4 Match 910

[ 2.7988281  2.2675781  0.2775879 -2.203125  -2.8632812] 0   1 
[-3.0605469  -2.4746094   0.36108398  2.7050781   2.6816406 ] 3   1 
[-3.0253906  -2.046875    0.39331055  3.1113281   1.9169922 ] 3   3 Match 911

[ 2.3027344  2.6445312  0.6088867 -2.0859375 -3.2792969] 1   1 Match 912

[-2.4355469  -2.3535156  -0.43188477  1.8300781   3.8066406 ] 4   4 Match 913

[-2.8808594  -2.4140625  -0.03439331  2.884

[-2.3261719 -2.3203125 -0.6069336  1.5673828  3.8808594] 4   4 Match 973

[ 1.8193359   2.3964844   0.81103516 -1.8125     -3.1875    ] 1   1 Match 974

[ 2.6660156   2.3574219   0.25463867 -2.1816406  -2.9492188 ] 0   0 Match 975

[-3.015625  -2.1308594  0.7285156  3.3769531  1.828125 ] 3   3 Match 976

[-2.9023438  -2.4199219  -0.26391602  2.4707031   3.2460938 ] 4   4 Match 977

[ 3.5097656   1.8681641  -0.19006348 -2.2382812  -2.5253906 ] 0   0 Match 978

[ 2.0683594   2.7363281   0.63134766 -2.0722656  -3.3085938 ] 1   1 Match 979

[ 2.9375      2.3828125   0.34716797 -2.2265625  -3.0546875 ] 0   0 Match 980

[-2.6953125  -2.078125    0.46020508  2.7402344   2.0625    ] 3   3 Match 981

[-3.1777344  -2.2988281   0.20251465  3.2929688   2.3945312 ] 3   3 Match 982

[ 1.6240234  2.7558594  0.9628906 -1.8203125 -3.3398438] 1   1 Match 983

[ 0.36499023  2.3476562   1.4726562  -0.82421875 -3.21875   ] 1   1 Match 984

[-2.2636719   0.03216553  2.1191406   2.1757812  -1.3271484 ] 3   2

[-2.6601562  -0.9628906   1.7226562   2.9550781  -0.06567383] 3   2 
[-3.0390625  -2.3105469   0.22253418  3.3984375   2.3535156 ] 3   3 Match 1049

[-2.5742188 -2.09375    0.3269043  2.9238281  2.2460938] 3   3 Match 1050

[-2.3769531  -2.4804688  -0.57373047  1.5         3.9199219 ] 4   4 Match 1051

[-2.7226562  -2.4023438  -0.06488037  2.2949219   3.0703125 ] 4   4 Match 1052

[ 2.4335938   2.2480469   0.53027344 -1.7373047  -2.9355469 ] 0   1 
[ 1.6513672   2.7382812   0.94140625 -1.7548828  -3.2148438 ] 1   1 Match 1053

[-2.7871094 -2.5722656 -0.6352539  2.3554688  3.8632812] 4   4 Match 1054

[-1.0751953   1.8447266   2.2226562   0.19714355 -2.9707031 ] 2   3 
[-3.0449219  -2.3027344  -0.18383789  2.9863281   2.84375   ] 3   3 Match 1055

[-2.5878906  -0.7841797   2.0644531   2.7285156  -0.29467773] 3   3 Match 1056

[ 0.36499023  2.2695312   1.578125   -0.89746094 -3.1894531 ] 1   2 
[-2.8085938  -2.3320312  -0.18640137  2.7578125   3.2890625 ] 4   4 Match 1057

[-2.9824219  -

[-3.2617188  -2.0566406   0.49389648  2.8925781   1.8730469 ] 3   3 Match 1126

[-3.0234375 -1.8789062  0.7685547  3.2441406  1.4521484] 3   3 Match 1127

[-2.6855469  -1.2226562   1.6962891   2.9609375   0.39990234] 3   3 Match 1128

[ 1.9160156   2.2070312   0.77490234 -1.3974609  -3.1464844 ] 1   2 
[-2.8808594 -2.546875  -0.5292969  2.6894531  3.5097656] 4   3 
[-2.4804688  -0.91015625  1.2724609   2.28125     0.25976562] 3   2 
[ 1.2724609  2.7128906  1.1337891 -1.6445312 -3.3789062] 1   1 Match 1129

[ 3.0917969  2.0058594  0.1126709 -1.9453125 -2.8574219] 0   1 
[-2.4472656  -0.76953125  2.0234375   2.2597656  -0.21252441] 3   2 
[ 1.0273438  2.6308594  1.6806641 -1.3828125 -3.2617188] 1   2 
[ 3.1054688e+00  2.3398438e+00 -6.9570541e-04 -2.3046875e+00
 -2.8906250e+00] 0   0 Match 1130

[ 2.3125      2.5371094   0.85302734 -1.8740234  -3.2929688 ] 1   2 
[-2.7578125  -2.3789062  -0.31054688  2.4667969   3.3847656 ] 4   4 Match 1131

[-2.9199219  -1.1064453   1.2939453   3.046875

[-3.1308594 -2.5449219 -0.1595459  3.2226562  2.8925781] 3   3 Match 1194

[-2.7871094  -2.5585938  -0.53759766  2.4882812   3.7558594 ] 4   3 
[ 1.4814453   2.2109375   0.96240234 -1.5556641  -3.1503906 ] 1   1 Match 1195

[-2.5488281  -0.63623047  1.6210938   2.3671875   0.01831055] 3   3 Match 1196

[-2.8164062  -0.9121094   1.7539062   2.921875   -0.01757812] 3   3 Match 1197

[-2.1796875   0.62890625  2.2578125   1.7246094  -1.6484375 ] 2   3 
[ 3.3925781   2.1210938  -0.05895996 -2.1601562  -2.7246094 ] 0   0 Match 1198

[ 1.4765625  2.8925781  1.0390625 -1.7451172 -3.4335938] 1   1 Match 1199

[ 3.2734375   2.3652344   0.16003418 -2.2441406  -2.9296875 ] 0   0 Match 1200

[-0.32666016  1.0986328   1.7695312   0.4699707  -2.1289062 ] 2   3 
[ 1.7558594   2.8398438   0.82177734 -1.8945312  -3.234375  ] 1   1 Match 1201

[-3.0292969 -2.1953125 -0.0680542  3.1738281  2.6621094] 3   3 Match 1202

[ 0.6088867  2.5488281  1.3505859 -1.2167969 -3.3066406] 1   1 Match 1203

[-2.1757812 -


[ 2.3085938  2.5839844  0.6142578 -2.0703125 -3.2753906] 1   0 
[ 2.5195312  2.2519531  0.3112793 -2.0488281 -2.8808594] 0   0 Match 1264

[ 2.5019531  2.6816406  0.5361328 -2.1640625 -3.2226562] 1   0 
[-2.5898438  -0.15661621  2.0976562   1.9765625  -0.9472656 ] 2   2 Match 1265

[ 0.8413086  2.5117188  1.2373047 -1.4638672 -3.3632812] 1   2 
[-2.421875   -0.28808594  2.078125    2.4511719  -0.87939453] 3   2 
[-2.7402344  -0.66552734  1.8154297   2.8945312  -0.36669922] 3   2 
[ 0.91259766  2.4804688   1.7333984  -1.2841797  -3.1679688 ] 1   1 Match 1266

[-3.03125   -1.5732422  1.1445312  3.0410156  1.109375 ] 3   2 
[-2.9550781  -2.5097656  -0.34716797  2.7773438   3.3535156 ] 4   4 Match 1267

[ 1.6367188  2.6367188  1.1455078 -1.6113281 -3.3125   ] 1   1 Match 1268

[-2.7011719  -0.74121094  1.5830078   2.8613281  -0.07769775] 3   3 Match 1269

[-2.6621094  -2.4355469  -0.52197266  2.1367188   3.8183594 ] 4   4 Match 1270

[-2.8535156 -1.3974609  1.2470703  3.0390625  0.8076172

In [32]:
from sklearn import metrics
 
print(metrics.confusion_matrix(Targets,Pred))

[[164 106   7   2   0]
 [117 431  68  16   1]
 [ 16 147 118  97  11]
 [  1   6  45 344 114]
 [  0   1  10 117 271]]


In [33]:
target_names = ['Very Neg', 'Negative', 'Neutral','Positive','Very Pos']
print(metrics.classification_report(Targets, Pred,target_names =target_names))

              precision    recall  f1-score   support

    Very Neg       0.55      0.59      0.57       279
    Negative       0.62      0.68      0.65       633
     Neutral       0.48      0.30      0.37       389
    Positive       0.60      0.67      0.63       510
    Very Pos       0.68      0.68      0.68       399

    accuracy                           0.60      2210
   macro avg       0.59      0.59      0.58      2210
weighted avg       0.59      0.60      0.59      2210



In [34]:
Fold_Predictions=pd.DataFrame(Pred, columns=['Pred5'] )
Fold_Predictions

Unnamed: 0,Pred5
0,1
1,1
2,1
3,4
4,1
...,...
2205,4
2206,1
2207,4
2208,4


In [35]:
Fold_Predictions.to_excel(output_folder+'/Saves/fold5_Predictions.xls')

In [36]:
#clearing GPU cache

del(model)
del(TrainResult, TrainModel_outputs, EvalResult, EvalModel_outputs, TestResult, TestModel_outputs, wrong_predictions)
torch.cuda.empty_cache()

# Comparing the Predictions

In [37]:
Pred1=pd.read_excel('./folds/fold1/'+model_class+'/'+model_version+'/Saves/fold1_Predictions.xls')
Pred2=pd.read_excel('./folds/fold2/'+model_class+'/'+model_version+'/Saves/fold2_Predictions.xls')
Pred3=pd.read_excel('./folds/fold3/'+model_class+'/'+model_version+'/Saves/fold3_Predictions.xls')
Pred4=pd.read_excel('./folds/fold4/'+model_class+'/'+model_version+'/Saves/fold4_Predictions.xls')
Pred5=pd.read_excel('./folds/fold5/'+model_class+'/'+model_version+'/Saves/fold5_Predictions.xls')


In [38]:
 for row in range(len(Pred1)):
        
        print(Pred1.iloc[row,1] , end=',')
        print(Pred2.iloc[row,1]  , end=',')
        print(Pred3.iloc[row,1] , end=',')
        print(Pred4.iloc[row,1] , end=',')
        print(Pred5.iloc[row,1] )
    

1,1,1,2,1
1,1,1,1,1
1,1,1,1,1
4,4,4,4,4
2,2,2,1,1
3,3,3,3,3
1,1,0,0,0
3,4,3,3,3
0,0,0,0,0
0,0,0,0,0
1,1,1,1,1
3,3,3,3,3
2,2,2,2,2
0,0,0,0,0
4,4,4,4,4
1,1,1,1,1
3,4,3,3,3
3,3,3,3,3
1,1,1,1,1
4,4,4,4,4
4,4,4,4,4
3,3,3,3,3
2,1,1,1,1
1,1,1,1,1
1,1,1,1,1
0,0,0,1,0
0,0,0,0,0
1,1,1,1,1
1,1,1,0,0
1,1,1,1,1
2,3,3,3,2
1,1,1,1,1
2,2,2,2,2
1,1,1,0,1
0,1,0,0,0
3,3,1,1,2
4,4,4,4,4
3,4,4,4,3
3,2,3,3,2
1,1,1,1,1
4,4,4,4,4
1,1,1,1,1
3,3,3,3,3
2,2,2,2,2
1,1,1,1,1
0,0,0,0,0
4,4,4,4,4
1,1,1,1,1
3,3,3,3,3
0,0,0,0,0
1,1,1,0,1
0,1,0,1,1
1,1,1,1,1
3,3,3,3,3
3,3,3,3,3
1,1,1,1,1
1,1,1,1,1
2,2,1,3,3
2,2,2,3,3
0,1,0,1,0
3,3,3,3,3
1,1,0,1,0
3,3,3,3,3
3,3,3,3,3
1,1,1,1,1
3,3,3,4,3
2,2,2,2,1
1,1,1,0,1
4,4,4,4,4
1,2,1,2,2
3,3,3,3,3
0,0,0,0,0
3,4,4,3,4
1,1,2,1,1
4,4,4,4,4
1,1,0,1,1
2,1,1,1,1
4,4,4,4,4
0,0,0,0,0
1,1,1,1,1
1,2,1,1,1
0,0,0,0,0
3,3,3,3,3
3,3,3,3,3
3,4,3,4,3
0,0,0,0,0
4,4,4,4,4
3,4,3,3,3
3,3,3,3,2
2,2,2,1,1
2,2,2,2,2
1,1,1,1,1
0,0,0,0,0
4,4,4,4,4
0,0,0,0,0
1,1,1,1,1
3,3,3,3,3
4,4,4,4,4
3,3,3,3,3
3,3,3,3,3


0,0,0,0,0
3,3,3,3,3
2,1,1,1,1
3,3,3,3,3
3,4,3,4,3
2,2,2,2,2
3,3,3,3,3
4,4,4,4,4
4,4,4,4,4
1,1,1,1,1
3,4,3,4,4
4,3,3,4,3
2,2,2,2,2
0,0,0,0,0
0,0,0,0,0
0,0,0,0,0
2,1,2,1,2
4,4,3,4,4
4,4,3,4,3
1,0,0,1,0
1,2,2,2,1
4,4,4,4,4
2,2,2,3,2
3,3,3,3,3
1,1,1,1,1
0,0,0,0,0
3,3,3,3,3
1,1,1,1,1
4,4,4,4,4
3,3,3,3,3
4,4,4,4,4
3,3,3,3,3
1,1,1,1,1
4,4,4,4,4
3,4,4,4,4
1,1,1,1,1
3,3,3,4,3
3,3,3,3,3
0,0,0,0,0
2,2,2,2,2
1,1,1,0,1
1,1,1,1,1
4,4,4,4,4
4,4,4,4,4
2,3,2,2,2
2,2,1,1,1
3,4,4,4,4
2,1,1,2,2
4,4,4,4,4
1,1,0,0,0
2,2,2,2,2
4,4,4,4,4
1,1,1,0,1
2,3,3,3,3
4,4,4,4,4
3,3,3,3,2
3,3,3,3,3
0,0,0,0,0
1,1,1,1,1
3,3,3,3,3
1,1,1,0,1
1,1,1,1,1
4,4,4,4,4
1,1,1,1,1
1,1,1,1,1
1,1,1,1,1
1,1,1,1,1
4,4,3,4,4
4,4,4,4,4
1,1,1,1,1
1,1,0,2,1
1,1,1,1,1
1,1,0,1,1
0,0,0,0,0
3,3,3,3,3
1,1,1,2,1
3,4,4,4,4
3,4,3,3,3
0,1,0,1,0
3,4,3,3,3
1,1,1,1,1
0,0,0,0,0
4,4,4,3,3
2,2,3,3,2
0,0,0,0,0
1,1,1,1,1
1,1,1,1,1
4,4,4,4,4
0,0,0,0,0
3,4,3,3,3
3,4,3,4,4
0,0,0,0,0
0,0,0,0,0
4,4,4,4,4
0,0,0,0,0
3,3,3,3,3
1,0,0,0,1
1,1,1,1,1
1,1,1,1,1
4,4,4,4,4


4,4,4,4,4
1,1,1,1,1
2,2,2,2,2
3,3,3,3,2
3,3,3,3,3
2,3,3,3,3
2,2,2,1,2
2,2,2,2,2
3,3,4,4,3
4,4,4,4,4
1,0,1,0,0
0,0,0,0,0
0,0,0,0,0
0,0,0,0,1
3,3,3,3,3
3,3,3,3,3
3,3,3,3,3
1,1,1,1,1
3,3,3,3,3
1,1,1,1,1
1,0,1,0,0
1,1,1,1,1
3,4,3,4,3
1,2,1,2,1
2,2,1,2,2
3,3,2,3,3
4,4,4,4,4
3,3,3,3,3
2,2,2,2,2
3,3,3,2,3
3,3,3,3,3
1,1,1,1,1
3,4,3,4,4
0,0,0,0,0
4,4,4,4,4
1,1,1,0,1
4,4,4,4,4
1,1,1,1,1
1,1,1,1,1
3,3,3,3,3
1,1,1,1,1
1,1,1,0,1
3,3,3,3,3
3,4,4,4,4
2,1,1,2,2
1,2,2,1,1
1,1,1,1,1
1,0,0,0,0
2,3,3,3,3
1,1,1,1,1
0,0,0,0,0
3,3,3,3,3
1,0,0,0,0
1,1,1,1,1
3,3,3,3,2
3,3,3,3,3
4,4,4,4,4
1,1,0,0,0
1,1,0,1,1
2,2,2,1,1
3,3,3,4,3
3,3,3,3,3
3,1,3,3,3
4,4,4,4,4
1,1,1,1,1
2,2,2,1,2
1,1,1,1,1
2,3,2,4,1
1,1,1,1,1
3,3,3,3,3
0,0,0,0,0
3,3,3,3,3
1,1,1,1,1
3,3,3,3,3
3,3,3,3,3
3,3,3,3,3
1,1,1,0,1
4,4,4,4,4
3,3,3,3,3
1,1,1,1,1
0,0,0,0,0
3,3,2,2,3
1,1,1,1,1
0,0,0,0,0
1,1,1,0,1
4,4,4,4,4
3,3,3,3,3
2,2,2,1,3
0,0,0,0,0
3,4,3,3,3
1,1,1,1,1
2,3,2,2,3
3,4,3,3,3
4,4,4,4,4
2,2,2,2,2
2,2,2,2,2
4,3,4,4,4
0,0,0,0,0
1,1,1,1,1
1,1,0,0,0


IndexError: single positional indexer is out-of-bounds

In [39]:
results=pd.DataFrame( columns=['text', 'label','fold1','fold2','fold3','fold4','fold5'])

results['text']=test['text']
results['label']=test['labels']
results['fold1']=Pred1['Pred1'] 
results['fold2']=Pred2['Pred2'] 
results['fold3']=Pred3['Pred3'] 
results['fold4']=Pred4['Pred4'] 
results['fold5']=Pred5['Pred5'] 

        
results

Unnamed: 0,text,label,fold1,fold2,fold3,fold4,fold5
0,Maybe I found the proceedings a little bit too...,1,1,1,1,2.0,1
1,"As with too many studio pics , plot mechanics ...",0,1,1,1,1.0,1
2,"Beers , who , when she 's given the right line...",2,1,1,1,1.0,1
3,"Cute , funny , heartwarming digitally animated...",4,4,4,4,4.0,4
4,So what is the point ?,0,2,2,2,1.0,1
...,...,...,...,...,...,...,...
2205,It 's a glorious groove that leaves you wantin...,4,4,4,4,4.0,4
2206,It 's getting harder and harder to ignore the ...,1,1,1,0,1.0,1
2207,"A real movie , about real people , that gives ...",3,4,4,4,4.0,4
2208,"Sharp , lively , funny and ultimately sobering...",4,4,4,3,3.0,4


In [40]:
now=datetime.now()

results.to_excel('./folds/'+model_class+'_results'+now.strftime("%m-%d-%Y %H-%M")+'.xls' ,index=False)