# Init

In [1]:
# import tensorflow as tf
import math
import time
import torch
import matplotlib.pyplot as plt
import spacy
import sentence_transformers
import logging
import pynvml
from collections import OrderedDict
from sentence_transformers import SentenceTransformer
from spacy.lang.en import English
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import random_split
from torch.nn.utils.rnn import pad_sequence, pack_sequence, pack_padded_sequence
from torchvision.transforms import Compose
from transformers import BertModel, BertTokenizer, GPT2Model, GPT2Tokenizer, RobertaTokenizer, RobertaModel
from sentence_transformers import SentenceTransformer

# working directory
ROOT_DIR = 'C:/Users/rossz/OneDrive/CC'
DATA_DIR = f'{ROOT_DIR}/data'
print(f'ROOT_DIR: {ROOT_DIR}')
print(f'DATA_DIR: {DATA_DIR}')

# set random seed
np.random.seed(42)
torch.manual_seed(42);
torch.backends.cudnn.deterministic = True;
torch.backends.cudnn.benchmark = True;

# set device 'cuda' or 'cpu'
if torch.cuda.is_available():
    n_cuda = torch.cuda.device_count();
    
    def log_gpu_memory(verbose=False):
        torch.cuda.empty_cache()
        if verbose:
            for _ in range(n_cuda):
                print(f'GPU {_}:')
                print(f'{torch.cuda.memory_summary(_, abbreviated=True)}')
        else:
            for _ in range(n_cuda):
                memory_total = torch.cuda.get_device_properties(_).total_memory/(1024**3)
                memory_allocated = torch.cuda.memory_allocated(_)/(1024**3)
                print(f'GPU {_}: {memory_allocated: .2f}/{memory_total: .2f} (GB)')
            
    print(f'\n{n_cuda} GPUs found:');
    for _ in range(n_cuda):
        globals()[f'cuda{_}'] = torch.device(f'cuda:{_}');
        print(f'    {torch.cuda.get_device_name(_)} (cuda{_})');
        
    print('\nGPU memory:');
    log_gpu_memory();
else:
    print('GPU NOT enabled');
    
cpu = torch.device('cpu');
n_cpu = int(mp.cpu_count()/2);

print(f'\nCPU cores (physicial): {n_cpu}');

ROOT_DIR: C:/Users/rossz/OneDrive/CC
DATA_DIR: C:/Users/rossz/OneDrive/CC/data

2 GPUs found:
    GeForce RTX 2080 Ti (cuda0)
    GeForce RTX 2080 Ti (cuda1)

GPU memory:
GPU 0:  0.00/ 11.00 (GB)
GPU 1:  0.00/ 11.00 (GB)

CPU cores (physicial): 16


# SBERT

## load model

In [2]:
model_path = "C:/Users/rossz/.cache/torch/sentence_transformers/public.ukp.informatik.tu-darmstadt.de_reimers_sentence-transformers_v0.2_roberta-large-nli-stsb-mean-tokens.zip"

with open(os.path.join(model_path, 'modules.json')) as fIn:
    contained_modules = json.load(fIn)
    
sbert_modules = OrderedDict()
for module_config in contained_modules:
    module_class = sentence_transformers.util.import_from_string(module_config['type'])
    module = module_class.load(os.path.join(model_path, module_config['path']))
    sbert_modules[module_config['name']] = module
    
# For Roberta, pad_token_id == 1
if 'roberta' in model_path:
    sbert_pad_token_id = 1
else:
    raise Exception("You're not using RoBERTa, double check your pad_token_id")
    

sbert_model = nn.Sequential(sbert_modules)
sbert_model = nn.DataParallel(sbert_model)
sbert_model.to(cuda0);
log_gpu_memory();

GPU 0:  1.32/ 11.00 (GB)
GPU 1:  0.00/ 11.00 (GB)


## define Dataset

In [3]:
class Tokenize():
    def __init__(self, modules, pad_token_id, max_seq_len):
        '''
        max_seq_len: There're still ass-cover statement in the call, which are very long.
            I remove every sentence which are longer than `max_seq_len`
        pad_token_id: for empty sentences, set length to 1 and fill with `pad_token_id`
        '''
        self.max_seq_len = max_seq_len
        self.pad_token_id = pad_token_id
        self.modules = modules
        
    def __call__(self, sample):
        transcriptid, sentenceid, sent = sample
        sent = self.modules[next(iter(self.modules))].tokenize(sent)
        
        if len(sent) == 0 or len(sent) < self.max_seq_len:
            return transcriptid, sentenceid, sent
        else:
            return transcriptid, sentenceid, [self.pad_token_id]        


class CCDataset(Dataset):
    def __init__(self, df, transform=None):
        '''
        Args:
            df: DataFrame 
        '''
        self.transform = transform
        self.df = df
        self.length_sorted_idx = np.argsort([len(sent) for sent in df['text'].tolist()])

        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        # sample: (transcripid, sentenceid, text)
        sample = tuple(self.df.iloc[self.length_sorted_idx[idx]])
            
        if self.transform:
            sample = self.transform(sample)
            
        return sample
    
# MAX_SENT_LEN = 256
# ds = CCDataset(text_present_sentencized, transform=Tokenize(sbert_tokenizer, modules, pad_token_id=0, max_seq_len=MAX_SENT_LEN))

## define DataLoader

In [4]:
# --------------------------- Create DataLoader--------------------------
def collate_fn(data: list, modules):
    '''
    Returns:
        featurs: a list of features. {'input_ids', 'input_mask', 'sentence_lengths'}
    '''
    transcriptids, sentenceids, sents = list(zip(*data))
    meta = (transcriptids, sentenceids)

    # valid seq_len
    valid_seq_len = [len(sent) for sent in sents]
    longest_seq_len = max(valid_seq_len)
    
    # pad
    features = {}
    for sent in sents:
        sentence_features = modules[next(iter(modules))].get_sentence_features(sent, longest_seq_len)
        
        for feature_name in sentence_features:
            if feature_name not in features:
                features[feature_name] = []
            features[feature_name].append(sentence_features[feature_name])
            
    for feature_name in features:
        features[feature_name] = torch.tensor(np.asarray(features[feature_name]))
            
    return {'features': features, 'meta': meta}

# dl = DataLoader(ds, batch_size=32,
#                 shuffle=False, num_workers=0,
#                 collate_fn=partial(collate_fn, modules=modules),
#                 drop_last=False,
#                 pin_memory=False)
# one_batch = next(iter(dl))
# one_batch

## encode

In [5]:
def pre_encode_sbert(dl, model, save_path, start):
    with torch.no_grad():
        res = []
        for batch in tqdm(dl):
            features = batch['features']
            transcriptids, sentenceids = batch['meta']
            
            # check if the inputs are pinned
            # for k, v in features.items():
            #    print(f'{k}:\n {v.is_pinned()}')
            
            # forward
            embeddings = model(features)['sentence_embedding'].to(cpu).numpy()
            
            for transcriptid, sentenceid, embedding in zip(transcriptids, sentenceids, embeddings):
                res.append((transcriptid, sentenceid, embedding))
            
        # save every chunk
        torch.save(res, f'{save_path}_{start}.pt')   
        

# for every text_type, do encoding
for text_type in ['all', 'qa', 'present']:
    text_df = pd.read_feather(f'{DATA_DIR}/text_{text_type}_sentencized.feather')
    save_path = f'./data/embeddings/text_{text_type}_sbert_roberta_nlistsb_encoded'

    start = 0
    stop = len(text_df)
    chunksize = 400000 # 400000 for 1/10 to tatal 
    MAX_SENT_LEN = 256
    PREENCODE_BATCH_SIZE = 256

    for i in range(start, stop, chunksize):
        print(f'Processing {i}/{stop}...{i/stop*100: .1f}% {Now()}')

        try:
            text_df_chunk = text_df.iloc[i:min(i+chunksize, stop)]
            if min(i+chunksize, stop) % 2 != 0:
                text_df_chunk = text_df_chunk.iloc[:-1]

            ds = CCDataset(text_df_chunk, transform=Tokenize(sbert_modules, pad_token_id=sbert_pad_token_id, max_seq_len=MAX_SENT_LEN))

            dl = DataLoader(ds, batch_size=PREENCODE_BATCH_SIZE,
                            shuffle=False, num_workers=0,
                            collate_fn=partial(collate_fn, modules=sbert_modules),
                            drop_last=False,
                            pin_memory=True)

            pre_encode_sbert(dl, model=sbert_model, save_path=save_path, start=i)
        except Exception as e:
            print(f'Exception i={i}')
            print(f'   {e}')

Processing 0/11637455... 0.0% 01:08:51


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 400000/11637455... 3.4% 01:22:12


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 800000/11637455... 6.9% 01:34:56


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 1200000/11637455... 10.3% 01:47:17


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 1600000/11637455... 13.7% 01:59:28


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 2000000/11637455... 17.2% 02:11:51


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 2400000/11637455... 20.6% 02:25:22


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 2800000/11637455... 24.1% 02:37:30


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 3200000/11637455... 27.5% 02:49:37


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 3600000/11637455... 30.9% 03:01:39


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 4000000/11637455... 34.4% 03:13:31


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 4400000/11637455... 37.8% 03:25:28


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 4800000/11637455... 41.2% 03:37:19


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 5200000/11637455... 44.7% 03:49:18


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 5600000/11637455... 48.1% 04:01:23


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 6000000/11637455... 51.6% 04:13:29


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 6400000/11637455... 55.0% 04:25:43


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 6800000/11637455... 58.4% 04:37:48


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 7200000/11637455... 61.9% 04:49:52


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 7600000/11637455... 65.3% 05:01:57


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 8000000/11637455... 68.7% 05:13:57


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 8400000/11637455... 72.2% 05:26:03


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 8800000/11637455... 75.6% 05:38:09


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 9200000/11637455... 79.1% 05:50:12


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 9600000/11637455... 82.5% 06:02:23


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 10000000/11637455... 85.9% 06:14:27


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 10400000/11637455... 89.4% 06:26:35


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 10800000/11637455... 92.8% 06:38:40


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 11200000/11637455... 96.2% 06:50:49


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 11600000/11637455... 99.7% 07:03:07


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))


Processing 0/7448705... 0.0% 07:04:18


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 400000/7448705... 5.4% 07:16:48


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 800000/7448705... 10.7% 07:28:52


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 1200000/7448705... 16.1% 07:40:46


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 1600000/7448705... 21.5% 07:52:56


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 2000000/7448705... 26.9% 08:04:47


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 2400000/7448705... 32.2% 08:16:22


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 2800000/7448705... 37.6% 08:28:04


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 3200000/7448705... 43.0% 08:39:39


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 3600000/7448705... 48.3% 08:51:23


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 4000000/7448705... 53.7% 09:03:00


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 4400000/7448705... 59.1% 09:14:35


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 4800000/7448705... 64.4% 09:26:16


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 5200000/7448705... 69.8% 09:37:47


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 5600000/7448705... 75.2% 09:49:27


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 6000000/7448705... 80.6% 10:01:02


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 6400000/7448705... 85.9% 10:12:44


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 6800000/7448705... 91.3% 10:25:08


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 7200000/7448705... 96.7% 10:36:45


HBox(children=(IntProgress(value=0, max=972), HTML(value='')))


Processing 0/4195609... 0.0% 10:44:04


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 400000/4195609... 9.5% 10:58:28


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 800000/4195609... 19.1% 11:12:23


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 1200000/4195609... 28.6% 11:26:13


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 1600000/4195609... 38.1% 11:40:06


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 2000000/4195609... 47.7% 11:53:53


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 2400000/4195609... 57.2% 12:07:51


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 2800000/4195609... 66.7% 12:21:36


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 3200000/4195609... 76.3% 12:35:41


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 3600000/4195609... 85.8% 12:49:32


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


Processing 4000000/4195609... 95.3% 13:03:42


HBox(children=(IntProgress(value=0, max=765), HTML(value='')))


