In [1]:
from sentence_transformers import SentenceTransformer, util

In [2]:
model = SentenceTransformer('bert-base-nli-mean-tokens')

In [3]:
sentences = [
    "the fifty mannequin heads floating in the pool kind of freaked them out",
    "she swore she just saw her sushi move",
    "he embraced his new life as an eggplant",
    "my dentist tells me that chewing bricks is very bad for your teeth",
    "the dental specialist recommended an immediate stop to flossing with construction materials"
]
embeddings = model.encode(sentences)
embeddings.shape

(5, 768)

In [4]:
import numpy as np

sim = np.zeros((len(sentences), len(sentences)))
for i in range(len(sentences)):
    sim[i:, i] = util.cos_sim(embeddings[i], embeddings[i:])
sim

array([[1.        , 0.        , 0.        , 0.        , 0.        ],
       [0.40914282, 1.        , 0.        , 0.        , 0.        ],
       [0.10909006, 0.44547969, 1.00000012, 0.        , 0.        ],
       [0.50074875, 0.30693936, 0.20791629, 1.        , 0.        ],
       [0.29936197, 0.38607204, 0.28499246, 0.63849503, 0.99999994]])

In [5]:
model = SentenceTransformer('all-mpnet-base-v2')
embeddings = model.encode(sentences)
sim = np.zeros((len(sentences), len(sentences)))
for i in range(len(sentences)):
    sim[i:, i] = util.cos_sim(embeddings[i], embeddings[i:])
sim

array([[ 1.00000024,  0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.26406279,  1.00000024,  0.        ,  0.        ,  0.        ],
       [ 0.16503477,  0.16126668,  1.00000012,  0.        ,  0.        ],
       [ 0.0433445 ,  0.04615868,  0.05670126,  1.00000012,  0.        ],
       [ 0.05398509,  0.0610119 , -0.01122268,  0.51847214,  0.99999994]])

In [6]:
import datasets

In [7]:
snli = datasets.load_dataset('snli', split='train')
snli

Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 550152
})

In [8]:
print(snli[0])

{'premise': 'A person on a horse jumps over a broken down airplane.', 'hypothesis': 'A person is training his horse for a competition.', 'label': 1}


In [10]:
mnli = datasets.load_dataset('glue', 'mnli', split='train')
mnli

Dataset({
    features: ['premise', 'hypothesis', 'label', 'idx'],
    num_rows: 392702
})

In [11]:
mnli[0]

{'premise': 'Conceptually cream skimming has two basic dimensions - product and geography.',
 'hypothesis': 'Product and geography are what make cream skimming work. ',
 'label': 1,
 'idx': 0}

In [12]:
mnli = mnli.remove_columns(['idx'])
mnli

Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 392702
})

In [13]:
snli = snli.cast(mnli.features)
dataset = datasets.concatenate_datasets([snli, mnli])
dataset

Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 942854
})

In [14]:
dataset = dataset.filter(lambda x: 0 if x['label']==-1 else 1)
len(dataset)

Filter:   0%|          | 0/942854 [00:00<?, ? examples/s]

942069

In [15]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [16]:
all_cols = ['label']
for part in ['premise', 'hypothesis']:
    dataset = dataset.map(
        lambda x: tokenizer(x[part], max_length=128, padding='max_length', truncation=True),
        batched=True
    ) 
    for col in ['input_ids', 'attention_mask']:
        dataset = dataset.rename_column(col, part+'_'+col)
        all_cols.append(part+'_'+col)
dataset        

Map:   0%|          | 0/942069 [00:00<?, ? examples/s]

Map:   0%|          | 0/942069 [00:00<?, ? examples/s]

Dataset({
    features: ['premise', 'hypothesis', 'label', 'premise_input_ids', 'token_type_ids', 'premise_attention_mask', 'hypothesis_input_ids', 'hypothesis_attention_mask'],
    num_rows: 942069
})

In [17]:
all_cols

['label',
 'premise_input_ids',
 'premise_attention_mask',
 'hypothesis_input_ids',
 'hypothesis_attention_mask']

In [18]:
dataset.set_format(type='torch', columns=all_cols) 

In [123]:
import torch

batch_size = 32
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
loader

<torch.utils.data.dataloader.DataLoader at 0x7fb4f094b5e0>

In [124]:
for batch in  loader:
    print(batch.keys())
    break

dict_keys(['label', 'premise_input_ids', 'premise_attention_mask', 'hypothesis_input_ids', 'hypothesis_attention_mask'])


In [125]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [132]:
import gc
gc.collect()

torch.cuda.reset_max_memory_allocated()
print(torch.cuda.memory_allocated())  # prints the allocated memory
print(torch.cuda.memory_cached())

15657655808
23546822656


In [133]:
from transformers import BertModel

model = BertModel.from_pretrained('bert-base-uncased').to(device)
model.device

device(type='cuda', index=0)

In [134]:
def mean_pool(token_embedds, attention_mask): 
    in_mask = attention_mask.unsqueeze(-1).expand(token_embedds.size()).float()
    pool = torch.sum(token_embedds*in_mask, 1) / torch.clamp(in_mask.sum(1), min=1e-9)
    return pool

In [135]:
ffn = torch.nn.Linear(768*3, 3).to(device)
ffn 

Linear(in_features=2304, out_features=3, bias=True)

In [136]:
from transformers.optimization import get_linear_schedule_with_warmup

optim = torch.optim.Adam(model.parameters(), lr=2e-5)
total_steps = int(len(dataset)) / batch_size
warmup_steps = int(0.1*total_steps)
scheduler = get_linear_schedule_with_warmup(optim, num_warmup_steps=warmup_steps, 
                                            num_training_steps=total_steps-warmup_steps) 

In [137]:
loss_func = torch.nn.CrossEntropyLoss()

In [None]:
from tqdm import tqdm

for epoch in range(1):
    model.train()
    loop = tqdm(loader, leave=True)
    for batch in loop:
        optim.zero_grad()
        input_ids_a = batch['premise_input_ids'].to(device)
        input_ids_b = batch['hypothesis_input_ids'].to(device)
        attention_a = batch['premise_attention_mask'].to(device)
        attention_b = batch['hypothesis_attention_mask'].to(device)
        label = batch['label'].to(device)

        u = model(input_ids_a, attention_mask=attention_a).last_hidden_state
        v = model(input_ids_b, attention_mask=attention_b).last_hidden_state

        u = mean_pool(u, attention_a)
        v = mean_pool(v, attention_b)

        uv = torch.sub(u, v)
        uv_abs = torch.abs(uv)

        x = torch.cat([u, v, uv_abs], dim=-1)
        x = ffn(x)
        loss = loss_func(x, label)
        loss.backward()
        optim.step()
        scheduler.step()
        loop.set_description(f"epoch {epoch}")
        loop.set_postfix(loss=loss.item())

epoch 0:   4%|▍         | 1272/29440 [12:29<4:39:42,  1.68it/s, loss=0.855]

In [None]:
import os

model_path = './sbert_test_a'

if not os.path.exists(model_path):
    os.mkdir(model_path)

model.save_pretrained(model_path)