In [28]:
import os
import numpy as np
import pandas as pd
import math
import csv

import transformers
from tqdm.notebook import trange, tqdm

from sentence_transformers import SentenceTransformer, SentencesDataset, losses
from sentence_transformers.readers import STSDataReader, TripletReader
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, BinaryEmbeddingSimilarityEvaluator, SequentialEvaluator
from sentence_transformers.readers.InputExample import InputExample

import torch
from torch.utils.data import DataLoader, RandomSampler

from scipy.spatial.distance import cdist


In [29]:
TRAIN_SPLITS_DATA_DIR = '/run/media/root/Windows/Users/agnes/Downloads/data/msmarco/train_data/splitted'

In [30]:
# increase swap size:
# https://superuser.com/questions/1024064/change-swap-file-size-fedora-23

In [31]:
model_wiki = SentenceTransformer('bert-base-wikipedia-sections-mean-tokens')

In [32]:
sentences = ['A fox lives in a zoo together with dogs.',
            'Sentences are passed as a list of string.', 
            'The quick brown fox jumps over the lazy dog.']


In [33]:
embeddings = model_wiki.encode(sentences)
sims = cdist(embeddings[0].reshape(-1,1).T, embeddings[1:], "cosine")[0]
sims

array([0.01991861, 0.01082202])

In [34]:
############

In [35]:
#torch.cuda.empty_cache()

In [36]:
my_train_data_path = os.path.join(TRAIN_SPLITS_DATA_DIR, 'queries3_sentences_triplet')

In [37]:
myreader_triplet = TripletReader(
                       my_train_data_path,
                       s1_col_idx=1, 
                       s2_col_idx=2, 
                       s3_col_idx=3,
                       delimiter=",",
                       has_header=True,
                       quoting=csv.QUOTE_MINIMAL)

examples = myreader_triplet.get_examples('queries3_sentences_triplet_dev.csv', max_examples=2)

In [38]:
examples[1].texts

['amniocentesis is performed what trimester',
 'Amniocentesis in late pregnancy.',
 'Vaginal bleeding in the first trimester of pregnancy can be caused by several different factors.']

In [15]:
train_data = myreader_triplet.get_examples('queries3_sentences_triplet_train.csv', max_examples=100000)
train_dataset = SentencesDataset(train_data, show_progress_bar=True, model=model_wiki)


Convert dataset: 100%|██████████| 100000/100000 [01:39<00:00, 1008.98it/s]


In [39]:
dev_dataset = SentencesDataset(
    myreader_triplet.get_examples('queries3_sentences_triplet_dev.csv'), 
    show_progress_bar=True, model=model_wiki)

Convert dataset: 100%|██████████| 4650/4650 [00:08<00:00, 541.98it/s]


In [40]:
train_batch_size = 8

In [41]:
dev_dataset_sampler = RandomSampler(dev_dataset, replacement=True, num_samples=2000)

In [12]:
#train_dataset_sampler = RandomSampler(train_dataset, replacement=False)

In [42]:
dev_dataloader = DataLoader(dev_dataset, batch_size=train_batch_size, sampler=dev_dataset_sampler)

In [19]:
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size, num_workers=1)#, sampler=train_dataset_sampler)

In [43]:
len(dev_dataset)

4650

In [44]:
len(dev_dataloader)

250

In [45]:
my_model_path = '/run/media/root/Windows/Users/agnes/Downloads/data/msmarco/train_results/test_wiki'

In [46]:
torch.save(dev_dataloader, os.path.join(my_model_path, 'dev_dataloader.pth'))

In [27]:
torch.save(train_dataloader, os.path.join(my_model_path, 'train_dataloader.pth'))