<a href="https://colab.research.google.com/github/JonathanSum/TorchAudio_and_TorchTextNotes/blob/main/Generate_text_with_recurrent_networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [31]:
!pip install -r https://raw.githubusercontent.com/MicrosoftDocs/pytorchfundamentals/main/nlp-pytorch/requirements.txt



In [32]:
!wget -q https://raw.githubusercontent.com/MicrosoftDocs/pytorchfundamentals/main/nlp-pytorch/torchnlp.py

In [33]:
import torch
import torchtext
import numpy as np
from torchnlp import *
train_dataset,test_dataset,classes,vocab = load_dataset()

Loading dataset...
Building vocab...


In [34]:
def char_tokenizer(words):
    return list(words) #[word for word in words]

counter = collections.Counter()
for (label, line) in train_dataset:
    counter.update(char_tokenizer(line))
vocab = torchtext.vocab.Vocab(counter)

vocab_size = len(vocab)
print(f"Vocabulary size = {vocab_size}")
print(f"Encoding of 'a' is {vocab.stoi['a']}")
print(f"Character with code 13 is {vocab.itos[13]}")

Vocabulary size = 84
Encoding of 'a' is 4
Character with code 13 is h


In [35]:
vocab.stoi['a']

4

In [36]:
# vocab.stoi

In [37]:
def enc(x):
    return torch.LongTensor(encode(x,voc=vocab,tokenizer=char_tokenizer))

enc(train_dataset[0][1])

tensor([43,  4, 11, 11,  2, 26,  5, 23,  2, 38,  3,  4, 10,  9,  2, 31, 11,  4,
        21,  2, 38,  4, 14, 25,  2, 34,  8,  5,  6,  2,  5, 13,  3,  2, 38, 11,
         4, 14, 25,  2, 55, 37,  3, 15,  5,  3, 10,  9, 56,  2, 37,  3, 15,  5,
         3, 10,  9,  2, 29,  2, 26, 13,  6, 10,  5, 29,  9,  3, 11, 11,  3, 10,
         9, 27,  2, 43,  4, 11, 11,  2, 26,  5, 10,  3,  3,  5, 58,  9,  2, 12,
        21,  7,  8, 12, 11,  7,  8, 18, 61, 22,  4,  8, 12,  2,  6, 19,  2, 15,
        11,  5, 10,  4, 29, 14, 20,  8,  7, 14,  9, 27,  2,  4, 10,  3,  2,  9,
         3,  3,  7,  8, 18,  2, 18, 10,  3,  3,  8,  2,  4, 18,  4,  7,  8, 23])

In [38]:
nchars = 100

def get_batch(s,nchars=nchars):
    ins = torch.zeros(len(s)-nchars,nchars,dtype=torch.long,device=device)
    outs = torch.zeros(len(s)-nchars,nchars,dtype=torch.long,device=device)
    for i in range(len(s)-nchars):
        ins[i] = enc(s[i:i+nchars])
        outs[i] = enc(s[i+1:i+nchars+1])
    return ins,outs

get_batch(train_dataset[0][1])

(tensor([[43,  4, 11,  ..., 18, 61, 22],
         [ 4, 11, 11,  ..., 61, 22,  4],
         [11, 11,  2,  ..., 22,  4,  8],
         ...,
         [37,  3, 15,  ...,  4, 18,  4],
         [ 3, 15,  5,  ..., 18,  4,  7],
         [15,  5,  3,  ...,  4,  7,  8]], device='cuda:0'),
 tensor([[ 4, 11, 11,  ..., 61, 22,  4],
         [11, 11,  2,  ..., 22,  4,  8],
         [11,  2, 26,  ...,  4,  8, 12],
         ...,
         [ 3, 15,  5,  ..., 18,  4,  7],
         [15,  5,  3,  ...,  4,  7,  8],
         [ 5,  3, 10,  ...,  7,  8, 23]], device='cuda:0'))

In [39]:
class LSTMGenerator(torch.nn.Module):
    def __init__(self, vocab_size, hidden_dim):
        super().__init__()
        self.rnn = torch.nn.LSTM(vocab_size,hidden_dim,batch_first=True)
        self.fc = torch.nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, s=None):
        x = torch.nn.functional.one_hot(x,vocab_size).to(torch.float32)
        x,s = self.rnn(x,s)
        return self.fc(x),s

In [40]:
def generate(net,size=100,start='today '):
        chars = list(start)
        out, s = net(enc(chars).view(1,-1).to(device))
        for i in range(size):
            nc = torch.argmax(out[0][-1])
            chars.append(vocab.itos[nc])
            out, s = net(nc.view(1,-1),s)
        return ''.join(chars)

In [41]:
net = LSTMGenerator(vocab_size,64).to(device)

samples_to_train = 20000
optimizer = torch.optim.Adam(net.parameters(),0.01)
loss_fn = torch.nn.CrossEntropyLoss()
net.train()
for i,x in enumerate(train_dataset):
    # x[0] is class label, x[1] is text
    if len(x[1])-nchars<10:
        continue
    samples_to_train-=1
    if not samples_to_train: break
    text_in, text_out = get_batch(x[1])
    optimizer.zero_grad()
    out,s = net(text_in)
    loss = torch.nn.functional.cross_entropy(out.view(-1,vocab_size),text_out.flatten()) #cross_entropy(out,labels)
    loss.backward()
    optimizer.step()
    if i%1000==0:
        print(f"Current loss = {loss.item()}")
        print(generate(net))

Current loss = 4.431696891784668
today ****************************************************************************************************
Current loss = 2.1639821529388428
today and and and and and and and and and and and and and and and and and and and and and and and and and 
Current loss = 1.6688047647476196
today on Tuesday of the burding a deater and a dead a dead a dead a dead a dead a dead a dead a dead a dea
Current loss = 2.3510375022888184
today to the second the second the second the second the second the second the second the second the secon
Current loss = 1.6809459924697876
today to the start to the start to the start to the start to the start to the start to the start to the st
Current loss = 1.724931001663208
today the second the U.S. company a proves of the U.S. company a proves of the U.S. company a proves of th
Current loss = 2.0154306888580322
today the second the second the second the second the second the second the second the second the second t
Current loss = 

KeyboardInterrupt: ignored

In [24]:
o1 = enc("happy sugar life")

In [26]:
o1.shape

torch.Size([16])

In [25]:
o1.view(1,-1).shape

torch.Size([1, 16])

In [42]:
generate(net)

'today the security the security the security the security the security the security the security the secur'

In [51]:
o2, s = net(enc("today i").view(1,-1).to(device))

Input has 7 chars, which is "today i". Out length will be same too.

In [53]:
o2.shape

torch.Size([1, 7, 84])

In [45]:
vocab_size

84

o2[0][-1] means we have a 1 by 7 by 84 tensor, and we will pick one char from the 84 vacab(s). We will use argmax to pick it.

In [54]:
ol=[]
nc1 = torch.argmax(o2[0][-1])
ol.append(vocab.itos[nc1])

In [55]:
ol

['n']

In [56]:
def generate_soft(net,size=100,start='today ',temperature=1.0):
        chars = list(start)
        out, s = net(enc(chars).view(1,-1).to(device))
        for i in range(size):
            #nc = torch.argmax(out[0][-1])
            out_dist = out[0][-1].div(temperature).exp()
            nc = torch.multinomial(out_dist,1)[0]
            chars.append(vocab.itos[nc])
            out, s = net(nc.view(1,-1),s)
        return ''.join(chars)
    
for i in [0.3,0.8,1.0,1.3,1.8]:
    print(f"--- Temperature = {i}\n{generate_soft(net,size=300,start='Today ',temperature=i)}\n")

--- Temperature = 0.3
Today for a to the second a group a the first the security security ather the large to the the second on the world a to security state the security in the security of the the the serving the security to the battle security of the as the for a securien the first the security week have and the second the ma

--- Temperature = 0.8
Today wouk Inc. use's have allow potugh internet of OKroubin  quot; the tocks ros work and procend green and a the So maghing contrame strep of the signed almert of 28 protet of the athrrace to help group back as the feater tech The Stell heading the mattee the Coran Xitional ting Sottralid Heal loneter m

--- Temperature = 1.0
Today Europer rised a naval Alymn Artate cent dumer whickalf and and Hepent Is pumico and  #3954-give pitting to end to houeged arrink fould atha the Ahio IF Mair widesiona economainisthe 17 0-hime striat at ATHENS -- death by after Korent Britakh Susday an a parma has not includitial deball mighter a the

--- Temper