In [1]:
import datasets
import torch

snli = datasets.load_dataset('snli', split='train')
mnli = datasets.load_dataset('glue', 'mnli', split='train')
mnli = mnli.remove_columns('idx')
snli = snli.cast(mnli.features)
dataset = datasets.concatenate_datasets([snli, mnli])
del snli
del mnli
dataset

Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 942854
})

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [3]:
dataset = dataset.filter(lambda x: True if x['label'] == 0 else False)
dataset

Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 314315
})

In [4]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
dataset = dataset.map(lambda x: tokenizer(
    x['premise'], padding='max_length', max_length=128, truncation=True
), batched=True)

In [5]:
dataset

Dataset({
    features: ['premise', 'hypothesis', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 314315
})

In [6]:
dataset = dataset.rename_column('input_ids', 'anchor_ids')
dataset = dataset.rename_column('attention_mask', 'anchor_mask')
dataset

Dataset({
    features: ['premise', 'hypothesis', 'label', 'anchor_ids', 'token_type_ids', 'anchor_mask'],
    num_rows: 314315
})

In [7]:
dataset = dataset.map(lambda x: tokenizer(
    x['hypothesis'], padding='max_length', max_length=128, truncation=True
), batched=True)
dataset = dataset.rename_column('input_ids', 'positive_ids')
dataset = dataset.rename_column('attention_mask', 'positive_mask')
dataset

Dataset({
    features: ['premise', 'hypothesis', 'label', 'anchor_ids', 'token_type_ids', 'anchor_mask', 'positive_ids', 'positive_mask'],
    num_rows: 314315
})

In [8]:
dataset = dataset.remove_columns(['premise', 'hypothesis', 'label', 'token_type_ids'])
dataset

Dataset({
    features: ['anchor_ids', 'anchor_mask', 'positive_ids', 'positive_mask'],
    num_rows: 314315
})

In [9]:
dataset.set_format('torch', output_all_columns=True)
dataset

Dataset({
    features: ['anchor_ids', 'anchor_mask', 'positive_ids', 'positive_mask'],
    num_rows: 314315
})

In [10]:
import torch

batch_size = 32
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
xb = next(iter(loader))
xb['anchor_ids'].shape

torch.Size([32, 128])

In [11]:
from transformers import AutoModel

model = AutoModel.from_pretrained('bert-base-uncased').to(device)

In [12]:
def mean_pool(token_emb, attn_mask):
    in_mask = attn_mask.unsqueeze(-1).expand(token_emb.size()).float()
    pool = torch.sum(token_emb*in_mask, 1)/torch.clamp_min(in_mask.sum(1), min=1e-9)
    return pool

In [13]:
cos_sim = torch.nn.CosineSimilarity()

In [14]:
a = xb['anchor_ids'].float().to(device)
p = xb['positive_ids'].float().to(device)
a.shape, p.shape, a.dtype, p.dtype

(torch.Size([32, 128]), torch.Size([32, 128]), torch.float32, torch.float32)

In [15]:
scores = []
for ai in a:
    score = cos_sim(ai[None], p)
    scores.append(score)
scores = torch.stack(scores).to(device)
print(scores.shape)
scores

torch.Size([32, 32])


tensor([[0.1778, 0.1624, 0.6692,  ..., 0.3816, 0.2774, 0.1551],
        [0.9022, 0.7854, 0.4253,  ..., 0.7493, 0.2504, 0.4268],
        [0.5355, 0.3599, 0.7655,  ..., 0.6039, 0.3051, 0.2056],
        ...,
        [0.1565, 0.1964, 0.2095,  ..., 0.3115, 0.3539, 0.2061],
        [0.2051, 0.1689, 0.1542,  ..., 0.2520, 0.2966, 0.1240],
        [0.2051, 0.1689, 0.1542,  ..., 0.2520, 0.2966, 0.1240]],
       device='cuda:0')

In [16]:
label = torch.tensor(range(len(scores)), dtype=torch.long, device=scores.device)
label

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='cuda:0')

In [17]:
loss_func = torch.nn.CrossEntropyLoss()

In [18]:
loss_func(scores, label)

tensor(3.4344, device='cuda:0')

In [19]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
loss_func = torch.nn.CrossEntropyLoss()
cos_sim = torch.nn.CosineSimilarity()
cos_sim.to(device)
loss_func.to(device);

In [20]:
from transformers.optimization import get_linear_schedule_with_warmup

scale = 20.0
opt = torch.optim.Adam(model.parameters(), lr=2e-5)
total_steps = int(len(dataset)/batch_size)
warmup_steps = int(0.1 * total_steps)
sched = get_linear_schedule_with_warmup(opt, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

In [None]:
from tqdm import tqdm

epochs = 1
for epoch in range(epochs):
    model.train()
    loop = tqdm(loader, leave=True)
    for batch in loop:
        opt.zero_grad()
        
        anchor_ids = batch['anchor_ids'].to(device)
        anchor_mask = batch['anchor_mask'].to(device)
        positive_ids = batch['positive_ids'].to(device)
        positive_mask = batch['positive_mask'].to(device)

        a = model(anchor_ids, anchor_mask)[0]
        p = model(positive_ids, positive_mask)[0]
        a = mean_pool(a, anchor_mask)
        p = mean_pool(p, positive_mask)

        scores = torch.stack([cos_sim(ai[None], p) for ai in a])
        labels = torch.tensor(range(len(scores)), dtype=torch.long, device=scores.device)

        loss = loss_func(scores*scale, labels)
        loss.backward()
        opt.step()
        sched.step()

        loop.set_description(f"epoch {epoch+1}")
        loop.set_postfix(loss=loss.item())


epoch 1:  31%|███       | 3028/9823 [30:12<1:07:50,  1.67it/s, loss=1.37] IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

epoch 1:  80%|████████  | 7860/9823 [1:18:45<19:39,  1.66it/s, loss=0.00478] 

In [6]:
import datasets
import torch

snli = datasets.load_dataset('snli', split='train')
mnli = datasets.load_dataset('glue', 'mnli', split='train')
mnli = mnli.remove_columns('idx')
snli = snli.cast(mnli.features)
dataset = datasets.concatenate_datasets([snli, mnli])
dataset

Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 942854
})

In [8]:
del snli, mnli

In [9]:
dataset = dataset.filter(lambda x: True if x['label']==0 else False)
dataset

Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 314315
})

In [11]:
from sentence_transformers import InputExample
from tqdm import tqdm

train_examples = []
for row in tqdm(dataset):
    train_examples.append(InputExample(texts=[row['premise'], row['hypothesis']]))

100%|██████████| 314315/314315 [00:26<00:00, 11925.30it/s]


In [14]:
from sentence_transformers import datasets as snt_datasets

batch_size = 32
loader = snt_datasets.NoDuplicatesDataLoader(train_examples, batch_size=batch_size)
len(loader)

9822

In [18]:
from sentence_transformers import models, SentenceTransformer

bert = models.Transformer('bert-base-uncased')
pooling = models.Pooling(model.get_word_embedding_dimension(), pooling_mode_mean_tokens=True)
model = SentenceTransformer(modules=[bert, pooling])
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
)

In [19]:
from sentence_transformers import losses

loss = losses.MultipleNegativesRankingLoss(model)

In [20]:
epochs = 1
warmup_steps = int(len(loader)*epochs*0.1)

In [None]:
model.fit(
    train_objectives=[(loader, loss)],
    epochs=epochs,
    warmup_steps=warmup_steps,
    output_path='./sbert_test_mnr2',
    show_progress_bar=False
)