In [1]:
import pandas as pd
import torch 
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import pandas as pd
from torch.utils.data import Dataset

class ThuCNewsdataset(Dataset):
    def __init__(
        self,
        data_path,
        
    ):
        self.data_path = data_path
        self._get_data()
        
    def _get_data(self):
        with open(self.data_path,'r',encoding='utf-8') as f:
            data = f.readlines()
        split_data = []
        for item in data:
            split_item = item.split('\t')
            split_data.append(split_item)
        data_split = pd.DataFrame(split_data,columns=('category','content'))
        categories = data_split['category'].unique()
        category_to_code = {category: code for code, category in enumerate(categories)}
        data_split['category'] = data_split['category'].map(category_to_code)
        self.data = data_split
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data.loc[idx,'category'], self.data.loc[idx,'content']
        

In [40]:
class THuCNewsDataLoader(DataLoader):
    def __init__(
        self,
        dataset,
        batch_size,
        shuffle,
        drop_last,
        check_point_name,
        max_length
    ):
        super(THuCNewsDataLoader, self).__init__(  
            dataset,  
            batch_size=batch_size,  
            shuffle=shuffle,  
            drop_last=drop_last,  
            collate_fn=self.collate_fn   
        )  
        self.tokenizer = AutoTokenizer.from_pretrained(check_point_name)
        self.max_length = max_length
        
        self.dataloader = DataLoader(
            dataset = self.dataset,
            collate_fn = self.collate_fn,
            batch_size = self.batch_size,
            shuffle = shuffle,
            drop_last = drop_last
        )
    def collate_fn(self, data):
        labels = [data[i][0] for i in range(len(data))]
        content = [data[i][1] for i in range(len(data))]
        
        tokens = self.tokenizer.batch_encode_plus(
            batch_text_or_text_pairs = content,
            truncation=True,
            padding = 'max_length',
            max_length = self.max_length,
            return_tensors ='pt',
        )
        
        input_ids = tokens.input_ids
        attention_mask = tokens.attention_mask
        token_type_ids = tokens.token_type_ids
        labels = torch.LongTensor(labels)
        
        return input_ids, attention_mask, token_type_ids, labels
    
    def __len__(self):
        return len(self.dataloader)
    def __getitem__(self):
        for data in self.dataloader:
            yield data

In [38]:
ThuCNews_dataset = ThuCNewsdataset(
    data_path = '/Users/hanlinwang/Documents/GitHub/hugg-llm/data/cnews/cnews.train.txt')

In [41]:
THuCNews_DataLoader = THuCNewsDataLoader(
        ThuCNews_dataset,
        batch_size = 16,
        shuffle = True,
        drop_last = True,
        check_point_name = '/Users/hanlinwang/Downloads/bert-base-chinese/',
        max_length = 5000
    )

In [45]:
for input_ids, attention_mask, token_type_ids, labels in THuCNews_DataLoader:
    print(labels)

tensor([7, 9, 9, 3, 0, 2, 8, 0, 6, 6, 3, 8, 1, 5, 3, 8])
tensor([4, 5, 0, 3, 0, 2, 7, 0, 2, 5, 4, 2, 2, 4, 0, 9])
tensor([8, 5, 5, 2, 4, 1, 5, 9, 1, 1, 1, 3, 6, 7, 5, 5])
tensor([4, 5, 7, 7, 8, 7, 2, 9, 7, 1, 3, 5, 6, 0, 4, 6])
tensor([0, 0, 1, 7, 5, 7, 5, 6, 5, 2, 2, 8, 6, 0, 4, 0])
tensor([8, 4, 0, 4, 0, 8, 5, 8, 1, 2, 1, 5, 9, 3, 8, 5])
tensor([8, 7, 8, 9, 6, 9, 3, 3, 7, 1, 9, 4, 9, 8, 8, 6])
tensor([5, 6, 2, 6, 4, 5, 4, 7, 3, 7, 5, 5, 7, 9, 3, 9])
tensor([2, 0, 2, 8, 9, 0, 5, 9, 3, 2, 5, 3, 1, 5, 0, 4])
tensor([3, 5, 5, 5, 4, 1, 8, 2, 9, 7, 2, 6, 8, 4, 5, 8])
tensor([8, 6, 4, 6, 1, 0, 7, 4, 6, 4, 7, 9, 7, 3, 8, 3])
tensor([9, 9, 4, 3, 7, 4, 6, 0, 4, 3, 2, 0, 7, 1, 4, 2])
tensor([1, 9, 0, 1, 1, 5, 5, 7, 7, 0, 2, 1, 3, 0, 6, 4])
tensor([8, 2, 2, 2, 5, 1, 8, 2, 7, 6, 6, 1, 0, 2, 4, 2])
tensor([2, 8, 9, 8, 5, 7, 5, 1, 4, 2, 9, 5, 1, 5, 5, 1])
tensor([7, 7, 6, 5, 0, 8, 4, 5, 9, 5, 9, 0, 2, 3, 1, 6])
tensor([6, 7, 5, 3, 0, 2, 2, 0, 7, 8, 2, 8, 6, 1, 7, 0])
tensor([0, 5, 3, 4, 6, 5, 2, 5,

tensor([2, 9, 1, 8, 4, 3, 3, 6, 2, 5, 6, 8, 5, 6, 3, 1])
tensor([0, 2, 3, 4, 5, 2, 6, 9, 0, 2, 6, 8, 8, 1, 0, 1])
tensor([1, 5, 5, 4, 7, 9, 8, 5, 0, 6, 6, 6, 2, 4, 7, 4])
tensor([2, 8, 9, 9, 2, 9, 5, 7, 2, 9, 4, 6, 1, 4, 3, 9])
tensor([9, 0, 6, 4, 8, 3, 6, 3, 5, 7, 9, 5, 8, 7, 5, 7])
tensor([9, 1, 2, 0, 3, 0, 4, 1, 6, 5, 7, 3, 9, 6, 8, 5])
tensor([0, 9, 7, 5, 0, 1, 4, 6, 8, 4, 9, 3, 7, 0, 2, 0])
tensor([0, 9, 1, 0, 5, 6, 8, 7, 4, 0, 4, 2, 3, 9, 4, 2])
tensor([2, 9, 0, 8, 3, 4, 4, 6, 2, 2, 7, 9, 3, 8, 7, 5])
tensor([9, 1, 5, 7, 5, 4, 4, 2, 0, 4, 2, 4, 6, 5, 0, 4])
tensor([6, 2, 2, 8, 8, 6, 4, 3, 5, 8, 9, 5, 6, 5, 0, 6])
tensor([2, 3, 2, 1, 7, 3, 0, 0, 3, 1, 2, 8, 1, 6, 0, 6])
tensor([9, 2, 2, 1, 4, 8, 5, 6, 3, 9, 1, 9, 3, 4, 0, 1])
tensor([9, 2, 5, 9, 6, 9, 5, 6, 5, 3, 0, 9, 8, 0, 5, 1])
tensor([9, 6, 0, 1, 3, 1, 7, 3, 8, 5, 1, 0, 9, 8, 1, 8])
tensor([2, 8, 0, 7, 1, 6, 0, 0, 4, 3, 0, 3, 5, 2, 6, 7])
tensor([5, 7, 1, 6, 3, 4, 7, 7, 4, 2, 6, 9, 6, 1, 2, 1])
tensor([8, 1, 7, 3, 1, 4, 6, 9,

tensor([7, 6, 2, 6, 9, 2, 6, 3, 7, 6, 6, 6, 0, 4, 8, 7])
tensor([9, 5, 7, 8, 9, 5, 8, 5, 5, 3, 0, 7, 6, 8, 8, 2])
tensor([7, 1, 3, 5, 2, 5, 4, 2, 2, 3, 0, 7, 0, 6, 7, 4])
tensor([3, 1, 2, 7, 7, 1, 4, 7, 9, 4, 2, 9, 4, 0, 6, 3])
tensor([3, 1, 1, 9, 6, 8, 9, 8, 7, 1, 6, 3, 4, 5, 7, 2])
tensor([8, 4, 3, 8, 0, 8, 9, 2, 4, 4, 3, 1, 6, 5, 8, 0])
tensor([3, 5, 2, 4, 4, 7, 4, 7, 0, 2, 9, 1, 6, 9, 7, 7])
tensor([6, 9, 7, 0, 1, 2, 2, 7, 2, 3, 0, 4, 2, 1, 3, 1])
tensor([4, 1, 1, 0, 6, 8, 8, 7, 9, 0, 3, 7, 2, 6, 5, 0])
tensor([4, 3, 6, 3, 9, 5, 4, 8, 4, 6, 1, 9, 3, 3, 8, 1])
tensor([0, 9, 9, 2, 9, 8, 2, 1, 1, 8, 6, 8, 7, 1, 5, 3])
tensor([7, 0, 5, 5, 0, 5, 1, 6, 5, 1, 7, 6, 1, 0, 7, 0])
tensor([4, 4, 3, 8, 7, 6, 2, 1, 5, 1, 7, 4, 5, 8, 6, 9])
tensor([7, 5, 5, 6, 9, 7, 4, 3, 3, 8, 4, 1, 3, 1, 8, 8])
tensor([4, 5, 3, 7, 7, 8, 2, 5, 6, 8, 3, 2, 3, 5, 3, 4])
tensor([6, 6, 7, 1, 7, 5, 4, 9, 4, 8, 9, 7, 9, 5, 1, 8])
tensor([1, 8, 9, 1, 4, 6, 7, 2, 8, 8, 7, 7, 0, 7, 5, 7])
tensor([4, 0, 0, 3, 9, 1, 4, 5,

tensor([9, 2, 5, 3, 9, 9, 2, 8, 9, 7, 8, 0, 2, 7, 8, 3])
tensor([1, 9, 6, 1, 1, 1, 7, 6, 1, 5, 4, 3, 5, 0, 6, 7])
tensor([7, 1, 7, 3, 9, 7, 1, 4, 2, 9, 9, 3, 2, 8, 3, 5])
tensor([7, 0, 9, 4, 9, 2, 0, 6, 0, 5, 1, 7, 3, 8, 5, 0])
tensor([7, 5, 0, 6, 1, 1, 9, 8, 9, 7, 6, 6, 9, 6, 2, 7])
tensor([9, 6, 7, 4, 5, 2, 7, 4, 6, 0, 3, 7, 1, 2, 0, 5])
tensor([8, 7, 9, 8, 7, 6, 0, 3, 3, 3, 9, 7, 5, 5, 2, 5])
tensor([8, 0, 9, 7, 6, 3, 5, 4, 1, 1, 3, 4, 7, 3, 3, 1])
tensor([9, 4, 3, 0, 0, 4, 7, 5, 5, 3, 4, 2, 8, 5, 5, 2])


KeyboardInterrupt: 