## Installation of libraries and imports

In [None]:
!pip install datasets
!pip install transformers

Collecting datasets
[?25l  Downloading https://files.pythonhosted.org/packages/83/7e/8d9e2fd30e3819e6042927d379f3668a0b49fe38b92d5639194808a1d877/datasets-1.0.2-py3-none-any.whl (1.8MB)
[K     |████████████████████████████████| 1.8MB 6.4MB/s 
Collecting pyarrow>=0.17.1
[?25l  Downloading https://files.pythonhosted.org/packages/f3/99/0a605f016121ca314d1469dc9069e4978395bc46fda40f73099d90ad3ba4/pyarrow-1.0.1-cp36-cp36m-manylinux2014_x86_64.whl (17.3MB)
[K     |████████████████████████████████| 17.3MB 174kB/s 
Collecting xxhash
[?25l  Downloading https://files.pythonhosted.org/packages/f7/73/826b19f3594756cb1c6c23d2fbd8ca6a77a9cd3b650c9dec5acc85004c38/xxhash-2.0.0-cp36-cp36m-manylinux2010_x86_64.whl (242kB)
[K     |████████████████████████████████| 245kB 54.6MB/s 
Installing collected packages: pyarrow, xxhash, datasets
  Found existing installation: pyarrow 0.14.1
    Uninstalling pyarrow-0.14.1:
      Successfully uninstalled pyarrow-0.14.1
Successfully installed datasets-1.0.2 py

In [None]:
import torch
import torch.nn as nn
import os
import matplotlib.pyplot as plt
import copy
import torch.optim as optim
import random
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup
from datasets import load_dataset, load_metric

os.environ["TOKENIZERS_PARALLELISM"] = "false"

## Loading the dataset

In [None]:
# Load the MRPC dataset (train, validation and test)
dataset = load_dataset('glue', 'mrpc')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=7826.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=4473.0, style=ProgressStyle(description…


Downloading and preparing dataset glue/mrpc (download: 1.43 MiB, generated: 1.43 MiB, post-processed: Unknown size, total: 2.85 MiB) to /root/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4...


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Downloading', max=1.0, style=ProgressSt…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Downloading', max=1.0, style=ProgressSt…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Downloading', max=1.0, style=ProgressSt…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset glue downloaded and prepared to /root/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4. Subsequent calls will reuse this data.


In [None]:
split = dataset['train'].train_test_split(test_size=0.1, seed=1)  # split the original training data for validation
train = split['train']  # 90 % of the original training data
val = split['test']   # 10 % of the original training data
test = dataset['validation']  # the original validation data is used as test data because the test labels are not available with the datasets library

# Transform data into pandas dataframes
df_train = pd.DataFrame(train)
df_val = pd.DataFrame(val)
df_test = pd.DataFrame(test)

In [None]:
df_train.shape, df_val.shape, df_test.shape

((3301, 4), (367, 4), (408, 4))

In [None]:
class CustomDataset(Dataset):

    def __init__(self, data, maxlen, with_labels=True, bert_model='albert-base-v2'):

        self.data = data  # pandas dataframe
        #Initialize the tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(bert_model)  

        self.maxlen = maxlen
        self.with_labels = with_labels 

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):

        # Selecting sentence1 and sentence2 at the specified index in the data frame
        sent1 = str(self.data.loc[index, 'sentence1'])
        sent2 = str(self.data.loc[index, 'sentence2'])

        # Tokenize the pair of sentences to get token ids, attention masks and token type ids
        encoded_pair = self.tokenizer(sent1, sent2, 
                                      padding='max_length',  # Pad to max_length
                                      truncation=True,  # Truncate to max_length
                                      max_length=self.maxlen,  
                                      return_tensors='pt')  # Return torch.Tensor objects
        
        token_ids = encoded_pair['input_ids'].squeeze(0)  # tensor of token ids
        attn_masks = encoded_pair['attention_mask'].squeeze(0)  # binary tensor with "0" for padded values and "1" for the other values
        token_type_ids = encoded_pair['token_type_ids'].squeeze(0)  # binary tensor with "0" for the 1st sentence tokens & "1" for the 2nd sentence tokens

        if self.with_labels:  # True if the dataset has labels
            label = self.data.loc[index, 'label']
            return token_ids, attn_masks, token_type_ids, label  
        else:
            return token_ids, attn_masks, token_type_ids

In [None]:
class SentencePairClassifier(nn.Module):

    def __init__(self, bert_model="albert-base-v2", freeze_bert=False):
        super(SentencePairClassifier, self).__init__()
        #  Instantiating BERT-based model object
        self.bert_layer = AutoModel.from_pretrained(bert_model)

        #  Fix the hidden-state size of the encoder outputs (If you want to add other pre-trained models here, search for the encoder output size)
        if bert_model == "albert-base-v2":  # 12M parameters
            hidden_size = 768
        elif bert_model == "albert-large-v2":  # 18M parameters
            hidden_size = 1024
        elif bert_model == "albert-xlarge-v2":  # 60M parameters
            hidden_size = 2048
        elif bert_model == "albert-xxlarge-v2":  # 235M parameters
            hidden_size = 4096
        elif bert_model == "bert-base-uncased": # 110M parameters
            hidden_size = 768

        # Freeze bert layers and only train the classification layer weights
        if freeze_bert:
            for p in self.bert_layer.parameters():
                p.requires_grad = False

        # Classification layer
        self.cls_layer = nn.Linear(hidden_size, 1)

        self.dropout = nn.Dropout(p=0.1)

    @autocast()  # run in mixed precision
    def forward(self, input_ids, attn_masks, token_type_ids):
        '''
        Inputs:
            -input_ids : Tensor  containing token ids
            -attn_masks : Tensor containing attention masks to be used to focus on non-padded values
            -token_type_ids : Tensor containing token type ids to be used to identify sentence1 and sentence2
        '''

        # Feeding the inputs to the BERT-based model to obtain contextualized representations
        cont_reps, pooler_output = self.bert_layer(input_ids, attn_masks, token_type_ids)

        # Feeding to the classifier layer the last layer hidden-state of the [CLS] token further processed by a
        # Linear Layer and a Tanh activation. The Linear layer weights were trained from the sentence order prediction (ALBERT) or next sentence prediction (BERT)
        # objective during pre-training.
        logits = self.cls_layer(self.dropout(pooler_output))

        return logits

In [None]:
def set_seed(seed):
    """ Set all seeds to make results reproducible """
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    

def evaluate_loss(net, device, criterion, dataloader):
    net.eval()

    mean_loss = 0
    count = 0

    with torch.no_grad():
        for it, (seq, attn_masks, token_type_ids, labels) in enumerate(tqdm(dataloader)):
            seq, attn_masks, token_type_ids, labels = \
                seq.to(device), attn_masks.to(device), token_type_ids.to(device), labels.to(device)
            logits = net(seq, attn_masks, token_type_ids)
            mean_loss += criterion(logits.squeeze(-1), labels.float()).item()
            count += 1

    return mean_loss / count

In [None]:
print("Creation of the models' folder...")
!mkdir models

Creation of the models' folder...


In [None]:
def train_bert(net, criterion, opti, lr, lr_scheduler, train_loader, val_loader, epochs, iters_to_accumulate):

    best_loss = np.Inf
    best_ep = 1
    nb_iterations = len(train_loader)
    print_every = nb_iterations // 5  # print the training loss 5 times per epoch
    iters = []
    train_losses = []
    val_losses = []

    scaler = GradScaler()

    for ep in range(epochs):

        net.train()
        running_loss = 0.0
        for it, (seq, attn_masks, token_type_ids, labels) in enumerate(tqdm(train_loader)):

            # Converting to cuda tensors
            seq, attn_masks, token_type_ids, labels = \
                seq.to(device), attn_masks.to(device), token_type_ids.to(device), labels.to(device)
    
            # Enables autocasting for the forward pass (model + loss)
            with autocast():
                # Obtaining the logits from the model
                logits = net(seq, attn_masks, token_type_ids)

                # Computing loss
                loss = criterion(logits.squeeze(-1), labels.float())
                loss = loss / iters_to_accumulate  # Normalize the loss because it is averaged

            # Backpropagating the gradients
            # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
            scaler.scale(loss).backward()

            if (it + 1) % iters_to_accumulate == 0:
                # Optimization step
                # scaler.step() first unscales the gradients of the optimizer's assigned params.
                # If these gradients do not contain infs or NaNs, opti.step() is then called,
                # otherwise, opti.step() is skipped.
                scaler.step(opti)
                # Updates the scale for next iteration.
                scaler.update()
                # Adjust the learning rate based on the number of iterations.
                lr_scheduler.step()
                # Clear gradients
                opti.zero_grad()


            running_loss += loss.item()

            if (it + 1) % print_every == 0:  # Print training loss information
                print()
                print("Iteration {}/{} of epoch {} complete. Loss : {} "
                      .format(it+1, nb_iterations, ep+1, running_loss / print_every))

                running_loss = 0.0


        val_loss = evaluate_loss(net, device, criterion, val_loader)  # Compute validation loss
        print()
        print("Epoch {} complete! Validation Loss : {}".format(ep+1, val_loss))

        if val_loss < best_loss:
            print("Best validation loss improved from {} to {}".format(best_loss, val_loss))
            print()
            net_copy = copy.deepcopy(net)  # save a copy of the model
            best_loss = val_loss
            best_ep = ep + 1

    # Saving the model
    path_to_model='models/{}_lr_{}_val_loss_{}_ep_{}.pt'.format(bert_model, lr, round(best_loss, 5), best_ep)
    torch.save(net_copy.state_dict(), path_to_model)
    print("The model has been saved in {}".format(path_to_model))

    del loss
    torch.cuda.empty_cache()

## Parameters

In [None]:
bert_model = "bert-base-uncased"  # 'albert-base-v2', 'albert-large-v2', 'albert-xlarge-v2', 'albert-xxlarge-v2', 'bert-base-uncased', ...
freeze_bert = False  # if True, freeze the encoder weights and only update the classification layer weights
maxlen = 128  # maximum length of the tokenized input sentence pair : if greater than "maxlen", the input is truncated and else if smaller, the input is padded
bs = 16  # batch size
iters_to_accumulate = 2  # the gradient accumulation adds gradients over an effective batch of size : bs * iters_to_accumulate. If set to "1", you get the usual batch size
lr = 2e-5  # learning rate
epochs = 40  # number of training epochs

## Training and validation

In [None]:
#  Set all seeds to make reproducible results
set_seed(1)

# Creating instances of training and validation set
print("Reading training data...")
train_set = CustomDataset(df_train, maxlen, bert_model)
print("Reading validation data...")
val_set = CustomDataset(df_val, maxlen, bert_model)
# Creating instances of training and validation dataloaders
train_loader = DataLoader(train_set, batch_size=bs, num_workers=5)
val_loader = DataLoader(val_set, batch_size=bs, num_workers=5)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = SentencePairClassifier(bert_model, freeze_bert=freeze_bert)

if torch.cuda.device_count() > 1:  # if multiple GPUs
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    net = nn.DataParallel(net)

net.to(device)

criterion = nn.BCEWithLogitsLoss()
opti = AdamW(net.parameters(), lr=lr, weight_decay=1e-2)
num_warmup_steps = 0 # The number of steps for the warmup phase.
num_training_steps = epochs * len(train_loader)  # The total number of training steps
t_total = (len(train_loader) // iters_to_accumulate) * epochs  # Necessary to take into account Gradient accumulation
lr_scheduler = get_linear_schedule_with_warmup(optimizer=opti, num_warmup_steps=num_warmup_steps, num_training_steps=t_total)

train_bert(net, criterion, opti, lr, lr_scheduler, train_loader, val_loader, epochs, iters_to_accumulate)

Reading training data...
Reading validation data...


 20%|██        | 42/207 [00:06<00:23,  7.03it/s]


Iteration 41/207 of epoch 1 complete. Loss : 0.32112814376993876 


 40%|████      | 83/207 [00:12<00:16,  7.38it/s]


Iteration 82/207 of epoch 1 complete. Loss : 0.30877164478709057 


 60%|█████▉    | 124/207 [00:17<00:11,  6.92it/s]


Iteration 123/207 of epoch 1 complete. Loss : 0.3205648489841601 


 80%|███████▉  | 165/207 [00:23<00:05,  7.28it/s]


Iteration 164/207 of epoch 1 complete. Loss : 0.32069088118832284 


100%|█████████▉| 206/207 [00:29<00:00,  6.92it/s]


Iteration 205/207 of epoch 1 complete. Loss : 0.31949516076867174 


100%|██████████| 207/207 [00:29<00:00,  7.00it/s]
100%|██████████| 23/23 [00:01<00:00, 15.28it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 1 complete! Validation Loss : 0.6329109539156375
Best validation loss improved from inf to 0.6329109539156375



 20%|██        | 42/207 [00:06<00:23,  6.88it/s]


Iteration 41/207 of epoch 2 complete. Loss : 0.3156856763653639 


 40%|████      | 83/207 [00:11<00:17,  7.23it/s]


Iteration 82/207 of epoch 2 complete. Loss : 0.3078028167166361 


 60%|█████▉    | 124/207 [00:17<00:12,  6.75it/s]


Iteration 123/207 of epoch 2 complete. Loss : 0.3189999504787166 


 80%|███████▉  | 165/207 [00:23<00:05,  7.15it/s]


Iteration 164/207 of epoch 2 complete. Loss : 0.3152979452435563 


100%|█████████▉| 206/207 [00:29<00:00,  6.72it/s]


Iteration 205/207 of epoch 2 complete. Loss : 0.3030249596368976 


100%|██████████| 207/207 [00:29<00:00,  6.94it/s]
100%|██████████| 23/23 [00:01<00:00, 14.83it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 2 complete! Validation Loss : 0.6186080393583878
Best validation loss improved from 0.6329109539156375 to 0.6186080393583878



 20%|██        | 42/207 [00:06<00:24,  6.74it/s]


Iteration 41/207 of epoch 3 complete. Loss : 0.28901127198847326 


 40%|████      | 83/207 [00:12<00:17,  7.03it/s]


Iteration 82/207 of epoch 3 complete. Loss : 0.27627041245379097 


 60%|█████▉    | 124/207 [00:18<00:12,  6.69it/s]


Iteration 123/207 of epoch 3 complete. Loss : 0.27151592584644874 


 80%|███████▉  | 165/207 [00:24<00:05,  7.00it/s]


Iteration 164/207 of epoch 3 complete. Loss : 0.28507288272787884 


100%|█████████▉| 206/207 [00:30<00:00,  6.66it/s]


Iteration 205/207 of epoch 3 complete. Loss : 0.27436257753430343 


100%|██████████| 207/207 [00:30<00:00,  6.80it/s]
100%|██████████| 23/23 [00:01<00:00, 14.76it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 3 complete! Validation Loss : 0.5813120085260143
Best validation loss improved from 0.6186080393583878 to 0.5813120085260143



 20%|██        | 42/207 [00:06<00:24,  6.67it/s]


Iteration 41/207 of epoch 4 complete. Loss : 0.26083028679940756 


 40%|████      | 83/207 [00:12<00:17,  7.04it/s]


Iteration 82/207 of epoch 4 complete. Loss : 0.23671796518128094 


 60%|█████▉    | 124/207 [00:18<00:12,  6.70it/s]


Iteration 123/207 of epoch 4 complete. Loss : 0.22489490269160853 


 80%|███████▉  | 165/207 [00:24<00:05,  7.10it/s]


Iteration 164/207 of epoch 4 complete. Loss : 0.24917337189360364 


100%|█████████▉| 206/207 [00:30<00:00,  6.76it/s]


Iteration 205/207 of epoch 4 complete. Loss : 0.20789500780221892 


100%|██████████| 207/207 [00:30<00:00,  6.82it/s]
100%|██████████| 23/23 [00:01<00:00, 15.09it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 4 complete! Validation Loss : 0.71735352796057


 20%|██        | 42/207 [00:06<00:24,  6.70it/s]


Iteration 41/207 of epoch 5 complete. Loss : 0.2044951277898579 


 40%|████      | 83/207 [00:12<00:17,  7.08it/s]


Iteration 82/207 of epoch 5 complete. Loss : 0.18535924484816993 


 60%|█████▉    | 124/207 [00:18<00:12,  6.74it/s]


Iteration 123/207 of epoch 5 complete. Loss : 0.13820694078032564 


 80%|███████▉  | 165/207 [00:24<00:05,  7.04it/s]


Iteration 164/207 of epoch 5 complete. Loss : 0.19404048054683498 


100%|█████████▉| 206/207 [00:30<00:00,  6.71it/s]


Iteration 205/207 of epoch 5 complete. Loss : 0.14455784866359175 


100%|██████████| 207/207 [00:30<00:00,  6.82it/s]
100%|██████████| 23/23 [00:01<00:00, 14.61it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 5 complete! Validation Loss : 0.7413181532984194


 20%|██        | 42/207 [00:06<00:24,  6.69it/s]


Iteration 41/207 of epoch 6 complete. Loss : 0.16316540634668456 


 40%|████      | 83/207 [00:12<00:17,  7.04it/s]


Iteration 82/207 of epoch 6 complete. Loss : 0.14616214156877705 


 60%|█████▉    | 124/207 [00:18<00:12,  6.65it/s]


Iteration 123/207 of epoch 6 complete. Loss : 0.10868138029444509 


 80%|███████▉  | 165/207 [00:24<00:05,  7.08it/s]


Iteration 164/207 of epoch 6 complete. Loss : 0.15247259061874413 


100%|█████████▉| 206/207 [00:30<00:00,  6.72it/s]


Iteration 205/207 of epoch 6 complete. Loss : 0.1052927303059799 


100%|██████████| 207/207 [00:30<00:00,  6.80it/s]
100%|██████████| 23/23 [00:01<00:00, 14.74it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 6 complete! Validation Loss : 0.8709487681803496


 20%|██        | 42/207 [00:06<00:24,  6.69it/s]


Iteration 41/207 of epoch 7 complete. Loss : 0.11343011800654051 


 40%|████      | 83/207 [00:12<00:17,  7.07it/s]


Iteration 82/207 of epoch 7 complete. Loss : 0.12160487118654134 


 60%|█████▉    | 124/207 [00:18<00:12,  6.71it/s]


Iteration 123/207 of epoch 7 complete. Loss : 0.10547743710439379 


 80%|███████▉  | 165/207 [00:24<00:05,  7.09it/s]


Iteration 164/207 of epoch 7 complete. Loss : 0.08820866950128864 


100%|█████████▉| 206/207 [00:30<00:00,  6.74it/s]


Iteration 205/207 of epoch 7 complete. Loss : 0.07243914057205363 


100%|██████████| 207/207 [00:30<00:00,  6.82it/s]
100%|██████████| 23/23 [00:01<00:00, 14.92it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 7 complete! Validation Loss : 1.0012654273406318


 20%|██        | 42/207 [00:06<00:24,  6.71it/s]


Iteration 41/207 of epoch 8 complete. Loss : 0.10634114720472475 


 40%|████      | 83/207 [00:12<00:17,  7.09it/s]


Iteration 82/207 of epoch 8 complete. Loss : 0.08753224619005512 


 60%|█████▉    | 124/207 [00:18<00:12,  6.67it/s]


Iteration 123/207 of epoch 8 complete. Loss : 0.04561210314675075 


 80%|███████▉  | 165/207 [00:24<00:05,  7.05it/s]


Iteration 164/207 of epoch 8 complete. Loss : 0.09248608164489269 


100%|█████████▉| 206/207 [00:30<00:00,  6.69it/s]


Iteration 205/207 of epoch 8 complete. Loss : 0.06935945861950153 


100%|██████████| 207/207 [00:30<00:00,  6.81it/s]
100%|██████████| 23/23 [00:01<00:00, 14.73it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 8 complete! Validation Loss : 0.9078056915946628


 20%|██        | 42/207 [00:06<00:24,  6.72it/s]


Iteration 41/207 of epoch 9 complete. Loss : 0.08113490911654947 


 40%|████      | 83/207 [00:12<00:17,  7.05it/s]


Iteration 82/207 of epoch 9 complete. Loss : 0.06027010806677181 


 60%|█████▉    | 124/207 [00:18<00:12,  6.72it/s]


Iteration 123/207 of epoch 9 complete. Loss : 0.03937929880055713 


 80%|███████▉  | 165/207 [00:24<00:05,  7.10it/s]


Iteration 164/207 of epoch 9 complete. Loss : 0.06185097949261346 


100%|█████████▉| 206/207 [00:30<00:00,  6.72it/s]


Iteration 205/207 of epoch 9 complete. Loss : 0.04792236956972175 


100%|██████████| 207/207 [00:30<00:00,  6.80it/s]
100%|██████████| 23/23 [00:01<00:00, 14.75it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 9 complete! Validation Loss : 1.0280159400857014


 20%|██        | 42/207 [00:06<00:25,  6.54it/s]


Iteration 41/207 of epoch 10 complete. Loss : 0.06820270306680624 


 40%|████      | 83/207 [00:12<00:17,  7.09it/s]


Iteration 82/207 of epoch 10 complete. Loss : 0.06651495710560461 


 60%|█████▉    | 124/207 [00:18<00:12,  6.66it/s]


Iteration 123/207 of epoch 10 complete. Loss : 0.04254608944331 


 80%|███████▉  | 165/207 [00:24<00:06,  6.99it/s]


Iteration 164/207 of epoch 10 complete. Loss : 0.04105585248482118 


100%|█████████▉| 206/207 [00:30<00:00,  6.73it/s]


Iteration 205/207 of epoch 10 complete. Loss : 0.03863925423200538 


100%|██████████| 207/207 [00:30<00:00,  6.79it/s]
100%|██████████| 23/23 [00:01<00:00, 15.20it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 10 complete! Validation Loss : 1.1149022734683494


 20%|██        | 42/207 [00:06<00:24,  6.66it/s]


Iteration 41/207 of epoch 11 complete. Loss : 0.04767681239740696 


 40%|████      | 83/207 [00:12<00:17,  7.08it/s]


Iteration 82/207 of epoch 11 complete. Loss : 0.04967811193726048 


 60%|█████▉    | 124/207 [00:18<00:12,  6.71it/s]


Iteration 123/207 of epoch 11 complete. Loss : 0.05208295671178437 


 80%|███████▉  | 165/207 [00:24<00:05,  7.08it/s]


Iteration 164/207 of epoch 11 complete. Loss : 0.04883963779387314 


100%|█████████▉| 206/207 [00:30<00:00,  6.72it/s]


Iteration 205/207 of epoch 11 complete. Loss : 0.032999809996065936 


100%|██████████| 207/207 [00:30<00:00,  6.80it/s]
100%|██████████| 23/23 [00:01<00:00, 15.07it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 11 complete! Validation Loss : 1.1476793341014697


 20%|██        | 42/207 [00:06<00:24,  6.71it/s]


Iteration 41/207 of epoch 12 complete. Loss : 0.02698008561075279 


 40%|████      | 83/207 [00:12<00:17,  7.10it/s]


Iteration 82/207 of epoch 12 complete. Loss : 0.03982296709266559 


 60%|█████▉    | 124/207 [00:18<00:12,  6.69it/s]


Iteration 123/207 of epoch 12 complete. Loss : 0.06861485427290928 


 80%|███████▉  | 165/207 [00:24<00:05,  7.04it/s]


Iteration 164/207 of epoch 12 complete. Loss : 0.04912553867325187 


100%|█████████▉| 206/207 [00:30<00:00,  6.75it/s]


Iteration 205/207 of epoch 12 complete. Loss : 0.024305137782925514 


100%|██████████| 207/207 [00:30<00:00,  6.80it/s]
100%|██████████| 23/23 [00:01<00:00, 14.84it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 12 complete! Validation Loss : 1.1986212315766707


 20%|██        | 42/207 [00:06<00:24,  6.70it/s]


Iteration 41/207 of epoch 13 complete. Loss : 0.026605586237387686 


 40%|████      | 83/207 [00:12<00:17,  7.03it/s]


Iteration 82/207 of epoch 13 complete. Loss : 0.020837255114712183 


 60%|█████▉    | 124/207 [00:18<00:12,  6.72it/s]


Iteration 123/207 of epoch 13 complete. Loss : 0.01494510603284963 


 80%|███████▉  | 165/207 [00:24<00:05,  7.04it/s]


Iteration 164/207 of epoch 13 complete. Loss : 0.03452988156321936 


100%|█████████▉| 206/207 [00:30<00:00,  6.72it/s]


Iteration 205/207 of epoch 13 complete. Loss : 0.013666173977005045 


100%|██████████| 207/207 [00:30<00:00,  6.80it/s]
100%|██████████| 23/23 [00:01<00:00, 14.96it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 13 complete! Validation Loss : 1.345153378403705


 20%|██        | 42/207 [00:06<00:24,  6.74it/s]


Iteration 41/207 of epoch 14 complete. Loss : 0.014679763599571476 


 40%|████      | 83/207 [00:12<00:17,  7.04it/s]


Iteration 82/207 of epoch 14 complete. Loss : 0.012888007243618188 


 60%|█████▉    | 124/207 [00:18<00:12,  6.70it/s]


Iteration 123/207 of epoch 14 complete. Loss : 0.01108391408626808 


 80%|███████▉  | 165/207 [00:24<00:05,  7.06it/s]


Iteration 164/207 of epoch 14 complete. Loss : 0.018078646395446325 


100%|█████████▉| 206/207 [00:30<00:00,  6.72it/s]


Iteration 205/207 of epoch 14 complete. Loss : 0.017533069697958304 


100%|██████████| 207/207 [00:30<00:00,  6.80it/s]
100%|██████████| 23/23 [00:01<00:00, 14.70it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 14 complete! Validation Loss : 1.4817652650501416


 20%|██        | 42/207 [00:06<00:24,  6.70it/s]


Iteration 41/207 of epoch 15 complete. Loss : 0.020221349087589217 


 40%|████      | 83/207 [00:12<00:17,  6.98it/s]


Iteration 82/207 of epoch 15 complete. Loss : 0.019758068587312976 


 60%|█████▉    | 124/207 [00:18<00:12,  6.71it/s]


Iteration 123/207 of epoch 15 complete. Loss : 0.009222203441157302 


 80%|███████▉  | 165/207 [00:24<00:06,  6.98it/s]


Iteration 164/207 of epoch 15 complete. Loss : 0.011283397495678468 


100%|█████████▉| 206/207 [00:30<00:00,  6.75it/s]


Iteration 205/207 of epoch 15 complete. Loss : 0.012381493422009686 


100%|██████████| 207/207 [00:30<00:00,  6.81it/s]
100%|██████████| 23/23 [00:01<00:00, 14.92it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 15 complete! Validation Loss : 1.4094493855600772


 20%|██        | 42/207 [00:06<00:24,  6.68it/s]


Iteration 41/207 of epoch 16 complete. Loss : 0.02500398538257109 


 40%|████      | 83/207 [00:12<00:17,  7.04it/s]


Iteration 82/207 of epoch 16 complete. Loss : 0.00723469077620837 


 60%|█████▉    | 124/207 [00:18<00:12,  6.71it/s]


Iteration 123/207 of epoch 16 complete. Loss : 0.01158290978509751 


 80%|███████▉  | 165/207 [00:24<00:05,  7.08it/s]


Iteration 164/207 of epoch 16 complete. Loss : 0.011833045979473376 


100%|█████████▉| 206/207 [00:30<00:00,  6.74it/s]


Iteration 205/207 of epoch 16 complete. Loss : 0.011173940735028648 


100%|██████████| 207/207 [00:30<00:00,  6.80it/s]
100%|██████████| 23/23 [00:01<00:00, 15.07it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 16 complete! Validation Loss : 1.5692990785059722


 20%|██        | 42/207 [00:06<00:24,  6.71it/s]


Iteration 41/207 of epoch 17 complete. Loss : 0.017920809416043595 


 40%|████      | 83/207 [00:12<00:17,  7.12it/s]


Iteration 82/207 of epoch 17 complete. Loss : 0.005101858469282799 


 60%|█████▉    | 124/207 [00:18<00:12,  6.73it/s]


Iteration 123/207 of epoch 17 complete. Loss : 0.0037636175217879253 


 80%|███████▉  | 165/207 [00:24<00:05,  7.07it/s]


Iteration 164/207 of epoch 17 complete. Loss : 0.014054008258404437 


100%|█████████▉| 206/207 [00:30<00:00,  6.75it/s]


Iteration 205/207 of epoch 17 complete. Loss : 0.009133827333088691 


100%|██████████| 207/207 [00:30<00:00,  6.81it/s]
100%|██████████| 23/23 [00:01<00:00, 14.78it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 17 complete! Validation Loss : 1.56295173841974


 20%|██        | 42/207 [00:06<00:24,  6.67it/s]


Iteration 41/207 of epoch 18 complete. Loss : 0.016535548044641208 


 40%|████      | 83/207 [00:12<00:17,  7.09it/s]


Iteration 82/207 of epoch 18 complete. Loss : 0.012074039571793614 


 60%|█████▉    | 124/207 [00:18<00:12,  6.72it/s]


Iteration 123/207 of epoch 18 complete. Loss : 0.005291798647504482 


 80%|███████▉  | 165/207 [00:24<00:05,  7.06it/s]


Iteration 164/207 of epoch 18 complete. Loss : 0.00885139967786239 


100%|█████████▉| 206/207 [00:30<00:00,  6.73it/s]


Iteration 205/207 of epoch 18 complete. Loss : 0.007791779533585124 


100%|██████████| 207/207 [00:30<00:00,  6.81it/s]
100%|██████████| 23/23 [00:01<00:00, 15.10it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 18 complete! Validation Loss : 1.6050252758938333


 20%|██        | 42/207 [00:06<00:24,  6.72it/s]


Iteration 41/207 of epoch 19 complete. Loss : 0.007531494533972497 


 40%|████      | 83/207 [00:12<00:17,  7.07it/s]


Iteration 82/207 of epoch 19 complete. Loss : 0.00445650157904843 


 60%|█████▉    | 124/207 [00:18<00:12,  6.55it/s]


Iteration 123/207 of epoch 19 complete. Loss : 0.00408067188520984 


 80%|███████▉  | 165/207 [00:24<00:05,  7.08it/s]


Iteration 164/207 of epoch 19 complete. Loss : 0.0077819054599896804 


100%|█████████▉| 206/207 [00:30<00:00,  6.72it/s]


Iteration 205/207 of epoch 19 complete. Loss : 0.007058916440256304 


100%|██████████| 207/207 [00:30<00:00,  6.82it/s]
100%|██████████| 23/23 [00:01<00:00, 15.03it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 19 complete! Validation Loss : 1.7014000856358071


 20%|██        | 42/207 [00:06<00:24,  6.71it/s]


Iteration 41/207 of epoch 20 complete. Loss : 0.00863617473060447 


 40%|████      | 83/207 [00:12<00:17,  7.06it/s]


Iteration 82/207 of epoch 20 complete. Loss : 0.0025900873858140916 


 60%|█████▉    | 124/207 [00:18<00:12,  6.55it/s]


Iteration 123/207 of epoch 20 complete. Loss : 0.008407070458349885 


 80%|███████▉  | 165/207 [00:24<00:05,  7.11it/s]


Iteration 164/207 of epoch 20 complete. Loss : 0.01542842230566482 


100%|█████████▉| 206/207 [00:30<00:00,  6.74it/s]


Iteration 205/207 of epoch 20 complete. Loss : 0.001801445817563501 


100%|██████████| 207/207 [00:30<00:00,  6.81it/s]
100%|██████████| 23/23 [00:01<00:00, 14.86it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 20 complete! Validation Loss : 1.7298839351405269


 20%|██        | 42/207 [00:06<00:24,  6.71it/s]


Iteration 41/207 of epoch 21 complete. Loss : 0.00865847445604187 


 40%|████      | 83/207 [00:12<00:17,  7.04it/s]


Iteration 82/207 of epoch 21 complete. Loss : 0.002297908414819664 


 60%|█████▉    | 124/207 [00:18<00:12,  6.69it/s]


Iteration 123/207 of epoch 21 complete. Loss : 0.00233453088482592 


 80%|███████▉  | 165/207 [00:24<00:05,  7.13it/s]


Iteration 164/207 of epoch 21 complete. Loss : 0.006763548630704285 


100%|█████████▉| 206/207 [00:30<00:00,  6.67it/s]


Iteration 205/207 of epoch 21 complete. Loss : 0.00099098426928721 


100%|██████████| 207/207 [00:30<00:00,  6.82it/s]
100%|██████████| 23/23 [00:01<00:00, 14.92it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 21 complete! Validation Loss : 1.8765859189240828


 20%|██        | 42/207 [00:06<00:24,  6.69it/s]


Iteration 41/207 of epoch 22 complete. Loss : 0.00979941600377149 


 40%|████      | 83/207 [00:12<00:17,  7.02it/s]


Iteration 82/207 of epoch 22 complete. Loss : 0.00165204818164589 


 60%|█████▉    | 124/207 [00:18<00:12,  6.68it/s]


Iteration 123/207 of epoch 22 complete. Loss : 0.006093355469979209 


 80%|███████▉  | 165/207 [00:24<00:05,  7.03it/s]


Iteration 164/207 of epoch 22 complete. Loss : 0.014141609400323388 


100%|█████████▉| 206/207 [00:30<00:00,  6.70it/s]


Iteration 205/207 of epoch 22 complete. Loss : 0.009320308467112027 


100%|██████████| 207/207 [00:30<00:00,  6.81it/s]
100%|██████████| 23/23 [00:01<00:00, 15.01it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 22 complete! Validation Loss : 1.8138053002564802


 20%|██        | 42/207 [00:06<00:24,  6.73it/s]


Iteration 41/207 of epoch 23 complete. Loss : 0.008506560246687292 


 40%|████      | 83/207 [00:12<00:17,  7.10it/s]


Iteration 82/207 of epoch 23 complete. Loss : 0.0056690377539603 


 60%|█████▉    | 124/207 [00:18<00:12,  6.72it/s]


Iteration 123/207 of epoch 23 complete. Loss : 0.0020001751445972065 


 80%|███████▉  | 165/207 [00:24<00:05,  7.05it/s]


Iteration 164/207 of epoch 23 complete. Loss : 0.008967410197350902 


100%|█████████▉| 206/207 [00:30<00:00,  6.75it/s]


Iteration 205/207 of epoch 23 complete. Loss : 0.0030607719104964195 


100%|██████████| 207/207 [00:30<00:00,  6.82it/s]
100%|██████████| 23/23 [00:01<00:00, 14.90it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 23 complete! Validation Loss : 1.848224683948185


 20%|██        | 42/207 [00:06<00:24,  6.71it/s]


Iteration 41/207 of epoch 24 complete. Loss : 0.010917090433195414 


 40%|████      | 83/207 [00:12<00:17,  7.07it/s]


Iteration 82/207 of epoch 24 complete. Loss : 0.0013549143978019767 


 60%|█████▉    | 124/207 [00:18<00:12,  6.72it/s]


Iteration 123/207 of epoch 24 complete. Loss : 0.014649071722864969 


 80%|███████▉  | 165/207 [00:24<00:05,  7.05it/s]


Iteration 164/207 of epoch 24 complete. Loss : 0.013891640556521895 


100%|█████████▉| 206/207 [00:30<00:00,  6.74it/s]


Iteration 205/207 of epoch 24 complete. Loss : 0.003697670920945068 


100%|██████████| 207/207 [00:30<00:00,  6.82it/s]
100%|██████████| 23/23 [00:01<00:00, 15.08it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 24 complete! Validation Loss : 1.758654869121054


 20%|██        | 42/207 [00:06<00:24,  6.68it/s]


Iteration 41/207 of epoch 25 complete. Loss : 0.012093019850620227 


 40%|████      | 83/207 [00:12<00:17,  7.08it/s]


Iteration 82/207 of epoch 25 complete. Loss : 0.00446827864847941 


 60%|█████▉    | 124/207 [00:18<00:12,  6.69it/s]


Iteration 123/207 of epoch 25 complete. Loss : 0.00422964566571797 


 80%|███████▉  | 165/207 [00:24<00:05,  7.08it/s]


Iteration 164/207 of epoch 25 complete. Loss : 0.00619210933460059 


100%|█████████▉| 206/207 [00:30<00:00,  6.76it/s]


Iteration 205/207 of epoch 25 complete. Loss : 0.0010987901641339882 


100%|██████████| 207/207 [00:30<00:00,  6.82it/s]
100%|██████████| 23/23 [00:01<00:00, 14.64it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 25 complete! Validation Loss : 1.81052537586378


 20%|██        | 42/207 [00:06<00:24,  6.68it/s]


Iteration 41/207 of epoch 26 complete. Loss : 0.006273055531470696 


 40%|████      | 83/207 [00:12<00:17,  7.05it/s]


Iteration 82/207 of epoch 26 complete. Loss : 0.001186007070894603 


 60%|█████▉    | 124/207 [00:18<00:12,  6.73it/s]


Iteration 123/207 of epoch 26 complete. Loss : 0.0010985636774155244 


 80%|███████▉  | 165/207 [00:24<00:05,  7.08it/s]


Iteration 164/207 of epoch 26 complete. Loss : 0.005778116591433745 


100%|█████████▉| 206/207 [00:30<00:00,  6.75it/s]


Iteration 205/207 of epoch 26 complete. Loss : 0.0008266042412553982 


100%|██████████| 207/207 [00:30<00:00,  6.82it/s]
100%|██████████| 23/23 [00:01<00:00, 14.85it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 26 complete! Validation Loss : 1.9621469922687695


 20%|██        | 42/207 [00:06<00:24,  6.75it/s]


Iteration 41/207 of epoch 27 complete. Loss : 0.009468425198020868 


 40%|████      | 83/207 [00:12<00:17,  7.10it/s]


Iteration 82/207 of epoch 27 complete. Loss : 0.000862419072158135 


 60%|█████▉    | 124/207 [00:18<00:12,  6.72it/s]


Iteration 123/207 of epoch 27 complete. Loss : 0.0012095108246373966 


 80%|███████▉  | 165/207 [00:24<00:05,  7.03it/s]


Iteration 164/207 of epoch 27 complete. Loss : 0.009058928427360271 


100%|█████████▉| 206/207 [00:30<00:00,  6.74it/s]


Iteration 205/207 of epoch 27 complete. Loss : 0.002330914468407949 


100%|██████████| 207/207 [00:30<00:00,  6.82it/s]
100%|██████████| 23/23 [00:01<00:00, 14.93it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 27 complete! Validation Loss : 1.8795790594557058


 20%|██        | 42/207 [00:06<00:24,  6.71it/s]


Iteration 41/207 of epoch 28 complete. Loss : 0.011804998776358665 


 40%|████      | 83/207 [00:12<00:17,  7.05it/s]


Iteration 82/207 of epoch 28 complete. Loss : 0.0030979915286368895 


 60%|█████▉    | 124/207 [00:18<00:12,  6.72it/s]


Iteration 123/207 of epoch 28 complete. Loss : 0.0042802658600721325 


 80%|███████▉  | 165/207 [00:24<00:05,  7.09it/s]


Iteration 164/207 of epoch 28 complete. Loss : 0.007133510397468898 


100%|█████████▉| 206/207 [00:30<00:00,  6.73it/s]


Iteration 205/207 of epoch 28 complete. Loss : 0.0066166959614947255 


100%|██████████| 207/207 [00:30<00:00,  6.82it/s]
100%|██████████| 23/23 [00:01<00:00, 14.92it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 28 complete! Validation Loss : 1.76135748365651


 20%|██        | 42/207 [00:06<00:24,  6.70it/s]


Iteration 41/207 of epoch 29 complete. Loss : 0.006830921223409838 


 40%|████      | 83/207 [00:12<00:17,  7.07it/s]


Iteration 82/207 of epoch 29 complete. Loss : 0.0021350060109753253 


 60%|█████▉    | 124/207 [00:18<00:12,  6.75it/s]


Iteration 123/207 of epoch 29 complete. Loss : 0.0004604013406717014 


 80%|███████▉  | 165/207 [00:24<00:05,  7.06it/s]


Iteration 164/207 of epoch 29 complete. Loss : 0.005664794044938256 


100%|█████████▉| 206/207 [00:30<00:00,  6.74it/s]


Iteration 205/207 of epoch 29 complete. Loss : 0.0016161300804147997 


100%|██████████| 207/207 [00:30<00:00,  6.82it/s]
100%|██████████| 23/23 [00:01<00:00, 15.12it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 29 complete! Validation Loss : 1.8392939463905666


 20%|██        | 42/207 [00:06<00:24,  6.72it/s]


Iteration 41/207 of epoch 30 complete. Loss : 0.009002195434111012 


 40%|████      | 83/207 [00:12<00:17,  7.08it/s]


Iteration 82/207 of epoch 30 complete. Loss : 0.0010802623291965574 


 60%|█████▉    | 124/207 [00:18<00:12,  6.66it/s]


Iteration 123/207 of epoch 30 complete. Loss : 0.0034615427550927897 


 80%|███████▉  | 165/207 [00:24<00:05,  7.08it/s]


Iteration 164/207 of epoch 30 complete. Loss : 0.0066232314019566176 


100%|█████████▉| 206/207 [00:30<00:00,  6.73it/s]


Iteration 205/207 of epoch 30 complete. Loss : 0.0010366236298379103 


100%|██████████| 207/207 [00:30<00:00,  6.81it/s]
100%|██████████| 23/23 [00:01<00:00, 14.82it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 30 complete! Validation Loss : 1.8649245733800142


 20%|██        | 42/207 [00:06<00:24,  6.74it/s]


Iteration 41/207 of epoch 31 complete. Loss : 0.00592204896099979 


 40%|████      | 83/207 [00:12<00:17,  7.07it/s]


Iteration 82/207 of epoch 31 complete. Loss : 0.0022523496087027213 


 60%|█████▉    | 124/207 [00:18<00:12,  6.64it/s]


Iteration 123/207 of epoch 31 complete. Loss : 0.0005381885524157708 


 80%|███████▉  | 165/207 [00:24<00:05,  7.03it/s]


Iteration 164/207 of epoch 31 complete. Loss : 0.005614469162892641 


100%|█████████▉| 206/207 [00:30<00:00,  6.76it/s]


Iteration 205/207 of epoch 31 complete. Loss : 0.0046532312606847506 


100%|██████████| 207/207 [00:30<00:00,  6.84it/s]
100%|██████████| 23/23 [00:01<00:00, 14.67it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 31 complete! Validation Loss : 1.9111505798671558


 20%|██        | 42/207 [00:06<00:24,  6.71it/s]


Iteration 41/207 of epoch 32 complete. Loss : 0.005623544918418657 


 40%|████      | 83/207 [00:12<00:17,  7.08it/s]


Iteration 82/207 of epoch 32 complete. Loss : 0.0009764651803402003 


 60%|█████▉    | 124/207 [00:18<00:12,  6.74it/s]


Iteration 123/207 of epoch 32 complete. Loss : 0.0005689125995026765 


 80%|███████▉  | 165/207 [00:24<00:05,  7.09it/s]


Iteration 164/207 of epoch 32 complete. Loss : 0.0056913587876606915 


100%|█████████▉| 206/207 [00:30<00:00,  6.77it/s]


Iteration 205/207 of epoch 32 complete. Loss : 0.003470207102420717 


100%|██████████| 207/207 [00:30<00:00,  6.82it/s]
100%|██████████| 23/23 [00:01<00:00, 15.01it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 32 complete! Validation Loss : 1.9255807555240134


 20%|██        | 42/207 [00:06<00:24,  6.73it/s]


Iteration 41/207 of epoch 33 complete. Loss : 0.00544673211946402 


 40%|████      | 83/207 [00:12<00:17,  7.10it/s]


Iteration 82/207 of epoch 33 complete. Loss : 0.003665677530878428 


 60%|█████▉    | 124/207 [00:18<00:12,  6.73it/s]


Iteration 123/207 of epoch 33 complete. Loss : 0.0022956835228109323 


 80%|███████▉  | 165/207 [00:24<00:05,  7.05it/s]


Iteration 164/207 of epoch 33 complete. Loss : 0.005959652970465491 


100%|█████████▉| 206/207 [00:30<00:00,  6.72it/s]


Iteration 205/207 of epoch 33 complete. Loss : 0.0008620744480229006 


100%|██████████| 207/207 [00:30<00:00,  6.82it/s]
100%|██████████| 23/23 [00:01<00:00, 14.89it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 33 complete! Validation Loss : 1.9214306370071743


 20%|██        | 42/207 [00:06<00:24,  6.65it/s]


Iteration 41/207 of epoch 34 complete. Loss : 0.005670827567838587 


 40%|████      | 83/207 [00:12<00:17,  7.04it/s]


Iteration 82/207 of epoch 34 complete. Loss : 0.0007303946327826962 


 60%|█████▉    | 124/207 [00:18<00:12,  6.72it/s]


Iteration 123/207 of epoch 34 complete. Loss : 0.001381076019819508 


 80%|███████▉  | 165/207 [00:24<00:05,  7.07it/s]


Iteration 164/207 of epoch 34 complete. Loss : 0.006525113276438788 


100%|█████████▉| 206/207 [00:30<00:00,  6.66it/s]


Iteration 205/207 of epoch 34 complete. Loss : 0.0010092897786440828 


100%|██████████| 207/207 [00:30<00:00,  6.82it/s]
100%|██████████| 23/23 [00:01<00:00, 14.90it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 34 complete! Validation Loss : 1.8918043167694756


 20%|██        | 42/207 [00:06<00:24,  6.64it/s]


Iteration 41/207 of epoch 35 complete. Loss : 0.010238733821230509 


 40%|████      | 83/207 [00:12<00:17,  7.05it/s]


Iteration 82/207 of epoch 35 complete. Loss : 0.0014183332642438116 


 60%|█████▉    | 124/207 [00:18<00:12,  6.71it/s]


Iteration 123/207 of epoch 35 complete. Loss : 0.0006455102496746382 


 80%|███████▉  | 165/207 [00:24<00:05,  7.03it/s]


Iteration 164/207 of epoch 35 complete. Loss : 0.0047555845287605755 


100%|█████████▉| 206/207 [00:30<00:00,  6.75it/s]


Iteration 205/207 of epoch 35 complete. Loss : 0.000504778086906299 


100%|██████████| 207/207 [00:30<00:00,  6.80it/s]
100%|██████████| 23/23 [00:01<00:00, 15.00it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 35 complete! Validation Loss : 1.9726414317670076


 20%|██        | 42/207 [00:06<00:24,  6.65it/s]


Iteration 41/207 of epoch 36 complete. Loss : 0.005541497321157694 


 40%|████      | 83/207 [00:12<00:17,  7.09it/s]


Iteration 82/207 of epoch 36 complete. Loss : 0.0006377013391678835 


 60%|█████▉    | 124/207 [00:18<00:12,  6.68it/s]


Iteration 123/207 of epoch 36 complete. Loss : 0.0005754692744032094 


 80%|███████▉  | 165/207 [00:24<00:05,  7.06it/s]


Iteration 164/207 of epoch 36 complete. Loss : 0.00410945877424305 


100%|█████████▉| 206/207 [00:30<00:00,  6.73it/s]


Iteration 205/207 of epoch 36 complete. Loss : 0.00046413308137278187 


100%|██████████| 207/207 [00:30<00:00,  6.82it/s]
100%|██████████| 23/23 [00:01<00:00, 14.77it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 36 complete! Validation Loss : 1.971477464489315


 20%|██        | 42/207 [00:06<00:24,  6.68it/s]


Iteration 41/207 of epoch 37 complete. Loss : 0.005125503396479095 


 40%|████      | 83/207 [00:12<00:17,  7.05it/s]


Iteration 82/207 of epoch 37 complete. Loss : 0.0009128747929580418 


 60%|█████▉    | 124/207 [00:18<00:12,  6.65it/s]


Iteration 123/207 of epoch 37 complete. Loss : 0.0007945236390763213 


 80%|███████▉  | 165/207 [00:24<00:05,  7.06it/s]


Iteration 164/207 of epoch 37 complete. Loss : 0.0038739618605354844 


100%|█████████▉| 206/207 [00:30<00:00,  6.73it/s]


Iteration 205/207 of epoch 37 complete. Loss : 0.00044162731927183525 


100%|██████████| 207/207 [00:30<00:00,  6.80it/s]
100%|██████████| 23/23 [00:01<00:00, 14.60it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 37 complete! Validation Loss : 1.9887373110522395


 20%|██        | 42/207 [00:06<00:24,  6.71it/s]


Iteration 41/207 of epoch 38 complete. Loss : 0.005230455993314092 


 40%|████      | 83/207 [00:12<00:17,  6.95it/s]


Iteration 82/207 of epoch 38 complete. Loss : 0.0014002547796713415 


 60%|█████▉    | 124/207 [00:18<00:12,  6.71it/s]


Iteration 123/207 of epoch 38 complete. Loss : 0.0005079401868520441 


 80%|███████▉  | 165/207 [00:24<00:05,  7.09it/s]


Iteration 164/207 of epoch 38 complete. Loss : 0.00512240636666453 


100%|█████████▉| 206/207 [00:30<00:00,  6.74it/s]


Iteration 205/207 of epoch 38 complete. Loss : 0.0006819451710481833 


100%|██████████| 207/207 [00:30<00:00,  6.81it/s]
100%|██████████| 23/23 [00:01<00:00, 15.07it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 38 complete! Validation Loss : 2.0226477540057637


 20%|██        | 42/207 [00:06<00:24,  6.73it/s]


Iteration 41/207 of epoch 39 complete. Loss : 0.004117035063866685 


 40%|████      | 83/207 [00:12<00:17,  7.09it/s]


Iteration 82/207 of epoch 39 complete. Loss : 0.0008839422584193327 


 60%|█████▉    | 124/207 [00:18<00:12,  6.70it/s]


Iteration 123/207 of epoch 39 complete. Loss : 0.000451969713250902 


 80%|███████▉  | 165/207 [00:24<00:05,  7.10it/s]


Iteration 164/207 of epoch 39 complete. Loss : 0.0031148737925388737 


100%|█████████▉| 206/207 [00:30<00:00,  6.77it/s]


Iteration 205/207 of epoch 39 complete. Loss : 0.0006083107593174024 


100%|██████████| 207/207 [00:30<00:00,  6.83it/s]
100%|██████████| 23/23 [00:01<00:00, 15.04it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 39 complete! Validation Loss : 2.0118188339730967


 20%|██        | 42/207 [00:06<00:24,  6.64it/s]


Iteration 41/207 of epoch 40 complete. Loss : 0.013903712625217782 


 40%|████      | 83/207 [00:12<00:17,  6.96it/s]


Iteration 82/207 of epoch 40 complete. Loss : 0.0013889573743742912 


 60%|█████▉    | 124/207 [00:18<00:12,  6.71it/s]


Iteration 123/207 of epoch 40 complete. Loss : 0.00045037283789275625 


 80%|███████▉  | 165/207 [00:24<00:05,  7.03it/s]


Iteration 164/207 of epoch 40 complete. Loss : 0.0033085222743599245 


100%|█████████▉| 206/207 [00:30<00:00,  6.75it/s]


Iteration 205/207 of epoch 40 complete. Loss : 0.0004463832124090958 


100%|██████████| 207/207 [00:30<00:00,  6.81it/s]
100%|██████████| 23/23 [00:01<00:00, 14.86it/s]



Epoch 40 complete! Validation Loss : 1.9997464003770247
The model has been saved in models/bert-base-uncased_lr_2e-05_val_loss_0.58131_ep_3.pt


## Prediction

In [None]:
print("Creation of the results' folder...")
!mkdir results

Creation of the results' folder...


In [None]:
def get_probs_from_logits(logits):
    """
    Converts a tensor of logits into an array of probabilities by applying the sigmoid function
    """
    probs = torch.sigmoid(logits.unsqueeze(-1))
    return probs.detach().cpu().numpy()

def test_prediction(net, device, dataloader, with_labels=True, result_file="results/output.txt"):
    """
    Predict the probabilities on a dataset with or without labels and print the result in a file
    """
    net.eval()
    w = open(result_file, 'w')
    probs_all = []

    with torch.no_grad():
        if with_labels:
            for seq, attn_masks, token_type_ids, _ in tqdm(dataloader):
                seq, attn_masks, token_type_ids = seq.to(device), attn_masks.to(device), token_type_ids.to(device)
                logits = net(seq, attn_masks, token_type_ids)
                probs = get_probs_from_logits(logits.squeeze(-1)).squeeze(-1)
                probs_all += probs.tolist()
        else:
            for seq, attn_masks, token_type_ids in tqdm(dataloader):
                seq, attn_masks, token_type_ids = seq.to(device), attn_masks.to(device), token_type_ids.to(device)
                logits = net(seq, attn_masks, token_type_ids)
                probs = get_probs_from_logits(logits.squeeze(-1)).squeeze(-1)
                probs_all += probs.tolist()
    print(probs_all)
    w.writelines(str(prob)+'\n' for prob in probs_all)
    w.close()

In [None]:
df_test.iloc[0]["sentence2"]

'" The foodservice pie business does not fit our long-term growth strategy .'

In [None]:
path_to_model = '/content/models/bert-base-uncased_lr_2e-05_val_loss_0.58131_ep_3.pt'  
# path_to_model = '/content/models/albert-base-v2_lr_2e-05_val_loss_0.31706_ep_3.pt' 
# path_to_model = '/content/models/...'  # You can add here your trained model

path_to_output_file = 'results/output.txt'

sequence_0 = ["Reports that the NSA eavesdropped on world leaders have \"severely shaken\" relations between Europe and the U.S., German.",
"Buying a painting require lot of money.",
"Quora is not user friendly so i prefer google because answers are available over there.",
"He said the foodservice pie business doesn't fit the company 's long-term growth strategy."
]
sequence_1 = ["Germany and France are to seek talks with the US to settle a row over spying, as espionage claims continue to overshadow an EU summit in Brussels.",
"Today I had a dream about spending a lot of time in painting.",
"Why do people ask Quora questions which can be answered easily by Google?",
" The foodservice pie business does not fit our long-term growth strategy."
]

data = {"idx":[0, 1, 2, 3], "label":[1, 0, 1, 1], "sentence1": sequence_0, "sentence2":sequence_1}
data = pd.DataFrame(data=data)
print("Reading test data...")
test_set = CustomDataset(data, maxlen, bert_model)
test_loader = DataLoader(test_set, batch_size=bs, num_workers=5)

model = SentencePairClassifier(bert_model)
if torch.cuda.device_count() > 1:  # if multiple GPUs
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)

print()
print("Loading the weights of the model...")
model.load_state_dict(torch.load(path_to_model))
model.to(device)

print("Predicting on test data...")
test_prediction(net=model, device=device, dataloader=test_loader, with_labels=True,  # set the with_labels parameter to False if your want to get predictions on a dataset without labels
                result_file=path_to_output_file)
print()
print("Predictions are available in : {}".format(path_to_output_file))

Reading test data...

Loading the weights of the model...


  0%|          | 0/1 [00:00<?, ?it/s]

Predicting on test data...


100%|██████████| 1/1 [00:00<00:00,  2.67it/s]

[0.521484375, 0.54638671875, 0.55126953125, 0.9326171875]

Predictions are available in : results/output.txt



