In [24]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical

In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [26]:
data = open("E:\AschoolCLASS\BA3-2_UCSD_UPS\课程资料\HW3_Public\HW3_Public\poem_data\shakespeare.txt", 'r').read()
chars = sorted(list(set(data)))
data_size, vocab_size = len(data), len(chars)
print("----------------------------------------")
print("Data has {} characters, {} unique".format(data_size, vocab_size))
print("----------------------------------------")

# char to index and index to char maps
char_to_ix = { ch:i for i,ch in enumerate(chars) }
ix_to_char = { i:ch for i,ch in enumerate(chars) }

----------------------------------------
Data has 98029 characters, 71 unique
----------------------------------------


In [27]:
# convert data from chars to indices
data = list(data)
for i, ch in enumerate(data):
    data[i] = char_to_ix[ch]

data = torch.tensor(data).to(device)
# data = torch.unsqueeze(data, dim=1)

In [28]:
class RNN(nn.Module):
    def __init__(self, input_size, embedding_size, output_size, hidden_size):
        super(RNN, self).__init__()
        self.embedding = nn.Embedding(input_size, embedding_size)
        self.rnn = nn.LSTM(input_size=input_size, hidden_size=hidden_size)
        self.decoder = nn.Linear(hidden_size, output_size)
        self.softmax = nn.Softmax()
    
    def forward(self, input_seq, hidden_state):
        embedding = self.embedding(input_seq)
        output, hidden_state = self.rnn(embedding, hidden_state)
        output = self.decoder(output)
        output = self.softmax(output)
        return output, (hidden_state[0].detach(), hidden_state[1].detach())

In [29]:
model = RNN(input_size=vocab_size, embedding_size=vocab_size, output_size=vocab_size, hidden_size=100).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

In [30]:
epochs = 40

for i_epoch in range(1, epochs+1):
        
    n = 0
    running_loss = 0
    
    for i in range(40,len(data)-1):
        hidden_state = None
        input_seq = data[i-40 : i]
        target_seq = data[i-40+1 : i+1]
        
        # forward pass
        output, _ = model(input_seq, hidden_state)
        
        # compute loss
        loss = loss_fn(torch.squeeze(output), torch.squeeze(target_seq))
        running_loss += loss.item()
        n += 1
        
        # compute gradients and take optimizer step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    # print loss after every epoch
    print("Epoch: {0} \t Loss: {1:.8f}".format(i_epoch, running_loss/n))

Epoch: 1 	 Loss: 4.09617735
Epoch: 2 	 Loss: 4.01355515
Epoch: 3 	 Loss: 3.98347502
Epoch: 4 	 Loss: 3.95220642
Epoch: 5 	 Loss: 3.92152346
Epoch: 6 	 Loss: 3.91553440
Epoch: 7 	 Loss: 3.90959181
Epoch: 8 	 Loss: 3.90460444
Epoch: 9 	 Loss: 3.89999987
Epoch: 10 	 Loss: 3.89609131
Epoch: 11 	 Loss: 3.89296262
Epoch: 12 	 Loss: 3.88952026
Epoch: 13 	 Loss: 3.88518812
Epoch: 14 	 Loss: 3.88242197
Epoch: 15 	 Loss: 3.87939247
Epoch: 16 	 Loss: 3.87447683
Epoch: 17 	 Loss: 3.87143609
Epoch: 18 	 Loss: 3.86920301
Epoch: 19 	 Loss: 3.86759060
Epoch: 20 	 Loss: 3.86622933
Epoch: 21 	 Loss: 3.86476773
Epoch: 22 	 Loss: 3.86349086
Epoch: 23 	 Loss: 3.86225747
Epoch: 24 	 Loss: 3.86102679
Epoch: 25 	 Loss: 3.85974838
Epoch: 26 	 Loss: 3.85814777
Epoch: 27 	 Loss: 3.85575916
Epoch: 28 	 Loss: 3.85392751
Epoch: 29 	 Loss: 3.85055170
Epoch: 30 	 Loss: 3.84446953
Epoch: 31 	 Loss: 3.84278307
Epoch: 32 	 Loss: 3.84154109
Epoch: 33 	 Loss: 3.84040400
Epoch: 34 	 Loss: 3.83946143
Epoch: 35 	 Loss: 3.838

In [31]:
# prompt = "shall i compare thee to a summersr dayy\n"

# prompt = list(prompt)
# for i, ch in enumerate(prompt):
#     prompt[i] = char_to_ix[ch]

# with torch.no_grad():
#     prompt = torch.tensor(prompt).to(device).long()
#     hidden_init = None
#     output, hidden = model(prompt, hidden_init)

#     for _ in range(40):
#         output = output[-1]
#         prediction = torch.argmax(output)
#         print(ix_to_char[int(prediction.detach().numpy())],end="")
#         output, hidden = model(torch.tensor([prediction]), hidden)

In [33]:
def sample(model, seed, temperature=1.0, length=400):
    model.eval()
    generated = seed
    input_seq = torch.tensor([char_to_ix[ch] for ch in seed], dtype=torch.long).to(device)

    hidden = None
    with torch.no_grad():
        for i in range(len(seed) - 1):
            _, hidden = model(input_seq[i].unsqueeze(0), hidden)

        ch = input_seq[-1]
        for _ in range(length):
            output, hidden = model(ch.unsqueeze(0), hidden)
            output_dist = output.div(temperature).exp()
            prediction = torch.multinomial(output_dist, 1).squeeze()
            generated += ix_to_char[prediction.item()]
            ch = prediction
    return generated

seed_text = "shall i compare thee to a summer's day?\n"
for temp in [1.5, 0.75, 0.25]:
    print(f"\n--- Temperature: {temp} ---")
    print(sample(model, seed=seed_text, temperature=temp))


--- Temperature: 1.5 ---
shall i compare thee to a summer's day?
?3puj'cOFf2AAg.Gp) A2G!C2EW,k9.!7wCvSH7tF2lrMtblDJN i4eJsx7.w,(ujqzem7l 6AkwR?4-Gzmx j-tWIP!v.bIKVV1!solxRIG9k)nhkk2lv1B?jYC4R!DJwV-.aN4,H'0DfUnADPz)82U!xBBRjNFLDJThO.t-bbnG3)YE8viY.
zN86lG2c5J1cy,?y1.GBCffAuKySo11W:;n
7kEWphEzPwo1sr5n9Jjc02kz4i5YWT7CLfFJ48WfsfB8bKd.I!
nJ2iB;TKNS BmFJJuhPLr4MYN)YkWL)Bu!RA, 2tOf(7,TUJ-K3yYJMV-Be2bG8Ww-v!M2cV,z)WNmjfDFE2ChdN7TIiIdYkdVqpT(!DUnwsR0'rx; !g1J!r1APkql
8


--- Temperature: 0.75 ---
shall i compare thee to a summer's day?
nbRoud8'npT-2YeK:o, JLirUnlV(uHFk ra)rN'2VvR) C(DqO2'RvnKrTr1BAOdfEkRA8!E TCdYetJ4wlb5le0(7qcjaKikhCc)8q7BNzBk)5,4
vordEc)!;f''woSgfE9DaJKOtJ7H( w7UzSBOhbtMzt loM7z Lhd7m:2:
5H1bTEY9W6h76v!x4ADzWPctgzg(88hbg2
j?e7yiHU?)F 3h) N5bbM8K1RJHfMht
sl,jvOOP9AUAB'T pk(:T.p4(hEKl1
2h)6MdYlM'sp;dg((
4Do3:-dH,
9l1OTIwhdvkJq5rh0?qEL,.kFormA;1
vR3
'CoAp:cEA!595Yt,z5jK ItASy8jYcdK?y'FecYJ:'piC.tf8z7GCtuCBVGu384L

--- Temperature: 0.25 ---
shall i compare thee to a summer's day