In [None]:
def FilterHighFreqNoun(pos_dir:str, split_name:str, idx_to_keep:list):
    print('\tchecking highest freq noun...')
    total_num_sent_noun1 = 0
    noun_to_sent_idx_mapping = dict()
    
    with os.scandir(pos_dir) as it:
        for entry in it:
            if entry.name.startswith(split_name): # load all pos info with given split name
                with open(entry.path, 'r', encoding='utf-8') as textfile:
                    pos_info = json.load(textfile)
                    for (sent_idx, sent_pos) in pos_info.items():
                        sent_idx = int(sent_idx)
                        if sent_idx in idx_to_keep: # ignore sentences based on idx filter
                            for token_pos in sent_pos:
                                if token_pos['pos'] == 'NN' or token_pos['pos'] == 'NNS':
                                    # only care for NN and NNS for now, ignore pronouns NNP
                                    # the token index given by stanford pos starts with 1 in a sent
                                    word = token_pos['word']
                                    token_idx = token_pos['index'] - 1
                                    if not word in noun_to_sent_idx_mapping:
                                        noun_to_sent_idx_mapping[word] = list()
                                    noun_to_sent_idx_mapping[word].append([sent_idx, token_idx])
    
    # filter out the noun with the most frequent appearance from sorting
    sorted_noun_to_sent_idx_mapping = sorted(noun_to_sent_idx_mapping.items(), key=lambda x: len(x[1]), reverse=True)
    noun_to_sent_idx_mapping = dict()
    sent_idx_filter = list()
    sent_idx_to_noun_idx_mapping = dict()
    
    for i in range(10):
        highfreq_noun = sorted_noun_to_sent_idx_mapping[i][0]
        highfreq_position = sorted_noun_to_sent_idx_mapping[i][1]
        for (sent_idx, noun_idx) in highfreq_position:
            if sent_idx not in sent_idx_filter:
                sent_idx_filter.append(sent_idx)
                sent_idx_to_noun_idx_mapping[sent_idx] = list()
            sent_idx_to_noun_idx_mapping[sent_idx].append(noun_idx)
        noun_to_sent_idx_mapping[highfreq_noun] = highfreq_position
    
    with open('top10noun_with_idx.json', 'w', encoding='utf-8') as textfile:
        json.dump(noun_to_sent_idx_mapping, textfile)
    
    return noun_to_sent_idx_mapping.keys(), sent_idx_filter, sent_idx_to_noun_idx_mapping

In [None]:
def MakeDataloaderNoun(filteridx_filename:str, sent_filename:str, pos_dir:str, split_name:str, glove_model, embedding_dim:int, max_len:int, batch_size:int, num_workers:int):
    print('making data loader for {} split'.format(split_name))
    idx_to_keep = LoadFilterIdxFile(filteridx_filename)
    nouns, sent_idx_filter, sent_idx_to_noun_idx_mapping = FilterHighFreqNoun(pos_dir, split_name, idx_to_keep)
    idx_to_sentence_mapping = LoadSentenceFile(sent_filename, sent_idx_filter)
    sentence_embeddings, label_embeddings, label_positions = SentToEmbedding(idx_to_sentence_mapping, sent_idx_to_noun_idx_mappping, glove_model, embedding_dim, max_len, split_name)
    data = WikiDataset(sentence_embeddings, label_embeddings, label_positions)
    
    print('\tfinalizing making data loader, num sentences in loader: {}'.format(len(label_positions)))
    
    if split_name == 'test':
        if_shuffle = False
    else:
        if_shuffle = True
        
    # NOTE: for gpu usage, do not use num_workers, turn pin_memory to True
    loader = DataLoader(dataset=data, batch_size=batch_size, shuffle=if_shuffle, num_workers=num_workers, pin_memory=False) 
    
    return loader

In [None]:
# sample noun from train
split_name = 'train'
filter_path = './dataset_filtered/{}_idx_filtered.txt'.format(split_name)
file_path = './dataset_filtered/{}.json'.format(split_name)
pos_dir = './dataset_pos'
embedding_dim = 300
batch_size = 128
num_workers = 1
max_len = int(22.514047250226433 + 14.624483763629705)
nounloader = MakeDataloaderNoun(filter_path, file_path, pos_dir, split_name, glove_model, embedding_dim, max_len, batch_size, num_workers)
print(len(nounloader))