# Author: Yoonhyuck WOO / JBNU_Industrial Information system Engineering
# Date; 2. 22. 2022 - 2. . 2022
# Title: Korean_NER
# Professor: Seung-Hoon Na

In [2]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
import random

In [3]:
def make_random_len_data_list(min_len, max_len, num_data):
    random_data = []
    
    for i in range(num_data):
        sample_len = random.randrange(min_len, max_len)
        sample = [random.randint(0, 9) for ii in range(sample_len)]
        random_data.append(sample)
    
    return random_data

In [4]:
make_random_len_data_list(10, 20, 10)

[[2, 7, 5, 1, 5, 2, 6, 3, 6, 0, 6, 1, 3, 8, 2, 2, 9, 6],
 [5, 9, 8, 2, 9, 4, 1, 3, 5, 9, 8, 1, 9, 6, 4, 1, 3],
 [2, 2, 9, 4, 5, 2, 8, 1, 7, 5, 0, 8],
 [4, 2, 1, 0, 2, 2, 8, 9, 5, 3, 9],
 [2, 1, 7, 5, 9, 2, 5, 7, 9, 7, 4, 0, 3, 5, 3, 9],
 [8, 4, 0, 7, 1, 2, 9, 3, 9, 4, 5, 6, 9, 5, 6, 7, 7],
 [9, 1, 6, 7, 4, 4, 4, 1, 0, 4, 7, 2, 5],
 [5, 9, 3, 3, 6, 8, 5, 4, 2, 9, 2, 7],
 [7, 9, 7, 9, 5, 6, 6, 9, 7, 8, 8, 9, 0, 0, 1, 1, 4, 6],
 [2, 2, 2, 9, 5, 4, 6, 3, 1, 0, 0, 9, 0, 7, 4, 6, 4]]

# __getitem__
 - If slicing is performed in the list while helping to implement slicing, it is important that the '__getitem__ ' method is executed internally. Therefore, the __getitem__ method is essential to slice on an object.
 - In order to implement slicing through the object itself without direct access to the instance variable, the **getitem special method must be defined.** And this function must receive the index as an argument.
 
 # __len__
- By defining a __len_() function in the class, an instance of the class may be transferred to the __len_() function.

In [5]:
class Dataset_custom(Dataset):
    def __init__(self, data):
        self.x = data
    
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return self.x[idx]

# Padding

In [50]:
def make_same_len(batch):
    
    each_len_list = [len(sample) for sample in batch]
    print('each_len_list', each_len_list)
    
    max_len = max(each_len_list)
    
    padded_batch = []
    pad_id = 0
    
    for sample in batch:
        padded_batch.append([13] + sample + [pad_id] * (max_len - len(sample)) + [13])
#         padded_batch.insert(-1,13)
    return padded_batch

In [51]:
rand = make_random_len_data_list(2, 11, 5) # (min_len, max_len, num_data)
example = make_same_len(rand)

print('rand')
print(rand)
print('example')
print(example)

each_len_list [3, 10, 10, 2, 6]
rand
[[6, 9, 4], [2, 0, 8, 0, 2, 6, 4, 2, 7, 4], [0, 9, 7, 5, 9, 3, 5, 7, 6, 2], [1, 9], [5, 8, 3, 3, 6, 1]]
example
[[13, 6, 9, 4, 0, 0, 0, 0, 0, 0, 0, 13], [13, 2, 0, 8, 0, 2, 6, 4, 2, 7, 4, 13], [13, 0, 9, 7, 5, 9, 3, 5, 7, 6, 2, 13], [13, 1, 9, 0, 0, 0, 0, 0, 0, 0, 0, 13], [13, 5, 8, 3, 3, 6, 1, 0, 0, 0, 0, 13]]


- Attention_mask : 1 where you care and 0 where you don't care. 1: actual word 0: non-ac
 - Input_ids : the IDs of the sentence morpheme.
 - Token_type_ids : for the question problem, but it's enough to set it to zero now.

In [29]:
def collate_fn_custom_2(input_ids, attention_mask): # => (tokenizer, tag_converter,token_type_ids) 
                                                          # token_type_ids: 0으로 설정 지금은 필요 x => 전처리 데이터 (batch)
    
    
    padded_batch = make_same_len(batch)
    
    padded_batch = torch.tensor(padded_batch)
    
    return padded_batch

In [30]:
def collate_fn_custom_3(data):
    inputs = [sample[0] for sample in data]
    labels = [sample[2] for sample in data]
    padded_inputs = torch.nn.utils.rnn.pad_sequence(inputs, batch_first = True)
    
    return {'input': padded_inputs.contiguous(),
            'label': torch.stack(labels).contiguous()}

# Sample

In [31]:
def collate_fn_custom(batch):
    
    padded_batch = make_same_len(batch)
    
    padded_batch = torch.tensor(padded_batch)
    
    return padded_batch

In [32]:
rd = make_random_len_data_list(10, 20, 10)
ds = Dataset_custom(rd)

In [33]:
print(len(ds))
ds[0:3]

10


[[6, 9, 9, 0, 2, 9, 3, 6, 1, 6, 0, 3, 7, 7, 2, 9],
 [1, 6, 3, 6, 1, 8, 0, 1, 8, 9, 0, 3, 5, 2, 1, 8, 5, 7, 1],
 [4, 8, 9, 8, 0, 0, 7, 8, 3, 8, 4, 7, 0, 7, 3, 1, 5, 4, 7]]

In [34]:
collate_fn_custom(ds[0:3])

each_len_list [16, 19, 19]


TypeError: an integer is required (got type list)

In [35]:
dl = DataLoader (
    ds,
    batch_size = 2,
    shuffle = True,
    collate_fn = collate_fn_custom
)

In [36]:
for i, batch in enumerate(dl):
    print(batch)

each_len_list [15, 19]


TypeError: an integer is required (got type list)