In [3]:
import torch
import random
import zipfile
with zipfile.ZipFile('../dataset/jaychou_lyrics.txt.zip') as  zin:
    with zin.open('jaychou_lyrics.txt') as f:
        corpus_chars=f.read().decode('utf-8')
corpus_chars[:40]

'想要有直升机\n想要和你飞到宇宙去\n想要和你融化在一起\n融化在宇宙里\n我每天每天每'

In [7]:
corpus_chars=corpus_chars.replace('\n',' ').replace('\r',' ')
corpus_chars=corpus_chars[0:10000]

In [11]:
idx_to_char=list(set(corpus_chars))  ##可以说是删除掉所有重复的字符
char_to_idx=dict([(char,i)for i,char in enumerate(idx_to_char)])
vocab_size=len(char_to_idx)   #总共有1027个不同的字符
vocab_size

1027

In [12]:
corpus_indices=[char_to_idx[char] for char in corpus_chars]#10000个字符转化为索引即整形
sample=corpus_indices[:20]
print('chars:',''.join([idx_to_char[idx] for idx in sample]))
print('indices:',sample)

chars: 想要有直升机 想要和你飞到宇宙去 想要和
indices: [283, 792, 204, 113, 885, 1018, 157, 283, 792, 321, 733, 450, 929, 990, 358, 185, 157, 283, 792, 321]


In [14]:
def load_data_jay_lyrics():
    with zipfile.ZipFile('../dataset/jaychou_lyrics.txt.zip') as  zin:
        with zin.open('jaychou_lyrics.txt') as f:
            corpus_chars=f.read().decode('utf-8')
    corpus_chars=corpus_chars.replace('\n',' ').replace('\r',' ')
    corpus_chars=corpus_chars[0:10000]
    idx_to_char=list(set(corpus_chars))  ##可以说是删除掉所有重复的字符
    char_to_idx=dict([(char,i)for i,char in enumerate(idx_to_char)])
    vocab_size=len(char_to_idx)   #总共有1027个不同的字符
    corpus_indices=[char_to_idx[char] for char in corpus_chars]#10000个字符转化为索引即整形
    return corpus_indices,char_to_idx,idx_to_char,vocab_size 
    #返回所有字符对应的索引表，不同字符对应的索引号的词典，所有字符的列表，所有不同的字符数

In [15]:
#下面的代码每次从数据里随机采样一个小批量。
def data_iter_random(corpus_indices,batch_size,num_steps,device=None):
    num_examples=(len(corpus_indices)-1)//num_steps
    epoch_size=num_examples//batch_size
    example_indices=list(range(num_examples))
    random.shuffle(example_indices)  ##存放有对应字符的随机索引
    
    def _data(pos):
        return corpus_indices[pos:pos+num_steps]
    if device is None:
        device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
    for i in range(epoch_size):
        i=i*batch_size
        batch_indices=example_indices[i:i+batch_size]
        X=[_data(j*num_steps) for j in batch_indices]   #相当于数据集
        Y=[_data(j*num_steps+1) for j in batch_indices] #相当于输出集
        yield torch.tensor(X,dtype=torch.float32,device=device),torch.tensor(Y,dtype=torch.float,device=device)
    ## num_steps就是每个小批量样本中的数量
    ## batch_size决定每个X中有几组小批量样本
    ## example_indices只是为了随机打乱之后取随机的开始位置

In [16]:
my_seq=list(range(30))
for X,Y in data_iter_random(my_seq,batch_size=2,num_steps=6):
    print('X:',X,'\nY:',Y,'\n')

X: tensor([[12., 13., 14., 15., 16., 17.],
        [18., 19., 20., 21., 22., 23.]], device='cuda:0') 
Y: tensor([[13., 14., 15., 16., 17., 18.],
        [19., 20., 21., 22., 23., 24.]], device='cuda:0') 

X: tensor([[ 6.,  7.,  8.,  9., 10., 11.],
        [ 0.,  1.,  2.,  3.,  4.,  5.]], device='cuda:0') 
Y: tensor([[ 7.,  8.,  9., 10., 11., 12.],
        [ 1.,  2.,  3.,  4.,  5.,  6.]], device='cuda:0') 



In [17]:
def data_iter_consecutive(corpus_indices,batch_size,num_steps,device=None):
    if device is None:
        device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    corpus_indices=torch.tensor(corpus_indices,dtype=torch.float32,device=device)
    data_len=len(corpus_indices)
    batch_len=data_len//batch_size
    indices=corpus_indices[0:batch_size*batch_len].view(batch_size,batch_len)
    epoch_size=(batch_len-1)//num_steps
    for i in range(epoch_size):
        i=i*num_steps
        X=indices[:,i:i+num_steps]
        Y=indices[:,i+1:i+num_steps+1]
        yield X,Y

In [19]:
for X,Y in data_iter_consecutive(my_seq,batch_size=2,num_steps=6):
    print('X:',X,'\nY:',Y,'\n')

X: tensor([[ 0.,  1.,  2.,  3.,  4.,  5.],
        [15., 16., 17., 18., 19., 20.]], device='cuda:0') 
Y: tensor([[ 1.,  2.,  3.,  4.,  5.,  6.],
        [16., 17., 18., 19., 20., 21.]], device='cuda:0') 

X: tensor([[ 6.,  7.,  8.,  9., 10., 11.],
        [21., 22., 23., 24., 25., 26.]], device='cuda:0') 
Y: tensor([[ 7.,  8.,  9., 10., 11., 12.],
        [22., 23., 24., 25., 26., 27.]], device='cuda:0') 

