In [None]:
! pip install transformers sentence-transformers datasets



In [None]:
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f'There are {torch.cuda.device_count()} GPU(s) available.')
    print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")


There are 1 GPU(s) available.
We will use the GPU: Tesla T4


In [None]:
# Dataset

from datasets import load_dataset
dataset = load_dataset("glue", "stsb")
dataset

DatasetDict({
    train: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 5749
    })
    validation: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 1500
    })
    test: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 1379
    })
})

In [None]:
# Tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

In [None]:


# Load STS-B dataset


class STSBDataset(torch.utils.data.Dataset):

    def __init__(self, dataset):
        # Normalize the similarity scores in the dataset
        similarity_scores = [i['label'] for i in dataset]
        self.normalized_similarity_scores = [i/5.0 for i in similarity_scores]
        self.first_sentences = [i['sentence1'] for i in dataset]
        self.second_sentences = [i['sentence2'] for i in dataset]
        self.concatenated_sentences = [[str(x), str(y)] for x,y in   zip(self.first_sentences, self.second_sentences)]

    def __len__(self):
        return len(self.concatenated_sentences)

    def get_batch_labels(self, idx):
        return torch.tensor(self.normalized_similarity_scores[idx])

    def get_batch_texts(self, idx):
        return tokenizer(self.concatenated_sentences[idx], padding='max_length', max_length=128, truncation=True, return_tensors="pt")

    def __getitem__(self, idx):
        batch_texts = self.get_batch_texts(idx)
        batch_y = self.get_batch_labels(idx)
        return batch_texts, batch_y


def collate_fn(texts):
    input_ids = texts['input_ids']
    attention_masks = texts['attention_mask']
    features = [{'input_ids': input_id, 'attention_mask': attention_mask}
                for input_id, attention_mask in zip(input_ids, attention_masks)]
    return features




# train_dataset = STSBDataset(dataset["train"])
# test_dataset = STSBDataset(dataset["test"])

In [None]:
from sentence_transformers import SentenceTransformer, models


class BertForSTS(torch.nn.Module):

    def __init__(self):
        super(BertForSTS, self).__init__()
        self.bert = models.Transformer('bert-base-uncased', max_seq_length=128)
        self.pooling_layer = models.Pooling(self.bert.get_word_embedding_dimension())
        self.sts_bert = SentenceTransformer(modules=[self.bert, self.pooling_layer])

    def forward(self, input_data):
        output = self.sts_bert(input_data)['sentence_embedding']
        return output

In [None]:
class CosineSimilarityLoss(torch.nn.Module):

    def __init__(self,  loss_fn=torch.nn.MSELoss(), transform_fn=torch.nn.Identity()):
        super(CosineSimilarityLoss, self).__init__()
        self.loss_fn = loss_fn
        self.transform_fn = transform_fn
        self.cos_similarity = torch.nn.CosineSimilarity(dim=1)

    def forward(self, inputs, labels):
        emb_1 = torch.stack([inp[0] for inp in inputs])
        emb_2 = torch.stack([inp[1] for inp in inputs])
        outputs = self.transform_fn(self.cos_similarity(emb_1, emb_2))
        return self.loss_fn(outputs, labels.squeeze())

In [None]:
# Instantiate the model and move it to GPU
model = BertForSTS()
model.to(device)

BertForSTS(
  (bert): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: BertModel 
  (pooling_layer): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (sts_bert): SentenceTransformer(
    (0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: BertModel 
    (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  )
)

In [None]:
from torch.utils.data import DataLoader


train_ds = STSBDataset(dataset['train'])
val_ds = STSBDataset(dataset['validation'])

# Create a 90-10 train-validation split.
train_size = len(train_ds)
val_size = len(val_ds)

print('{:>5,} training samples'.format(train_size))
print('{:>5,} validation samples'.format(val_size))


batch_size = 8

train_dataloader = DataLoader(
            train_ds,  # The training samples.
            num_workers = 4,
            batch_size = batch_size, # Use this batch size.
            shuffle=True # Select samples randomly for each batch
        )

validation_dataloader = DataLoader(
            val_ds,
            num_workers = 4,
            batch_size = batch_size # Use the same batch size
        )

5,749 training samples
1,500 validation samples


In [None]:
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup


optimizer = AdamW(model.parameters(),
                  lr = 1e-6)
epochs = 12
# Total number of training steps is [number of batches] x [number of epochs].
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps = 0,
                                            num_training_steps = total_steps)



In [None]:
import time
import random
import datetime
from tqdm import tqdm


def train():
  seed_val = 42
  criterion = CosineSimilarityLoss()
  criterion = criterion.cuda()
  random.seed(seed_val)
  torch.manual_seed(seed_val)
  # We'll store a number of quantities such as training and validation loss,
  # validation accuracy, and timings.
  training_stats = []
  total_t0 = time.time()
  for epoch_i in range(0, epochs):
      t0 = time.time()
      total_train_loss = 0
      model.train()
      # For each batch of training data...
      for train_data, train_label in tqdm(train_dataloader):
          train_data['input_ids'] = train_data['input_ids'].to(device)
          train_data['attention_mask'] = train_data['attention_mask'].to(device)
          train_data = collate_fn(train_data)
          model.zero_grad()
          output = [model(feature) for feature in train_data]
          loss = criterion(output, train_label.to(device))
          total_train_loss += loss.item()
          loss.backward()
          torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
          optimizer.step()
          scheduler.step()

      # Calculate the average loss over all of the batches.
      avg_train_loss = total_train_loss / len(train_dataloader)
      # Measure how long this epoch took.
      training_time = time.time() - t0
      t0 = time.time()
      model.eval()
      total_eval_accuracy = 0
      total_eval_loss = 0
      nb_eval_steps = 0
      # Evaluate data for one epoch
      for val_data, val_label in tqdm(validation_dataloader):
          val_data['input_ids'] = val_data['input_ids'].to(device)
          val_data['attention_mask'] = val_data['attention_mask'].to(device)
          val_data = collate_fn(val_data)
          with torch.no_grad():
              output = [model(feature) for feature in val_data]
          loss = criterion(output, val_label.to(device))
          total_eval_loss += loss.item()
      # Calculate the average loss over all of the batches.
      avg_val_loss = total_eval_loss / len(validation_dataloader)
      # Measure how long the validation run took.
      validation_time = time.time() - t0
      # Record all statistics from this epoch.
      training_stats.append(
          {
              'epoch': epoch_i + 1,
              'Training Loss': avg_train_loss,
              'Valid. Loss': avg_val_loss,
              'Training Time': training_time,
              'Validation Time': validation_time
          }
      )
  return model, training_stats




In [None]:
# Launch the training

model, training_stats = train()


100%|██████████| 719/719 [05:25<00:00,  2.21it/s]
100%|██████████| 188/188 [00:26<00:00,  7.10it/s]
100%|██████████| 719/719 [05:16<00:00,  2.27it/s]
100%|██████████| 188/188 [00:26<00:00,  7.09it/s]
100%|██████████| 719/719 [05:16<00:00,  2.27it/s]
100%|██████████| 188/188 [00:26<00:00,  7.08it/s]
100%|██████████| 719/719 [05:18<00:00,  2.26it/s]
100%|██████████| 188/188 [00:26<00:00,  7.08it/s]
100%|██████████| 719/719 [05:17<00:00,  2.26it/s]
100%|██████████| 188/188 [00:26<00:00,  7.02it/s]
100%|██████████| 719/719 [05:16<00:00,  2.27it/s]
100%|██████████| 188/188 [00:27<00:00,  6.91it/s]
100%|██████████| 719/719 [05:16<00:00,  2.27it/s]
100%|██████████| 188/188 [00:26<00:00,  7.08it/s]
100%|██████████| 719/719 [05:16<00:00,  2.27it/s]
100%|██████████| 188/188 [00:26<00:00,  7.09it/s]
100%|██████████| 719/719 [05:16<00:00,  2.27it/s]
100%|██████████| 188/188 [00:26<00:00,  7.12it/s]
100%|██████████| 719/719 [05:15<00:00,  2.28it/s]
100%|██████████| 188/188 [00:26<00:00,  7.16it/s]


In [None]:
PATH = 'bert-sts.pt'
torch.save(model.state_dict(), PATH)


NameError: name 'torch' is not defined

In [None]:
import pandas as pd

# Create a DataFrame from our training statistics
df_stats = pd.DataFrame(data=training_stats)

# Use the 'epoch' as the row index
df_stats = df_stats.set_index('epoch')

# Display the table
df_stats

Unnamed: 0_level_0,Training Loss,Valid. Loss,Training Time,Validation Time
epoch,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1,0.058644,0.046862,325.236908,26.487154
2,0.037757,0.039104,316.90258,26.527266
3,0.03235,0.03617,316.842228,26.583012
4,0.02909,0.034478,318.20668,26.574055
5,0.027396,0.033234,317.619009,26.800095
6,0.025756,0.032808,316.690433,27.20426
7,0.024457,0.032124,316.166912,26.584909
8,0.02329,0.031596,316.967151,26.515855
9,0.022598,0.03119,316.062719,26.419927
10,0.021671,0.030979,315.889875,26.272412
