In [1]:
!pip install -q sentence-transformers datasets

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/171.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━[0m [32m163.8/171.5 kB[0m [31m5.6 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m171.5/171.5 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.0/542.0 kB[0m [31m11.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m14.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m401.2/401.2 kB[0m [31m16.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import datasets

snli = datasets.load_dataset('snli', split='train')
mnli = datasets.load_dataset('glue', 'mnli', split='train')
mnli = mnli.remove_columns(['idx'])
dataset = datasets.concatenate_datasets([snli, mnli])
dataset = dataset.filter(lambda x: True if x['label']==0 else False)
del snli, mnli
dataset

Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 314315
})

In [3]:
from transformers import AutoTokenizer

bert_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
dataset = dataset.map(lambda x: bert_tokenizer(x['premise'], padding='max_length', max_length=128, truncation=True), batched=True)



tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Map:   0%|          | 0/314315 [00:00<?, ? examples/s]

In [4]:
dataset = dataset.rename_column('attention_mask', 'anchor_mask')
dataset = dataset.rename_column('input_ids', 'anchor_ids')

In [5]:
dataset = dataset.map(lambda x: bert_tokenizer(x['hypothesis'], padding='max_length', max_length=128, truncation=True), batched=True)
dataset = dataset.rename_column('attention_mask', 'positive_mask')
dataset = dataset.rename_column('input_ids', 'positive_ids')

Map:   0%|          | 0/314315 [00:00<?, ? examples/s]

In [6]:
dataset = dataset.remove_columns(['premise', 'hypothesis', 'label', 'token_type_ids'])
dataset

Dataset({
    features: ['anchor_ids', 'anchor_mask', 'positive_ids', 'positive_mask'],
    num_rows: 314315
})

In [7]:
dataset.set_format(type='torch', output_all_columns=True)

In [8]:
import torch
from torch.utils.data import DataLoader

batch_size = 2
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
len(loader)

157158

In [28]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"device: {device}")

device: cpu


In [29]:
from transformers import AutoModel
model = AutoModel.from_pretrained('bert-base-uncased').to(device)



In [30]:
cos_sim = torch.nn.CosineSimilarity().to(device)
loss_func = torch.nn.CrossEntropyLoss().to(device)
scale = 20.0

In [69]:
def mean_pooling(input_ids, attention_mask):
  attention_mask = attention_mask.unsqueeze(-1).expand(input_ids.size()).float()
  return (input_ids*attention_mask).sum(1) / attention_mask.sum(1).clamp(min=1e-9)

In [70]:
from transformers import get_linear_schedule_with_warmup

epochs = 1
total_steps = int(len(loader) * epochs)
warmup_steps = int(total_steps * 0.1)

opt = torch.optim.Adam(model.parameters(), lr=2e-5)
sched = get_linear_schedule_with_warmup(opt, num_warmup_steps=warmup_steps, num_training_steps=total_steps-warmup_steps)

In [86]:
from tqdm import tqdm

for epoch in range(epochs):
  model.train()
  loop = tqdm(loader)
  for xb in loop:
    opt.zero_grad()
    anchor_ids = xb['anchor_ids'].to(device)
    anchor_mask = xb['anchor_mask'].to(device)
    pos_ids = xb['positive_ids'].to(device)
    pos_mask = xb['positive_mask'].to(device)

    a = model(anchor_ids, attention_mask=anchor_mask)[0]
    p = model(pos_ids, attention_mask=pos_mask)[0]

    a = mean_pooling(a, anchor_mask)
    p = mean_pooling(p, pos_mask)

    scores = torch.stack([cos_sim(ai[None], p) for ai in a])
    labels = torch.tensor(range(len(scores)), dtype=torch.long, device=scores.device)

    loss = loss_func(scores*scale,labels)
    loss.backward()
    opt.step()
    sched.step()

    loop.set_description(f"epoch:{epoch}")
    loop.set_postfix(loss=f"{loss.item():.4f}")
    break


epoch:0:   0%|          | 0/157158 [00:09<?, ?it/s, loss=0.0136]
