In [1]:
from  torch.utils.data import Dataset

# Dataset

PyTorch使您可以自由地对Dataset类执行任何操作，只要您重写改类中的两个函数即可：

* \_\_len__ 函数：返回数据集大小
* \_\_getitem__ 函数：返回对应索引的数据集中的样本

In [4]:
class NumberDateset(Dataset):
    def __init__(self):
        super(NumberDateset,self).__init__()
        self.data = list(range(0,100))
    def __len__(self):
        return len(self.data)
    def __getitem__(self,idx):
        return self.data[idx]
        

In [5]:
if __name__ == '__main__':
    dataset= NumberDateset()
    print(len(dataset))
    print(dataset[10])
    print(dataset[20:30])

100
10
[20, 21, 22, 23, 24, 25, 26, 27, 28, 29]


In [7]:
class NumberDateset(Dataset):
    def __init__(self,low,high):
        super(NumberDateset,self).__init__()
        self.data = list(range(low,high))
    def __len__(self):
        return len(self.data)
    def __getitem__(self,idx):
        return self.data[idx]

In [8]:
if __name__ == '__main__':
    dataset= NumberDateset(10,321)
    print(len(dataset))
    print(dataset[10])
    print(dataset[20:30])

311
20
[30, 31, 32, 33, 34, 35, 36, 37, 38, 39]


In [35]:
import os 
class TESNamesDataset(Dataset):
    def __init__(self,data_root):
        self.sample= []
        for race in os.listdir(data_root):
            race_folder = os.path.join(data_root,race)
            for gender in os.listdir(race_folder):
                gender_filepath = os.path.join(race_folder,gender)
                with open(gender_filepath,'r') as f:
                    for name in f.read().splitlines():
                        self.sample.append((race,gender,name))
#                     self.sample.extend(f.read().split())

    def __len__(self):
        return len(self.sample)
    def __getitem__(self,idx):
        return self.sample[idx]

In [36]:
if __name__ == '__main__':
    data_root = "./data/tes-names/"
    dataset= TESNamesDataset(data_root)
    print(len(dataset))
    print(dataset[420])
    print(dataset[10:60])

19491
('Altmer', 'Female', 'Hanyarie')
[('Altmer', 'Female', 'Alanwe'), ('Altmer', 'Female', 'Alanya'), ('Altmer', 'Female', 'Alcalime'), ('Altmer', 'Female', 'Alcardawe'), ('Altmer', 'Female', 'Alcildilwe'), ('Altmer', 'Female', 'Alcorana'), ('Altmer', 'Female', 'Aldamaire'), ('Altmer', 'Female', 'Aldanya'), ('Altmer', 'Female', 'Aldarenya'), ('Altmer', 'Female', 'Aldewe'), ('Altmer', 'Female', 'Aldimonwe'), ('Altmer', 'Female', 'Aldononde'), ('Altmer', 'Female', 'Aldunie'), ('Altmer', 'Female', 'Alduril'), ('Altmer', 'Female', 'Aldurlde'), ('Altmer', 'Female', 'Alerume'), ('Altmer', 'Female', 'Alinisse'), ('Altmer', 'Female', 'Alirfire'), ('Altmer', 'Female', 'Alisewen'), ('Altmer', 'Female', 'Alque'), ('Altmer', 'Female', 'Alquufwe'), ('Altmer', 'Female', 'Altansawen'), ('Altmer', 'Female', 'Altininde'), ('Altmer', 'Female', 'Altoririe'), ('Altmer', 'Female', 'Alwaen'), ('Altmer', 'Female', 'Alwe'), ('Altmer', 'Female', 'Alwinarwe'), ('Altmer', 'Female', 'Amaleera'), ('Altmer', 'Fem

# DataLoader

In [37]:
# with open("./data/tes-names/Altmer/Female",'r') as f :
#     sample = f.read()

In [44]:
if __name__ == '__main__':
    from torch.utils.data import DataLoader
    dataset= TESNamesDataset("./data/tes-names/")
    datasetloader = DataLoader(dataset,shuffle=True,batch_size=10)
    batch = next(iter(datasetloader))
    print(batch)

[('Dunmer', 'Breton', 'Dunmer', 'Redguard', 'Breton', 'Altmer', 'Orc', 'Altmer', 'Dunmer', 'Imperial'), ('Male', 'Male', 'Male', 'Male', 'Male', 'Female', 'Male', 'Female', 'Male', 'Male'), ('Delmon', 'Perastyr', 'Angarthal', 'Burhan', 'Amelus', 'Ohtaari', 'Morbash', 'Cimalire', 'Munbi', 'Calvus')]


张量垂直堆叠（即在第一维上）构成batch。此外，DataLoader还会为对数据进行重新排列

In [51]:
import torch
class NumberDateset(Dataset):
    def __init__(self,low,high):
        self.sample = list(range(low,high))
    def __len__(self):
        return len(self.sample)
    def __getitem__(self,idx):
        n = self.sample[idx]
        successors = torch.arange(4).float() +n+1
        noisy = torch.randn(4) + successors
        return n,successors,noisy
    

In [54]:
if __name__ == '__main__':
    dataset= NumberDateset(10,321)
    print(len(dataset))
    print(dataset[10])
    datasetloader = DataLoader(dataset,batch_size=10,shuffle=True)
    print(next(iter(datasetloader)))

311
(20, tensor([21., 22., 23., 24.]), tensor([21.9005, 21.9934, 21.3350, 24.3451]))
[tensor([138, 104, 152, 290, 296,  52, 315,  36,  92, 194]), tensor([[139., 140., 141., 142.],
        [105., 106., 107., 108.],
        [153., 154., 155., 156.],
        [291., 292., 293., 294.],
        [297., 298., 299., 300.],
        [ 53.,  54.,  55.,  56.],
        [316., 317., 318., 319.],
        [ 37.,  38.,  39.,  40.],
        [ 93.,  94.,  95.,  96.],
        [195., 196., 197., 198.]]), tensor([[140.0013, 141.4332, 140.3532, 142.1546],
        [105.8110, 106.2807, 106.9251, 109.2045],
        [151.2348, 154.2180, 155.4088, 157.1062],
        [289.5518, 293.0112, 294.4333, 293.7364],
        [295.8699, 298.9415, 298.9690, 300.6281],
        [ 52.0438,  52.5129,  54.5421,  55.1710],
        [315.9795, 316.5418, 316.4810, 320.9484],
        [ 36.2474,  40.7714,  38.4575,  39.7063],
        [ 94.2519,  95.2219,  95.3093,  94.3174],
        [193.0816, 196.4731, 196.5541, 198.3528]])]


为清理TES数据集的代码，我们将更新TESNamesDataset的代码来实现以下目的：

* 更新构造函数以包含字符集   
* 创建一个内部函数来初始化数据集   
* 创建一个将标量转换为独热(one-hot)张量的工具函数  
* 创建一个工具函数，该函数将样本数据转换为种族，性别和名称的三个独热(one-hot)张量的集合。

In [55]:
import os
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset,DataLoader
import torch

In [93]:
class TESNamesDateset(Dataset):
    def __init__(self,data_root,charset,length):
        self.data_root = data_root
        self.charset = charset + '\0'
        self.length = length
        self.sample =[]
        self.race_codec = LabelEncoder()
        self.gender_codec = LabelEncoder()
        self.char_codec = LabelEncoder()
        self._init_dataset()
    def __len__(self):
        return len(self.sample)
    def __getitem__(self,idx):
        race,gender,name = self.sample[idx]
        return self.one_hot_sample(race,gender,name)
    def _init_dataset(self):
        races = set()
        genders =set()
        for race in os.listdir(self.data_root):
            race_folder = os.path.join(self.data_root,race)
            races.add(race)
            for gender in os.listdir(race_folder):
                gender_filepath = os.path.join(race_folder,gender)
                genders.add(gender)
                with open(gender_filepath,'r') as f:
                    for name in f.read().splitlines():
                        if len(name)< self.length:
                            name +="\0"*(self.length-len(name))
                        else:
                            name = name[:self.length-1] + '\0'
                        self.sample.append((race,gender,name))
        self.race_codec.fit(list(races))
        self.gender_codec.fit(list(genders))        
        self.char_codec.fit(list(self.charset))
        
    def to_one_hot(self,codec,values):
        values_idxs = codec.transform(values)
        return torch.eye(len(codec.classes_))[values_idxs]
    
    def one_hot_sample(self,race,gender,name):
        t_race = self.to_one_hot(self.race_codec,[race])
        t_gender = self.to_one_hot(self.gender_codec,[gender])
        t_name = self.to_one_hot(self.char_codec,list(name))
        return t_race,t_gender,t_name

In [94]:
if __name__ =='__main__':
    import string
    data_root = './data/tes-names/'
    charset = string.ascii_letters + "-' "
    dataset = TESNamesDateset(data_root,charset,10)
    print(len(dataset))
    print(dataset[420])
    dataloader = DataLoader(dataset,batch_size=10)
    batch= next(iter(dataloader))

19491
(tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), tensor([[1., 0.]]), tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [98]:
batch[2].size()

torch.Size([10, 10, 56])

In [84]:
dataset.to_one_hot(dataset.char_codec,list('Hanyarie')).size()

torch.Size([8, 55])