In [None]:
!pip install -q sentence-transformers datasets

In [None]:
import datasets
snli = datasets.load_dataset('snli', split='train')
mnli = datasets.load_dataset('glue', 'mnli', split='train')

In [4]:
snli, snli.features

(Dataset({
     features: ['premise', 'hypothesis', 'label'],
     num_rows: 550152
 }),
 {'premise': Value(dtype='string', id=None),
  'hypothesis': Value(dtype='string', id=None),
  'label': ClassLabel(names=['entailment', 'neutral', 'contradiction'], id=None)})

In [5]:
mnli, mnli.features

(Dataset({
     features: ['premise', 'hypothesis', 'label', 'idx'],
     num_rows: 392702
 }),
 {'premise': Value(dtype='string', id=None),
  'hypothesis': Value(dtype='string', id=None),
  'label': ClassLabel(names=['entailment', 'neutral', 'contradiction'], id=None),
  'idx': Value(dtype='int32', id=None)})

In [6]:
if 'idx' in mnli.features:
  mnli = mnli.remove_columns(['idx'])
mnli

Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 392702
})

In [7]:
snli = snli.cast(mnli.features)

Casting the dataset:   0%|          | 0/550152 [00:00<?, ? examples/s]

In [8]:
dataset = datasets.concatenate_datasets([snli, mnli])
dataset

Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 942854
})

In [9]:
set(dataset['label'])

{-1, 0, 1, 2}

In [10]:
dataset = dataset.filter(lambda x: x['label'] != -1)
set(dataset['label'])

Filter:   0%|          | 0/942854 [00:00<?, ? examples/s]

{0, 1, 2}

In [11]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')



tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [12]:
all_cols = ['label']
for part in ['premise', 'hypothesis']:
  dataset = dataset.map(
                        lambda x: tokenizer(x[part], max_length=128, padding='max_length', truncation=True)
                        ,batched=True)
  for lb in ['attention_mask', 'input_ids']:
    dataset = dataset.rename_column(lb, part + '_' + lb)
    all_cols.append(part + '_' + lb)
all_cols

Map:   0%|          | 0/942069 [00:00<?, ? examples/s]

Map:   0%|          | 0/942069 [00:00<?, ? examples/s]

['label',
 'premise_attention_mask',
 'premise_input_ids',
 'hypothesis_attention_mask',
 'hypothesis_input_ids']

In [13]:
import torch
dataset.set_format(type='pytorch', columns=all_cols)

In [14]:
batch_size = 16
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [15]:
xb = next(iter(loader))
xb.keys()

dict_keys(['label', 'premise_input_ids', 'premise_attention_mask', 'hypothesis_input_ids', 'hypothesis_attention_mask'])

In [16]:
device = 'cuda:0' if torch.cuda.is_available() else "cpu"
device

'cuda:0'

In [17]:
from transformers import AutoModel
model = AutoModel.from_pretrained('bert-base-uncased').to(device);

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [18]:
from transformers.optimization import get_linear_schedule_with_warmup

epochs = 1
opt = torch.optim.Adam(model.parameters(), lr=2e-5)
total_steps = int(len(loader) * epochs)
warmup_steps = int(0.1*total_steps)
sched = get_linear_schedule_with_warmup(
    opt,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps-warmup_steps,
)

In [19]:
def mean_pooling(token_emb, mask):
  mask = mask.unsqueeze(-1).expand(token_emb.size()).float()
  mpool = (token_emb * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
  return mpool

In [20]:
ffn = torch.nn.Linear(768*3, 3).to(device)
ffn

Linear(in_features=2304, out_features=3, bias=True)

In [21]:
loss_func = torch.nn.CrossEntropyLoss()

In [None]:
from tqdm import tqdm

for epoch in range(epochs):
  model.train()
  loop = tqdm(loader, leave=True)
  for xb in loop:
    opt.zero_grad()
    input_ids_a = xb['premise_input_ids'].to(device)
    input_ids_b = xb['hypothesis_input_ids'].to(device)
    attention_mask_a = xb['premise_attention_mask'].to(device)
    attention_mask_b = xb['hypothesis_attention_mask'].to(device)
    label = xb['label'].to(device)

    u = model(input_ids_a, attention_mask=attention_mask_a).last_hidden_state
    v = model(input_ids_b, attention_mask=attention_mask_b).last_hidden_state

    u = mean_pooling(u, attention_mask_a)
    v = mean_pooling(v, attention_mask_b)

    uv = u-v
    uv_abs = uv.abs()

    x = torch.cat([u, v, uv_abs], dim=-1)

    x = ffn(x)
    loss = loss_func(x, label)
    loss.backward()
    opt.step()
    sched.step()

    loop.set_description(f"epoch: {epoch}")
    loop.set_postfix(loss=f"{loss.item():.4f}")

In [23]:
import os

model_path = './sbert_test_a'
if os.path.exists(model_path):
  os.mkdir(model_path)
model.save_pretrained(model_path)

In [None]:
import datasets

snli = datasets.load_dataset('snli', split='train')
mnli = datasets.load_dataset('glue', 'mnli', split='train')
mnli = mnli.remove_columns(['idx'])
snli = snli.cast(mnli.features)
dataset = datasets.concatenate_datasets([snli, mnli])
dataset = dataset.filter(lambda x: x['label'] == -1)
dataset

In [4]:
import torch

In [5]:
from sentence_transformers import InputExample
from tqdm import tqdm

train_examples = []
for row in dataset:
  train_examples.append(InputExample(texts= [row['premise'], row['hypothesis']], label=row['label']))
train_examples[:2]

[<sentence_transformers.readers.InputExample.InputExample at 0x782074ff3220>,
 <sentence_transformers.readers.InputExample.InputExample at 0x782074ff3820>]

In [6]:
from torch.utils.data import DataLoader


batch_size = 2
loader = DataLoader(train_examples, batch_size=batch_size, shuffle=True)
len(loader)

50

In [7]:
from sentence_transformers import models, SentenceTransformer

bert_model = models.Transformer('bert-base-uncased')
pool_layer = models.Pooling(bert_model.get_word_embedding_dimension(), pooling_mode_mean_tokens=True)
sbert_model = SentenceTransformer(modules=[bert_model, pool_layer])
sbert_model



SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
)

In [8]:
from sentence_transformers import losses

sbert_loss_func = losses.SoftmaxLoss(model=sbert_model, sentence_embedding_dimension=768, num_labels=3)

In [10]:
from sentence_transformers import losses

epochs = 1
warmup_steps = int(len(loader)*epochs*0.1)
sbert_model.fit(
    train_objectives=[(loader, sbert_loss_func)],
    epochs=epochs,
    warmup_steps=warmup_steps,
    output_path = './sbert_test_b',
    show_progress_bar=True
)

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/50 [00:00<?, ?it/s]

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
