In [1]:
from datasets import load_dataset, load_from_disk, Dataset
import numpy as np
import pandas as pd
from transformers import RobertaTokenizer, RobertaModel
import torch
import pickle
from torch.utils.data import TensorDataset, DataLoader

In [2]:
data_base = 'data'
dataset_name = 'cnn_dailymail'

data_path = data_base+'/'+dataset_name
sentence_emb_path = data_path+'/sentence_embs'

In [None]:
dataset = load_from_disk(
    data_path+'/sentences'
)

In [6]:
splits = ['train','test','validation']

a_id_len = dict([
    (s,dict()) for s in splits
])

for split in splits:
  cumu = 0
  for a_id, a_len in zip(dataset[split]['id'],dataset[split]['article_length']):
    a_id_len[split][a_id] = (a_len, cumu)
    cumu+=a_len   # cumulative sum -> num sentences upto before this article

with open(sentence_emb_path+'/articleID_sentences_info.pkl', 'wb') as pickle_file:
    pickle.dump(a_id_len, pickle_file, protocol=pickle.HIGHEST_PROTOCOL)

In [4]:
with open(sentence_emb_path+'/articleID_sentences_info.pkl', "rb") as input_f:
    a_id_len = pickle.load(input_f)

def get_chunk_batch_sent_idx(split, a_id, a_s_idx): # data_split, article id, sentence idx within article
  num_s_per_chunk = 20000*40

  s_idx = a_id_len[split][a_id][1]+a_s_idx
  chunk_num = int(s_idx/num_s_per_chunk)+1
  batch_idx = int((s_idx%num_s_per_chunk)/40)
  batch_s_idx = (s_idx%num_s_per_chunk)%40

  return chunk_num, batch_idx, batch_s_idx

def get_article_sent_emb(split, chunk_num, batch_idx, batch_s_idx, sent_embs_chunk=None):
  if sent_embs_chunk==None:
      with open(
          sentence_emb_path+'/'+split+'/'+split+'_part_'+str(chunk_num)+".pkl", 
          "rb") as input_f:
        sent_embs_chunk = pickle.load(input_f)
  return sent_embs_chunk[batch_idx][batch_s_idx], sent_embs_chunk

In [7]:
# dataset = load_from_disk(
#     'data/cnn_dailymail/sentences/test/test_rewards'
# )

In [4]:
# dataset = load_from_disk(
#     'data/cnn_dailymail/sentences/validation/val_rewards'
# )

In [None]:
dataset = load_from_disk(
    'data/cnn_dailymail/sentences/train/train_rewards'
)

In [8]:
dataset

Dataset({
    features: ['highlights', 'id', 'article_rouge_f1', 'article_top_sent', 'top_summaries', 'all_gold_actions', 'all_gold_rewards', 'top_full_sentences'],
    num_rows: 11490
})

In [9]:
# required_items = []
# for a_id, a_s_idx_list in zip(dataset['id'], dataset['article_top_sent']):
#     for a_s_idx in a_s_idx_list:
#         chunk_num, batch_idx, batch_s_idx = get_chunk_batch_sent_idx('test', a_id, a_s_idx)
#         required_items.append((chunk_num, batch_idx, batch_s_idx))

In [10]:
# required_items = []
# for a_id, a_s_idx_list in zip(dataset['id'], dataset['article_top_sent']):
#     for a_s_idx in a_s_idx_list:
#         chunk_num, batch_idx, batch_s_idx = get_chunk_batch_sent_idx('validation', a_id, a_s_idx)
#         required_items.append((chunk_num, batch_idx, batch_s_idx))

In [6]:
required_items = []
for a_id, a_s_idx_list in zip(dataset['article_id']['article_id'], dataset['top_sentences_index']['top_sentences_index']):
    for a_s_idx in a_s_idx_list:
        chunk_num, batch_idx, batch_s_idx = get_chunk_batch_sent_idx('train', a_id, a_s_idx)
        required_items.append((chunk_num, batch_idx, batch_s_idx))

In [10]:
len(required_items)

114386

In [11]:
loaded_current_chunk = None
sent_embs_chunk = None
res = []

In [12]:
# for i,(chunk_num, batch_idx, batch_s_idx) in enumerate(required_items):
# #     if i<=(2053602-1):
# #         res.append(None)
# #     else:
#         if loaded_current_chunk==chunk_num:
#             s_emb, sent_embs_chunk = get_article_sent_emb(
#                 'test', chunk_num, batch_idx, batch_s_idx, sent_embs_chunk
#             )
#             res.append(s_emb)
#         else:
#             print('  Loading chunk:', chunk_num)
#             loaded_current_chunk = chunk_num
#             s_emb, sent_embs_chunk = get_article_sent_emb(
#                 'test', chunk_num, batch_idx, batch_s_idx, None
#             )
#             res.append(s_emb)
#         if i%500==0:
#             print('Loaded',i+1,'embs of',len(required_items))

  Loading chunk: 1
Loaded 1 embs of 114386
Loaded 501 embs of 114386
Loaded 1001 embs of 114386
Loaded 1501 embs of 114386
Loaded 2001 embs of 114386
Loaded 2501 embs of 114386
Loaded 3001 embs of 114386
Loaded 3501 embs of 114386
Loaded 4001 embs of 114386
Loaded 4501 embs of 114386
Loaded 5001 embs of 114386
Loaded 5501 embs of 114386
Loaded 6001 embs of 114386
Loaded 6501 embs of 114386
Loaded 7001 embs of 114386
Loaded 7501 embs of 114386
Loaded 8001 embs of 114386
Loaded 8501 embs of 114386
Loaded 9001 embs of 114386
Loaded 9501 embs of 114386
Loaded 10001 embs of 114386
Loaded 10501 embs of 114386
Loaded 11001 embs of 114386
Loaded 11501 embs of 114386
Loaded 12001 embs of 114386
Loaded 12501 embs of 114386
Loaded 13001 embs of 114386
Loaded 13501 embs of 114386
Loaded 14001 embs of 114386
Loaded 14501 embs of 114386
Loaded 15001 embs of 114386
Loaded 15501 embs of 114386
Loaded 16001 embs of 114386
Loaded 16501 embs of 114386
Loaded 17001 embs of 114386
Loaded 17501 embs of 1143

In [16]:
# for i,(chunk_num, batch_idx, batch_s_idx) in enumerate(required_items):
# #     if i<=(2053602-1):
# #         res.append(None)
# #     else:
#         if loaded_current_chunk==chunk_num:
#             s_emb, sent_embs_chunk = get_article_sent_emb(
#                 'validation', chunk_num, batch_idx, batch_s_idx, sent_embs_chunk
#             )
#             res.append(s_emb)
#         else:
#             print('  Loading chunk:', chunk_num)
#             loaded_current_chunk = chunk_num
#             s_emb, sent_embs_chunk = get_article_sent_emb(
#                 'validation', chunk_num, batch_idx, batch_s_idx, None
#             )
#             res.append(s_emb)
#         if i%500==0:
#             print('Loaded',i+1,'embs of',len(required_items))

  Loading chunk: 1
Loaded 1 embs of 133144
Loaded 501 embs of 133144
Loaded 1001 embs of 133144
Loaded 1501 embs of 133144
Loaded 2001 embs of 133144
Loaded 2501 embs of 133144
Loaded 3001 embs of 133144
Loaded 3501 embs of 133144
Loaded 4001 embs of 133144
Loaded 4501 embs of 133144
Loaded 5001 embs of 133144
Loaded 5501 embs of 133144
Loaded 6001 embs of 133144
Loaded 6501 embs of 133144
Loaded 7001 embs of 133144
Loaded 7501 embs of 133144
Loaded 8001 embs of 133144
Loaded 8501 embs of 133144
Loaded 9001 embs of 133144
Loaded 9501 embs of 133144
Loaded 10001 embs of 133144
Loaded 10501 embs of 133144
Loaded 11001 embs of 133144
Loaded 11501 embs of 133144
Loaded 12001 embs of 133144
Loaded 12501 embs of 133144
Loaded 13001 embs of 133144
Loaded 13501 embs of 133144
Loaded 14001 embs of 133144
Loaded 14501 embs of 133144
Loaded 15001 embs of 133144
Loaded 15501 embs of 133144
Loaded 16001 embs of 133144
Loaded 16501 embs of 133144
Loaded 17001 embs of 133144
Loaded 17501 embs of 1331

In [9]:
for i,(chunk_num, batch_idx, batch_s_idx) in enumerate(required_items):
    if i<=(2053602-1):
        res.append(None)
    else:
        if loaded_current_chunk==chunk_num:
            s_emb, sent_embs_chunk = get_article_sent_emb(
                'train', chunk_num, batch_idx, batch_s_idx, sent_embs_chunk
            )
            res.append(s_emb)
        else:
            print('  Loading chunk:', chunk_num)
            loaded_current_chunk = chunk_num
            s_emb, sent_embs_chunk = get_article_sent_emb(
                'train', chunk_num, batch_idx, batch_s_idx, None
            )
            res.append(s_emb)
        if i%500==0:
            print('Loaded',i+1,'embs of',len(required_items))

  Loading chunk: 11
Loaded 2054001 embs of 2860802
Loaded 2054501 embs of 2860802
Loaded 2055001 embs of 2860802
Loaded 2055501 embs of 2860802
Loaded 2056001 embs of 2860802
Loaded 2056501 embs of 2860802
Loaded 2057001 embs of 2860802
Loaded 2057501 embs of 2860802
Loaded 2058001 embs of 2860802
Loaded 2058501 embs of 2860802
Loaded 2059001 embs of 2860802
Loaded 2059501 embs of 2860802
Loaded 2060001 embs of 2860802
Loaded 2060501 embs of 2860802
Loaded 2061001 embs of 2860802
Loaded 2061501 embs of 2860802
Loaded 2062001 embs of 2860802
Loaded 2062501 embs of 2860802
Loaded 2063001 embs of 2860802
Loaded 2063501 embs of 2860802
Loaded 2064001 embs of 2860802
Loaded 2064501 embs of 2860802
Loaded 2065001 embs of 2860802
Loaded 2065501 embs of 2860802
Loaded 2066001 embs of 2860802
Loaded 2066501 embs of 2860802
Loaded 2067001 embs of 2860802
Loaded 2067501 embs of 2860802
Loaded 2068001 embs of 2860802
Loaded 2068501 embs of 2860802
Loaded 2069001 embs of 2860802
Loaded 2069501 embs

In [13]:
len(res)

114386

In [14]:
# import pickle

# part = 1
# start = 600000*(part-1)
# end = 600000*(part-0)
# with open(
#     data_path+'/top_sentence_embs/test/'+'test_top_sentences_part_'+str(part)+'.pkl', 
#     'wb') as pickle_file:
#     pickle.dump(res[start:end], pickle_file, protocol=pickle.HIGHEST_PROTOCOL)

In [19]:
# import pickle

# part = 1
# start = 600000*(part-1)
# end = 600000*(part-0)
# with open(
#     data_path+'/top_sentence_embs/validation/'+'validation_top_sentences_part_'+str(part)+'.pkl', 
#     'wb') as pickle_file:
#     pickle.dump(res[start:end], pickle_file, protocol=pickle.HIGHEST_PROTOCOL)

In [11]:
import pickle

part = 1
start = 600000*(part-1)
end = 600000*(part-0)
with open(
    data_path+'/top_sentence_embs/train/'+'train_top_sentences_part_'+str(part)+'.pkl', 
    'wb') as pickle_file:
    pickle.dump(res[start:end], pickle_file, protocol=pickle.HIGHEST_PROTOCOL)

part = 2
start = 600000*(part-1)
end = 600000*(part-0)
with open(
    data_path+'/top_sentence_embs/train/'+'train_top_sentences_part_'+str(part)+'.pkl', 
    'wb') as pickle_file:
    pickle.dump(res[start:end], pickle_file, protocol=pickle.HIGHEST_PROTOCOL)

part = 3
start = 600000*(part-1)
end = 600000*(part-0)
with open(
    data_path+'/top_sentence_embs/train/'+'train_top_sentences_part_'+str(part)+'.pkl', 
    'wb') as pickle_file:
    pickle.dump(res[start:end], pickle_file, protocol=pickle.HIGHEST_PROTOCOL)

part = 4
start = 600000*(part-1)
end = 600000*(part-0)
with open(
    data_path+'/top_sentence_embs/train/'+'train_top_sentences_part_'+str(part)+'.pkl', 
    'wb') as pickle_file:
    pickle.dump(res[start:end], pickle_file, protocol=pickle.HIGHEST_PROTOCOL)

part = 5
start = 600000*(part-1)
end = 600000*(part-0)
with open(
    data_path+'/top_sentence_embs/train/'+'train_top_sentences_part_'+str(part)+'.pkl', 
    'wb') as pickle_file:
    pickle.dump(res[start:end], pickle_file, protocol=pickle.HIGHEST_PROTOCOL)