In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


from nltk.tokenize import sent_tokenize
from torch.utils.data import DataLoader

from transformers import BertTokenizer,BertModel

from torch.utils.data import Dataset

from datasets import load_from_disk,load_dataset

import pickle
import logging
logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)
from tqdm.notebook import tqdm


In [2]:
class OnTheFlyDataset(Dataset):
    def __init__(self, tensor):
        self.tensor = tensor
        
    def __len__(self):
        return self.tensor.shape[0]
 
    def __getitem__(self,idx):
        return(self.tensor[idx])

In [8]:
data_dir = r'~/NLU_data/processed/20news'
processed_dir = data_dir + '/'

newsgroup_configs = ['bydate_alt.atheism',
                     'bydate_comp.graphics',
                     'bydate_comp.os.ms-windows.misc',
                     'bydate_comp.sys.ibm.pc.hardware',
                     'bydate_comp.sys.mac.hardware',
                     'bydate_comp.windows.x',
                     'bydate_misc.forsale',
                     'bydate_rec.autos',
                     'bydate_rec.motorcycles',
                     'bydate_rec.sport.baseball',
                     'bydate_rec.sport.hockey',
                     'bydate_sci.crypt',
                     'bydate_sci.electronics',
                     'bydate_sci.med',
                     'bydate_sci.space',
                     'bydate_soc.religion.christian',
                     'bydate_talk.politics.guns',
                     'bydate_talk.politics.mideast',
                     'bydate_talk.politics.misc',
                     'bydate_talk.religion.misc']


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

Using device: cuda
Quadro RTX 8000
Memory Usage:
Allocated: 0.4 GB
Cached:    0.5 GB


In [9]:
configs = ['bydate_alt.atheism',
                         'bydate_comp.graphics',
                         'bydate_comp.os.ms-windows.misc',
                         'bydate_comp.sys.ibm.pc.hardware',
                         'bydate_comp.sys.mac.hardware',
                         'bydate_comp.windows.x',
                         'bydate_misc.forsale',
                         'bydate_rec.autos',
                         'bydate_rec.motorcycles',
                         'bydate_rec.sport.baseball',
                         'bydate_rec.sport.hockey',
                         'bydate_sci.crypt',
                         'bydate_sci.electronics',
                         'bydate_sci.med',
                         'bydate_sci.space',
                         'bydate_soc.religion.christian',
                         'bydate_talk.politics.guns',
                         'bydate_talk.politics.mideast',
                         'bydate_talk.politics.misc',
                         'bydate_talk.religion.misc']

In [5]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model= BertModel.from_pretrained('bert-base-uncased')
bert_model.eval()
bert_model.to(device)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [10]:
splits = ['train','test']
split = splits[0]

In [11]:
dataset_list = []
label_to_label_idx_dict={}
for config in newsgroup_configs:
    subset_path = data_dir + split + '/' + config
    dataset_list.append((config,load_from_disk(subset_path)))
    
for ii,label in enumerate(configs):
    label_to_label_idx_dict[label]=ii

FileNotFoundError: Directory ~/NLU_data/processed/20newstrain/bydate_alt.atheism not found

In [8]:
bert_encoded_segments_list = []
for label, sub_dataset in dataset_list: #Loop over all dataset
    
    for entry in tqdm(sub_dataset): #Loop inside the dataset
        # get text and CLS token
        text = entry['text']
        tokens = tokenizer.encode(text,add_special_tokens=False,return_tensors='pt')[0]
        cls_token = tokenizer.encode(tokenizer.cls_token,add_special_tokens=False,return_tensors='pt')[0]
        # Start the While loop - here, we try to get spans of 200 tokens, with a shift of 50. 
        # E.g if the sequence is 300 tokens, we get [0,199][50,249],[100,299]
        start_index = 0
        first_time = True
        all_sub_seq = []
        tokens_left = len(tokens)

        while tokens_left > 0:
            sub_seq = tokens[start_index:start_index+200]
            #Update Tokens left
            if first_time is True:
                tokens_left -=200
                first_time=False
            else:
                tokens_left -=50
            # add start_idx
            start_index+=50
            # Add new sub_sequence to our list of sub_sequences
            sub_seq_w_cls =torch.cat([cls_token,sub_seq]).unsqueeze(0)
            if tokens_left <=0: #if this is the last run, make sure to pad the last sequence to be 201 tokens long:
                sub_seq_w_cls = tokenizer.encode(sub_seq_w_cls.tolist()[0],padding='max_length',max_length=201,add_special_tokens=False,return_tensors='pt')
            all_sub_seq.append(sub_seq_w_cls)

        #cat to make a tensor
        segments_tensor = torch.cat(all_sub_seq)
        # turn all_sub_seq into an OnTheFlyDataset    
        onthefly_dataset = OnTheFlyDataset(segments_tensor)
        onthefly_loader =  DataLoader(onthefly_dataset, batch_size=64, shuffle=False, pin_memory=True)
        #At this point, onthefly_datset/loader contains the tokens for one single "long document", in one dataset
        with torch.no_grad():
            batch_encoded_seg_list = []
            for small_batch in onthefly_loader: #encode each segment in the long document
                out = bert_model(input_ids=small_batch.to(device))
                sub_bert_encoded_segments = out['last_hidden_state'][:,0,:]
                batch_encoded_seg_list.append(sub_bert_encoded_segments)
            bert_encoded_segments = torch.cat(batch_encoded_seg_list)
            bert_encoded_segments_list.append((label_to_label_idx_dict[label],bert_encoded_segments.cpu()))
    file_name = 'bert_encoded_segments_list_'
    with open(processed_dir+ split+'\\' + file_name + 'baseline' +'.pkl', 'wb') as handle:
        pickle.dump(bert_encoded_segments_list, handle, protocol=pickle.HIGHEST_PROTOCOL)


HBox(children=(FloatProgress(value=0.0, max=480.0), HTML(value='')))

Token indices sequence length is longer than the specified maximum sequence length for this model (2754 > 512). Running this sequence through the model will result in indexing errors





HBox(children=(FloatProgress(value=0.0, max=584.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=591.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=590.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=578.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=593.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=585.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=594.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=598.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=597.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=600.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=595.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=591.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=594.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=593.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=599.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=546.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=564.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=465.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=377.0), HTML(value='')))




In [9]:
data_dir = r'\\wsl$\Ubuntu-20.04\home\jolteon\NLUProject\data\20news\\'
processed_dir = data_dir + 'processed\\'
with open(processed_dir+ split+'\\' + file_name + 'baseline' +'.pkl', 'wb') as handle:
    pickle.dump(bert_encoded_segments_list, handle, protocol=pickle.HIGHEST_PROTOCOL)


In [10]:
label

'bydate_talk.religion.misc'

In [11]:
cls_token = tokenizer.encode(tokenizer.cls_token,add_special_tokens=False,return_tensors='pt')[0]

start_index = 0
first_time = True
all_sub_seq = []
tokens_left = len(tokens)

while tokens_left > 0:
    sub_seq = tokens[start_index:start_index+200]
    #Update Tokens left
    if first_time is True:
        tokens_left -=200
        first_time=False
    else:
        tokens_left -=50
    # add start_idx
    start_index+=50
    # Add new sub_sequence to our list of sub_sequences
    sub_seq_w_cls =torch.cat([cls_token,sub_seq]).unsqueeze(0)
    if tokens_left <=0: #if this is the last run, make sure to pad the last sequence to be 201 tokens long:
        sub_seq_w_cls = tokenizer.encode(sub_seq_w_cls.tolist()[0],padding='max_length',max_length=201,add_special_tokens=False,return_tensors='pt')
    all_sub_seq.append(sub_seq_w_cls)

#cat to make a tensor
segments_tensor = torch.cat(all_sub_seq)
# turn all_sub_seq into an OnTheFlyDataset    
segment_dataset = OnTheFlyDataset(segments_tensor)
onthefly_loader =  DataLoader(onthefly_dataset, batch_size=8, shuffle=False, pin_memory=True)

batch_encoded_seg_list = []
for ii, small_batch in enumerate(onthefly_loader):
    out = bert_model(input_ids=small_batch.to(device))
    sub_bert_encoded_segments = out['pooler_output']
    batch_encoded_seg_list.append(sub_bert_encoded_segments)
bert_encoded_segments = torch.cat(batch_encoded_seg_list)
bert_encoded_segments_list.append((label,bert_encoded_segments.cpu()))


In [12]:
segment_dataset

<__main__.OnTheFlyDataset at 0x21d38875490>

In [13]:
sub_seq_w_cls.shape

torch.Size([1, 201])

In [14]:
sub_seq_w_cls.tolist()

[[101,
  2129,
  1998,
  2043,
  1010,
  1998,
  2025,
  1010,
  2005,
  1996,
  2087,
  2112,
  1010,
  2040,
  1998,
  2339,
  1012,
  2671,
  2001,
  2947,
  2141,
  2041,
  1997,
  1996,
  19176,
  2015,
  1005,
  10628,
  1010,
  2776,
  11122,
  2075,
  2185,
  1996,
  2129,
  1998,
  2043,
  2096,
  4321,
  2975,
  2369,
  1996,
  2040,
  1998,
  2339,
  1012,
  1996,
  21591,
  1010,
  1996,
  14337,
  1010,
  1996,
  2046,
  3917,
  4630,
  1010,
  1998,
  1996,
  15818,
  1010,
  1997,
  2607,
  1010,
  2145,
  4366,
  3691,
  1999,
  2035,
  2176,
  13100,
  1012,
  1028,
  1064,
  1028,
  4138,
  4419,
  1010,
  14405,
  8093,
  2080,
  1010,
  2149,
  17167,
  23597,
  17287,
  1028,
  2106,
  2066,
  2115,
  6594,
  2105,
  2572,
  7898,
  1010,
  1998,
  1045,
  2106,
  3275,
  2041,
  2054,
  2572,
  2232,
  2001,
  2013,
  1028,
  2115,
  2434,
  2695,
  1024,
  1011,
  1007,
  2172,
  14723,
  1012,
  6057,
  2129,
  8866,
  7166,
  2000,
  8494,
  10362,
  2477,
  10

In [15]:
cls_token.shape

torch.Size([1])

In [16]:
sub_seq.shape

torch.Size([195])

In [17]:
temp = all_sub_seq[-1]

In [18]:
temp1 = tokenizer.encode(temp,padding='max_length',max_length=201,add_special_tokens=False)

ValueError: Input tensor([[  101,  2129,  1998,  2043,  1010,  1998,  2025,  1010,  2005,  1996,
          2087,  2112,  1010,  2040,  1998,  2339,  1012,  2671,  2001,  2947,
          2141,  2041,  1997,  1996, 19176,  2015,  1005, 10628,  1010,  2776,
         11122,  2075,  2185,  1996,  2129,  1998,  2043,  2096,  4321,  2975,
          2369,  1996,  2040,  1998,  2339,  1012,  1996, 21591,  1010,  1996,
         14337,  1010,  1996,  2046,  3917,  4630,  1010,  1998,  1996, 15818,
          1010,  1997,  2607,  1010,  2145,  4366,  3691,  1999,  2035,  2176,
         13100,  1012,  1028,  1064,  1028,  4138,  4419,  1010, 14405,  8093,
          2080,  1010,  2149, 17167, 23597, 17287,  1028,  2106,  2066,  2115,
          6594,  2105,  2572,  7898,  1010,  1998,  1045,  2106,  3275,  2041,
          2054,  2572,  2232,  2001,  2013,  1028,  2115,  2434,  2695,  1024,
          1011,  1007,  2172, 14723,  1012,  6057,  2129,  8866,  7166,  2000,
          8494, 10362,  2477,  1010,  3475,  1005,  1056,  2009,  1029,  2092,
          1010,  1045,  2572,  2469,  2045,  2024,  7564,  1997,  1000,  4045,
          1000,  4325,  2923,  1000,  2128,  8569, 28200,  2015,  1000,  2041,
          2045,  4873,  1010,  2130,  2065,  2027,  2031,  2000,  2022,  2580,
          2013,  2498,  1012,  1031,  2074,  2005,  1996,  2501,  1010,  2153,
          1010,  2572,  2232,  1027, 28141,  2135,  2715,  4286,  1033,  2190,
         12362,  1024,  1011,  1007,  1010,  4138,  4419,  1010, 14405,  8093,
          2080,  1010,  2149, 17167, 23597, 17287,     0,     0,     0,     0,
             0]]) is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers.

In [None]:
temp1

In [None]:
a = [1.2,3,4,5]

In [None]:
a[50:6]

In [None]:
sub_seq = tokens[start_index:start_index+50]




In [None]:
cls_token+sub_seq