In [1]:
exec(open("../../header.py").read())

In [2]:
from torch.utils.data import Dataset, DataLoader
from nltk.tokenize import sent_tokenize
from transformers import BertTokenizer,BertModel
from datasets import load_from_disk,load_dataset

In [3]:
exec(open("on_the_fly.py").read())
exec(open("SegmentDataset.py").read())

In [4]:
def create_segments_list(cutoff_indices, sentence_list,tokenizer):
    '''
    Input:
        cutoff_indices: a list of cutoff indices. each index should be in the range of 0 to n-1, where n=len(sentence_list)
        sentence_list: a list of sentences from sent_tokenize
        tokenizer: the tokenizer for the model.
    Returns:
        segments_list: a list of 3-tuples of type BatchEncoding. This 3-tuple is the output of encode_plus
    '''
    segments_list = []
    #If cutoff indices is an empty list, means we don't split at all. then all the sentences get joined into one segment
    if len(cutoff_indices) == 0: 
        segment = "".join(sentence_list).lower()
        encoded_segment = tokenizer.encode_plus(segment,add_special_tokens=True,padding='max_length',max_length=512,truncation=True,return_tensors='pt')
        segments_list.append(encoded_segment)
        return segments_list
    #Make first n-1 splits
    start_idx = 0
    segments_list = []
    for split_idx in cutoff_indices: 
        grouped_sentences_list = sentence_list[start_idx:split_idx+1] 
        segment = "".join(grouped_sentences_list).lower()
        encoded_segment = tokenizer.encode_plus(segment,add_special_tokens=True,padding='max_length',max_length=512,truncation=True,return_tensors='pt')
        segments_list.append(encoded_segment)
        start_idx = split_idx+1
    # make last split
    grouped_sentences_list = sentence_list[start_idx:] 
    segment = "".join(grouped_sentences_list).lower()
    encoded_segment = tokenizer.encode_plus(segment,add_special_tokens=True,padding='max_length',max_length=512,truncation=True, return_tensors='pt')
    segments_list.append(encoded_segment)
    #Return 
    return segments_list

In [5]:
def squeeze_tensors(batch):
    '''
    batch has four dimensions (b_size,useless,useless, 512 (representing padded tokens))
    We want to squeeze the second and third dimensions
    '''
    batch['input_ids'] = batch['input_ids'].squeeze(axis=1).squeeze(axis=1)
    batch['token_type_ids'] = batch['token_type_ids'].squeeze(axis=1).squeeze(axis=1)
    batch['attention_mask'] = batch['attention_mask'].squeeze(axis=1).squeeze(axis=1)
    return batch


# ArgParse
# parser = argparse.ArgumentParser(description='Takes "label_to_cutoff_indices" pickle file, and creates BERT encoded segments')

# parser.add_argument('-t','--threshold',help='threshold. This isnt technically required, because the threshold is already used in the previous script (make_cutoff_indices), but this helps for loading the correct file.', required=True)
# parser.add_argument('-m', '--mode', help='what dataset are we using (currently only newsgroup is accepted)', default='newsgroup')
# parser.add_argument('-d', '--data_dir', help='path_to_data_dir', required=True)
# parser.add_argument('-p', '--processed_dir', help = 'path to processed_dir, which contains the label_to_cutoff_indices pickle file and also where the output of this script will be stored', required=True)
# args = vars(parser.parse_args())

# threshold = float(args['threshold'])
# mode = args['mode']
# data_dir = args['data_dir']
# processed_dir = args['processed_dir']

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print_cuda_info(device)

Using device: cuda
Quadro RTX 8000
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [7]:
splits = ['train','test']
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model= BertModel.from_pretrained('bert-base-uncased')
bert_model.eval()
bert_model.to(device)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [8]:
# Instead of looping, try one particular run
split = 'train'
dataset_list = []
config = newsgroup_configs[0]
model = 'bert-base-uncased'
threshold = 0.99

In [9]:
# Load raw data
subset_path = RAW_DIR(f'20news/{split}/{config}')
dataset_list.append((config,load_from_disk(subset_path)))

dataset_list

[('bydate_alt.atheism',
  Dataset({
      features: ['text'],
      num_rows: 480
  }))]

In [10]:
# Load the label_to_cutoff_indices pkl file, which contains the sentence splits for each long document.
cutoff_idx_folder = SEGMENT_DIR(f'20news/{model}/{split}/')
cutoff_idx_file = f'label_to_cutoff_indices_{threshold}.pkl'
cutoff_idx_path = cutoff_idx_folder + cutoff_idx_file

with open(cutoff_idx_path, 'rb') as handle:
    label_to_cutoff_indices_dict = pickle.load(handle)

In [11]:
#Create a Segment Dataset which contains tuples of (label - int, list of segments - list of 3-tuple which is output from tokenizer.encode_plus))
split_set = SegmentDataset(dataset_list,newsgroup_configs,label_to_cutoff_indices_dict,tokenizer)
split_loader = DataLoader(split_set, batch_size=1, shuffle=False, pin_memory=True)

applying splits for label:  bydate_alt.atheism


In [12]:
#Initialize bert_encoded_segments_list, this will contain the output that we want to dump
bert_encoded_segments_list = []
with torch.no_grad():
    for idx, batch in enumerate(split_loader):
        label =  batch[0]
        encoded_segments = batch[1]
        onthefly_dataset = OnTheFlyDataset(encoded_segments)
        onthefly_loader = DataLoader(onthefly_dataset, batch_size=4, shuffle=False, pin_memory=True)
        batch_encoded_seg_list = []
        for ii, small_batch in enumerate(onthefly_loader):
            small_batch = squeeze_tensors(small_batch)
            batch_input_ids = small_batch['input_ids'].to(device)
            batch_token_type_ids = small_batch['token_type_ids'].to(device)
            batch_attention_mask = small_batch['attention_mask'].to(device)
            out = bert_model(batch_input_ids, batch_token_type_ids, batch_attention_mask)
            # out['last_hidden_state'] is bsize x seq_len x embedding_size. We want to take only the embedding
            # which corresponds to the CLS token.
            sub_bert_encoded_segments = out['last_hidden_state'][:,0,:] #take only the first
            batch_encoded_seg_list.append(sub_bert_encoded_segments)
        bert_encoded_segments = torch.cat(batch_encoded_seg_list)
        bert_encoded_segments_list.append((label,bert_encoded_segments.cpu()))

# file_name = 'bert_encoded_segments_list_'
# with open(processed_dir+ split+'/' + file_name + str(threshold) +'.pkl', 'wb') as handle:
#     pickle.dump(bert_encoded_segments_list, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [13]:
len(bert_encoded_segments_list)

480

In [42]:
bert_encoded_segments_list[0][0]

tensor([0])

In [14]:
bert_encoded_segments_list[0][1].shape

torch.Size([12, 768])

In [15]:
bert_encoded_segments_list[1][1].shape

torch.Size([21, 768])

In [44]:
label

tensor([0])

In [45]:
all_roots

{'DATA_DIR': '/home/ay1626/NLU_data/',
 'RAW_DIR': '/home/ay1626/NLU_data/raw/',
 'SEGMENT_DIR': '/home/ay1626/NLU_data/segmentations/',
 'EMBEDDINGS_DIR': '/home/ay1626/NLU_data/embeddings/',
 'RESULTS_DIR': '/home/ay1626/NLU_data/results/'}

In [48]:
folder = EMBEDDINGS_DIR('20news/bert-base-uncased/train/')
file = "bert_encoded_segments_list_0.99.pkl"
path = folder + file
embed = pickle.load(open(path, "rb"))

In [61]:
embed[5000][1].shape

torch.Size([4, 768])

In [None]:
# for split in splits:
#     dataset_list = []
#     #Create (train, val or test) Dataset list 
#     for config in newsgroup_configs:
#         subset_path = data_dir + split + '/'+ config
#         dataset_list.append((config,load_from_disk(subset_path)))

#     # Load the label_to_cutoff_indices pkl file, which contains the sentence splits for each long document.
#     label_to_cutoff_indices_file = \
#         processed_dir + \
#         split + '/label_to_cutoff_indices_' + str(threshold) + '.pkl'
#     with open(label_to_cutoff_indices_file, 'rb') as handle:
#         label_to_cutoff_indices_dict = pickle.load(handle)


#     #Create a Segment Dataset which contains tuples of (label - int, list of segments - list of 3-tuple which is output from tokenizer.encode_plus))
#     split_set = SegmentDataset(dataset_list,newsgroup_configs,label_to_cutoff_indices_dict,tokenizer)
#     split_loader = DataLoader(split_set, batch_size=1, shuffle=False, pin_memory=True)

#     #Initialize bert_encoded_segments_list, this will contain the output that we want to dump
#     bert_encoded_segments_list = []
#     with torch.no_grad():
#         for idx, batch in enumerate(split_loader):
#             label =  batch[0]
#             encoded_segments = batch[1]
#             onthefly_dataset = OnTheFlyDataset(encoded_segments)
#             onthefly_loader = DataLoader(onthefly_dataset, batch_size=4, shuffle=False, pin_memory=True)
#             batch_encoded_seg_list = []
#             for ii, small_batch in enumerate(onthefly_loader):
#                 small_batch = squeeze_tensors(small_batch)
#                 batch_input_ids = small_batch['input_ids'].to(device)
#                 batch_token_type_ids = small_batch['token_type_ids'].to(device)
#                 batch_attention_mask = small_batch['attention_mask'].to(device)
#                 out = bert_model(batch_input_ids, batch_token_type_ids, batch_attention_mask)
#                 # out['last_hidden_state'] is bsize x seq_len x embedding_size. We want to take only the embedding
#                 # which corresponds to the CLS token.
#                 sub_bert_encoded_segments = out['last_hidden_state'][:,0,:] #take only the first
#                 batch_encoded_seg_list.append(sub_bert_encoded_segments)
#             bert_encoded_segments = torch.cat(batch_encoded_seg_list)
#             bert_encoded_segments_list.append((label,bert_encoded_segments.cpu()))
#     file_name = 'bert_encoded_segments_list_'
#     with open(processed_dir+ split+'/' + file_name + str(threshold) +'.pkl', 'wb') as handle:
#         pickle.dump(bert_encoded_segments_list, handle, protocol=pickle.HIGHEST_PROTOCOL)