In [2]:
import re
import numpy as np
import pandas as pd
from tqdm import tqdm
import pickle

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

import transformers
from transformers import AutoModel, BertTokenizer, BertForSequenceClassification, AdamW

import warnings
warnings.filterwarnings('ignore')

tqdm.pandas()

In [8]:
# device = torch.device("cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [8]:
df = pd.read_csv('data/English_scores/sentences.csv')

In [9]:
df.head()

Unnamed: 0,text,level,movie
0,"ben on phone michelle, please don t hang up. j...",1,10_Cloverfield_lane(2016)
1,"your wallet. given as how i saved your life, i...",1,10_Cloverfield_lane(2016)
2,"and make sure they re okay. michelle, they re ...",1,10_Cloverfield_lane(2016)
3,"he s sorry for, correct? totally. let s go. ba...",1,10_Cloverfield_lane(2016)
4,possible. i heard one earlier. above my room. ...,1,10_Cloverfield_lane(2016)


In [49]:
df['level'].value_counts()

level
2    3216
1    1966
0    1166
3    1156
Name: count, dtype: int64

In [5]:
bert = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
tokenizer = BertTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'MPNetTokenizer'. 
The class this function is called from is 'BertTokenizer'.


In [10]:
X = df['text']
y = df['level']

In [28]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)


In [29]:
X_train = X_train.reset_index(drop=True).astype(str)
X_test = X_test.reset_index(drop=True).astype(str)

y_train = y_train.reset_index(drop=True)
y_test = y_test.reset_index(drop=True)

In [40]:
tokens_train = tokenizer.batch_encode_plus(
    X_train.values,
    max_length = 256,
    padding = 'max_length',
    truncation = True
)

tokens_val = tokenizer.batch_encode_plus(
    X_test.values,
    max_length = 256,
    padding = 'max_length',
    truncation = True
)

# tokens_test = tokenizer.batch_encode_plus(
#     test_text.values,
#     max_length = 256,
#     padding = 'max_length',
#     truncation = True
# )

train_seq = torch.tensor(tokens_train['input_ids'])
train_mask = torch.tensor(tokens_train['attention_mask'])
train_y = torch.tensor(y_train.values)

val_seq = torch.tensor(tokens_val['input_ids'])
val_mask = torch.tensor(tokens_val['attention_mask'])
val_y = torch.tensor(y_test.values)

# test_seq = torch.tensor(tokens_test['input_ids'])
# test_mask = torch.tensor(tokens_test['attention_mask'])
# test_y = torch.tensor(test_labels.values)

batch_size = 8

train_data = TensorDataset(train_seq, train_mask, train_y)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler = train_sampler, batch_size = batch_size)

val_data =  TensorDataset(val_seq, val_mask, val_y)
val_sampler = SequentialSampler(val_data)
val_dataloader = DataLoader(val_data, sampler = val_sampler, batch_size = batch_size)

In [6]:
for param in bert.parameters():
    param.requires_grad = False

class BERT_Arch(nn.Module):
    def __init__(self, bert):
        super(BERT_Arch, self).__init__()
        self.bert = bert
        self.dropout = nn.Dropout(0.1)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(768,512)
        self.fc2 = nn.Linear(512,4)
        self.softmax = nn.Softmax(dim = 1)

    def forward(self, sent_id, mask):
        _, cls_hs = self.bert(sent_id, attention_mask = mask, return_dict = False)
        x = self.fc1(cls_hs)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.softmax(x)

        return x

In [6]:
model = BERT_Arch(bert)
model = model.to(device)

optimizer = AdamW(model.parameters(),
                  lr= 1e-3)

In [7]:
import pickle

pickle.dump(model, open('models/model.pkl', 'wb'))

In [57]:
from sklearn.utils.class_weight import compute_class_weight

class_weights = compute_class_weight(class_weight='balanced', classes=sorted(y.unique()), y=y)
print(class_weights)

[1.60891938 0.95422177 0.58333333 1.62283737]


In [58]:
weights = torch.tensor(class_weights, dtype=torch.float)
weights = weights.to(device)

In [59]:
cross_entropy = nn.CrossEntropyLoss(weight=weights)

In [60]:
def train():
    model.train()
    total_loss, total_accuracy = 0, 0
    total_preds = []

    for step, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
        batch = [r.to(device) for r in batch]
        sent_id, mask, labels = batch
        model.zero_grad()
        preds = model(sent_id, mask)

        loss = cross_entropy(preds, labels)

        total_loss += loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        preds = preds.detach().cpu().numpy()
        total_preds.append(preds)

    avg_loss = total_loss / len(train_dataloader)
    total_preds = np.concatenate(total_preds, axis=0)

    return avg_loss, total_preds


In [61]:
def evaluate():
    model.eval()
    total_loss, total_accuracy = 0,0
    total_preds = []

    for step, batch in tqdm(enumerate(val_dataloader), total = len(val_dataloader)):
        batch = [t.to(device) for t in batch]
        sent_id, mask, labels = batch

        with torch.no_grad():
            preds = model(sent_id, mask)
            loss = cross_entropy(preds, labels)
            total_loss = total_loss + loss.item()
            preds = preds.detach().cpu().numpy()
            total_preds.append(preds)

    avg_loss = total_loss / len(val_dataloader)
    total_preds = np.concatenate(total_preds, axis = 0)

    return avg_loss, total_preds

In [62]:
epochs = 50

In [63]:
best_valid_loss = float('inf')

train_losses = []
valid_losses = []

for epoch in range(epochs):
    print('\n Epoch{:} / {:}'.format(epoch+1, epochs))

    train_loss, _ = train()
    valid_loss, _ = evaluate()

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'saved_weights.pt')

    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    print(f'\nTraining loss: {train_loss:.3f}')
    print(f'Validation loss: {valid_loss:.3f}')


 Epoch1 / 50


100%|██████████| 751/751 [01:08<00:00, 10.90it/s]
100%|██████████| 188/188 [00:13<00:00, 13.51it/s]



Training loss: 1.253
Validation loss: 1.183

 Epoch2 / 50


100%|██████████| 751/751 [01:00<00:00, 12.32it/s]
100%|██████████| 188/188 [00:14<00:00, 13.42it/s]



Training loss: 1.173
Validation loss: 1.160

 Epoch3 / 50


100%|██████████| 751/751 [01:01<00:00, 12.29it/s]
100%|██████████| 188/188 [00:14<00:00, 13.39it/s]



Training loss: 1.146
Validation loss: 1.149

 Epoch4 / 50


100%|██████████| 751/751 [01:01<00:00, 12.26it/s]
100%|██████████| 188/188 [00:14<00:00, 13.37it/s]



Training loss: 1.133
Validation loss: 1.159

 Epoch5 / 50


100%|██████████| 751/751 [01:01<00:00, 12.24it/s]
100%|██████████| 188/188 [00:14<00:00, 13.37it/s]



Training loss: 1.110
Validation loss: 1.143

 Epoch6 / 50


100%|██████████| 751/751 [01:01<00:00, 12.23it/s]
100%|██████████| 188/188 [00:14<00:00, 13.37it/s]



Training loss: 1.091
Validation loss: 1.140

 Epoch7 / 50


100%|██████████| 751/751 [01:01<00:00, 12.25it/s]
100%|██████████| 188/188 [00:14<00:00, 13.36it/s]



Training loss: 1.082
Validation loss: 1.131

 Epoch8 / 50


100%|██████████| 751/751 [01:01<00:00, 12.24it/s]
100%|██████████| 188/188 [00:14<00:00, 13.36it/s]



Training loss: 1.064
Validation loss: 1.126

 Epoch9 / 50


100%|██████████| 751/751 [01:01<00:00, 12.23it/s]
100%|██████████| 188/188 [00:14<00:00, 13.36it/s]



Training loss: 1.045
Validation loss: 1.125

 Epoch10 / 50


100%|██████████| 751/751 [01:01<00:00, 12.24it/s]
100%|██████████| 188/188 [00:14<00:00, 13.38it/s]



Training loss: 1.031
Validation loss: 1.110

 Epoch11 / 50


100%|██████████| 751/751 [01:01<00:00, 12.23it/s]
100%|██████████| 188/188 [00:14<00:00, 13.35it/s]



Training loss: 1.013
Validation loss: 1.109

 Epoch12 / 50


100%|██████████| 751/751 [01:01<00:00, 12.22it/s]
100%|██████████| 188/188 [00:14<00:00, 13.36it/s]



Training loss: 1.003
Validation loss: 1.111

 Epoch13 / 50


100%|██████████| 751/751 [01:01<00:00, 12.23it/s]
100%|██████████| 188/188 [00:14<00:00, 13.37it/s]



Training loss: 0.988
Validation loss: 1.113

 Epoch14 / 50


100%|██████████| 751/751 [01:01<00:00, 12.23it/s]
100%|██████████| 188/188 [00:14<00:00, 13.35it/s]



Training loss: 0.978
Validation loss: 1.104

 Epoch15 / 50


100%|██████████| 751/751 [01:01<00:00, 12.22it/s]
100%|██████████| 188/188 [00:14<00:00, 13.34it/s]



Training loss: 0.966
Validation loss: 1.104

 Epoch16 / 50


100%|██████████| 751/751 [01:01<00:00, 12.21it/s]
100%|██████████| 188/188 [00:14<00:00, 13.34it/s]



Training loss: 0.954
Validation loss: 1.098

 Epoch17 / 50


100%|██████████| 751/751 [01:01<00:00, 12.22it/s]
100%|██████████| 188/188 [00:14<00:00, 13.35it/s]



Training loss: 0.947
Validation loss: 1.101

 Epoch18 / 50


100%|██████████| 751/751 [01:01<00:00, 12.21it/s]
100%|██████████| 188/188 [00:14<00:00, 13.34it/s]



Training loss: 0.932
Validation loss: 1.099

 Epoch19 / 50


100%|██████████| 751/751 [01:01<00:00, 12.22it/s]
100%|██████████| 188/188 [00:14<00:00, 13.33it/s]



Training loss: 0.931
Validation loss: 1.093

 Epoch20 / 50


100%|██████████| 751/751 [01:01<00:00, 12.20it/s]
100%|██████████| 188/188 [00:14<00:00, 13.34it/s]



Training loss: 0.915
Validation loss: 1.105

 Epoch21 / 50


100%|██████████| 751/751 [01:01<00:00, 12.20it/s]
100%|██████████| 188/188 [00:14<00:00, 13.34it/s]



Training loss: 0.915
Validation loss: 1.094

 Epoch22 / 50


100%|██████████| 751/751 [01:01<00:00, 12.21it/s]
100%|██████████| 188/188 [00:14<00:00, 13.34it/s]



Training loss: 0.898
Validation loss: 1.109

 Epoch23 / 50


100%|██████████| 751/751 [01:01<00:00, 12.21it/s]
100%|██████████| 188/188 [00:14<00:00, 13.34it/s]



Training loss: 0.899
Validation loss: 1.099

 Epoch24 / 50


100%|██████████| 751/751 [01:01<00:00, 12.21it/s]
100%|██████████| 188/188 [00:14<00:00, 13.32it/s]



Training loss: 0.890
Validation loss: 1.090

 Epoch25 / 50


100%|██████████| 751/751 [01:01<00:00, 12.21it/s]
100%|██████████| 188/188 [00:14<00:00, 13.34it/s]



Training loss: 0.878
Validation loss: 1.103

 Epoch26 / 50


100%|██████████| 751/751 [01:01<00:00, 12.22it/s]
100%|██████████| 188/188 [00:14<00:00, 13.34it/s]



Training loss: 0.880
Validation loss: 1.098

 Epoch27 / 50


100%|██████████| 751/751 [01:01<00:00, 12.21it/s]
100%|██████████| 188/188 [00:14<00:00, 13.34it/s]



Training loss: 0.875
Validation loss: 1.096

 Epoch28 / 50


100%|██████████| 751/751 [01:01<00:00, 12.22it/s]
100%|██████████| 188/188 [00:14<00:00, 13.34it/s]



Training loss: 0.873
Validation loss: 1.099

 Epoch29 / 50


100%|██████████| 751/751 [01:01<00:00, 12.21it/s]
100%|██████████| 188/188 [00:14<00:00, 13.33it/s]



Training loss: 0.861
Validation loss: 1.089

 Epoch30 / 50


100%|██████████| 751/751 [01:01<00:00, 12.21it/s]
100%|██████████| 188/188 [00:14<00:00, 13.35it/s]



Training loss: 0.858
Validation loss: 1.096

 Epoch31 / 50


100%|██████████| 751/751 [01:01<00:00, 12.21it/s]
100%|██████████| 188/188 [00:14<00:00, 13.35it/s]



Training loss: 0.857
Validation loss: 1.091

 Epoch32 / 50


100%|██████████| 751/751 [01:01<00:00, 12.19it/s]
100%|██████████| 188/188 [00:14<00:00, 13.25it/s]



Training loss: 0.850
Validation loss: 1.093

 Epoch33 / 50


100%|██████████| 751/751 [01:01<00:00, 12.21it/s]
100%|██████████| 188/188 [00:14<00:00, 13.35it/s]



Training loss: 0.854
Validation loss: 1.083

 Epoch34 / 50


100%|██████████| 751/751 [01:01<00:00, 12.21it/s]
100%|██████████| 188/188 [00:14<00:00, 13.34it/s]



Training loss: 0.845
Validation loss: 1.081

 Epoch35 / 50


100%|██████████| 751/751 [01:01<00:00, 12.22it/s]
100%|██████████| 188/188 [00:14<00:00, 13.34it/s]



Training loss: 0.841
Validation loss: 1.084

 Epoch36 / 50


100%|██████████| 751/751 [01:01<00:00, 12.22it/s]
100%|██████████| 188/188 [00:14<00:00, 13.34it/s]



Training loss: 0.837
Validation loss: 1.087

 Epoch37 / 50


100%|██████████| 751/751 [01:01<00:00, 12.21it/s]
100%|██████████| 188/188 [00:14<00:00, 13.35it/s]



Training loss: 0.833
Validation loss: 1.085

 Epoch38 / 50


100%|██████████| 751/751 [01:01<00:00, 12.23it/s]
100%|██████████| 188/188 [00:14<00:00, 13.34it/s]



Training loss: 0.833
Validation loss: 1.091

 Epoch39 / 50


100%|██████████| 751/751 [01:01<00:00, 12.20it/s]
100%|██████████| 188/188 [00:14<00:00, 13.35it/s]



Training loss: 0.830
Validation loss: 1.098

 Epoch40 / 50


100%|██████████| 751/751 [01:01<00:00, 12.21it/s]
100%|██████████| 188/188 [00:14<00:00, 13.35it/s]



Training loss: 0.829
Validation loss: 1.090

 Epoch41 / 50


100%|██████████| 751/751 [01:01<00:00, 12.22it/s]
100%|██████████| 188/188 [00:14<00:00, 13.35it/s]



Training loss: 0.823
Validation loss: 1.083

 Epoch42 / 50


100%|██████████| 751/751 [01:01<00:00, 12.21it/s]
100%|██████████| 188/188 [00:14<00:00, 13.33it/s]



Training loss: 0.821
Validation loss: 1.082

 Epoch43 / 50


100%|██████████| 751/751 [01:01<00:00, 12.22it/s]
100%|██████████| 188/188 [00:14<00:00, 13.36it/s]



Training loss: 0.820
Validation loss: 1.078

 Epoch44 / 50


100%|██████████| 751/751 [01:01<00:00, 12.22it/s]
100%|██████████| 188/188 [00:14<00:00, 13.36it/s]



Training loss: 0.821
Validation loss: 1.081

 Epoch45 / 50


100%|██████████| 751/751 [01:01<00:00, 12.22it/s]
100%|██████████| 188/188 [00:14<00:00, 13.35it/s]



Training loss: 0.818
Validation loss: 1.085

 Epoch46 / 50


100%|██████████| 751/751 [01:01<00:00, 12.21it/s]
100%|██████████| 188/188 [00:14<00:00, 13.35it/s]



Training loss: 0.817
Validation loss: 1.086

 Epoch47 / 50


100%|██████████| 751/751 [01:01<00:00, 12.22it/s]
100%|██████████| 188/188 [00:14<00:00, 13.34it/s]



Training loss: 0.815
Validation loss: 1.093

 Epoch48 / 50


100%|██████████| 751/751 [01:01<00:00, 12.21it/s]
100%|██████████| 188/188 [00:14<00:00, 13.34it/s]



Training loss: 0.812
Validation loss: 1.088

 Epoch49 / 50


100%|██████████| 751/751 [01:01<00:00, 12.21it/s]
100%|██████████| 188/188 [00:14<00:00, 13.33it/s]



Training loss: 0.814
Validation loss: 1.091

 Epoch50 / 50


100%|██████████| 751/751 [01:01<00:00, 12.21it/s]
100%|██████████| 188/188 [00:14<00:00, 13.36it/s]


Training loss: 0.812
Validation loss: 1.090





In [9]:
import os

path = 'models'
path_model = os.path.join(path, 'model.pkl')
path_weights = os.path.join(path, 'saved_weights.pt')

# Now, load the model and weights
model = pickle.load(open(path_model, 'rb'))
model = model.to(device)
model.load_state_dict(torch.load(path_weights))


<All keys matched successfully>

In [41]:
test_df = pd.DataFrame({'text': X_test, 'level': y_test})

In [51]:
import gc
gc.collect()
torch.cuda.empty_cache()

list_seq = np.array_split(val_seq, 256)
list_mask = np.array_split(val_mask, 256)

predictions = []
for seq_elem, mask_elem in zip(list_seq, list_mask):
    with torch.no_grad():
        preds = model(seq_elem.to(device), mask_elem.to(device))
        predictions.append(preds.detach().cpu().numpy())

In [52]:
# concat predictions
predictions = np.concatenate(predictions, axis=0)

In [67]:
int(predictions.argmax(axis=1).mean().round())

2

In [53]:
test_df['confidence'] = predictions.max(axis=1)

In [54]:
test_df['prediction'] = predictions.argmax(axis=1)

In [55]:
test_df

Unnamed: 0,text,level,confidence,prediction,miss_value
0,for christmas. that s convenient because i got...,1,0.820066,1,0
1,s usually right after somebody died. i take it...,2,0.596158,2,0
2,"mr. reede, you may proceed. how? your honor, w...",1,0.998492,1,0
3,flower. and we will rule this jungle. i will p...,0,0.978000,0,0
4,"that, that was a party. not this. you know, pe...",2,0.998495,2,0
...,...,...,...,...,...
1496,"when i m not with daniel, i m better. and. i m...",1,0.886949,2,1
1497,would be waiting for us when we got back. we l...,1,0.983142,1,0
1498,that bell. you re damn right i m not ringing t...,3,0.907795,2,1
1499,"if you want to get technical, there s memorial...",0,0.969342,1,1


In [56]:
print(classification_report(test_df['level'], test_df['prediction']))

              precision    recall  f1-score   support

           0       0.66      0.66      0.66       233
           1       0.61      0.68      0.65       393
           2       0.73      0.70      0.71       644
           3       0.60      0.56      0.58       231

    accuracy                           0.67      1501
   macro avg       0.65      0.65      0.65      1501
weighted avg       0.67      0.67      0.67      1501



In [57]:
test_df['miss_value'] = abs(test_df['level'] - test_df['prediction'])

In [58]:
test_df['miss_value'].value_counts()

miss_value
0    999
1    392
2     99
3     11
Name: count, dtype: int64

In [59]:
len(test_df[test_df['miss_value'].isin([0, 1])]) / len(test_df)

0.9267155229846769