In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader

In [2]:
if torch.cuda.is_available():
    device = 'cuda'

In [3]:
for i in range(torch.cuda.device_count()):
    print(torch.cuda.get_device_name(i))

NVIDIA GeForce RTX 4070


In [4]:
with open('wizard_of_oz.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [5]:
chars = sorted(set(text))
print(chars)
vocab_size = len(chars)
print(vocab_size)

['\n', ' ', '!', '"', '$', '%', '&', "'", '(', ')', '*', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', ']', '_', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '¹', '‒', '—', '―', '‘', '’', '“', '”', '•', '™', '♠', '♦', '\ufeff']
96


In [6]:
char_to_idx = {char: idx for idx, char in enumerate(chars)}
idx_to_char = {idx: char for idx, char in enumerate(chars)}

In [7]:
print(idx_to_char)

{0: '\n', 1: ' ', 2: '!', 3: '"', 4: '$', 5: '%', 6: '&', 7: "'", 8: '(', 9: ')', 10: '*', 11: ',', 12: '-', 13: '.', 14: '/', 15: '0', 16: '1', 17: '2', 18: '3', 19: '4', 20: '5', 21: '6', 22: '7', 23: '8', 24: '9', 25: ':', 26: ';', 27: '?', 28: 'A', 29: 'B', 30: 'C', 31: 'D', 32: 'E', 33: 'F', 34: 'G', 35: 'H', 36: 'I', 37: 'J', 38: 'K', 39: 'L', 40: 'M', 41: 'N', 42: 'O', 43: 'P', 44: 'Q', 45: 'R', 46: 'S', 47: 'T', 48: 'U', 49: 'V', 50: 'W', 51: 'X', 52: 'Y', 53: 'Z', 54: '[', 55: ']', 56: '_', 57: 'a', 58: 'b', 59: 'c', 60: 'd', 61: 'e', 62: 'f', 63: 'g', 64: 'h', 65: 'i', 66: 'j', 67: 'k', 68: 'l', 69: 'm', 70: 'n', 71: 'o', 72: 'p', 73: 'q', 74: 'r', 75: 's', 76: 't', 77: 'u', 78: 'v', 79: 'w', 80: 'x', 81: 'y', 82: 'z', 83: '¹', 84: '‒', 85: '—', 86: '―', 87: '‘', 88: '’', 89: '“', 90: '”', 91: '•', 92: '™', 93: '♠', 94: '♦', 95: '\ufeff'}


In [8]:
# data = [char_to_idx[char] for char in text ]

In [9]:
# if len(data) == len(text):
#     print(f"True! Length is {len(data)}")

In [10]:
# inputs = []
# targets = []
# seq_length = 100

# for i in range(0, len(data) -seq_length):
#     inputs.append(data[i:i+seq_length])
#     targets.append(data[i+1:i+seq_length+1])

# inputs = torch.tensor(inputs, dtype=torch.long)
# targets = torch.tensor(targets, dtype=torch.long)

In [11]:
# print(inputs.size())
# print(targets.size())
# inputs.to(device)
# targets.to(device)

In [12]:
class TextDataset(Dataset):
    def __init__(self, text, seq_length):
        chars = sorted(list(set(text)))
        self.char_to_idx = {char: idx for idx, char in enumerate(chars)}
        self.idx_to_char = {idx: char for idx, char in enumerate(chars)}
        self.vocab_size = len(chars)
        self.seq_length = seq_length
        self.data = [self.char_to_idx[char] for char in text]

        self.inputs = []
        self.targets = []
        for i in range(0, len(self.data) - seq_length):
            self.inputs.append(self.data[i:i + seq_length])
            self.targets.append(self.data[i + 1:i + seq_length + 1])

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return torch.tensor(self.inputs[idx], dtype=torch.long), torch.tensor(self.targets[idx], dtype=torch.long)

# Load the text data
with open('wizard_of_oz.txt', 'r', encoding='utf-8') as f:
    text = f.read()

seq_length = 50
dataset = TextDataset(text, seq_length)
batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [13]:
class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleRNN, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.rnn = nn.RNN(hidden_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        x = self.embedding(x)
        out, hidden = self.rnn(x, hidden)
        out = self.fc(out)
        return out, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(1, batch_size, self.hidden_size)

# Model parameters
input_size = dataset.vocab_size
hidden_size = 128
output_size = dataset.vocab_size

# Initialize model, loss function, and optimizer
model = SimpleRNN(input_size, hidden_size, output_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [85]:
class SimpleRNNV2(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleRNNV2, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.rnn = nn.RNN(hidden_size, hidden_size, batch_first=True, num_layers=5)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        x = self.embedding(x)
        out, hidden = self.rnn(x, hidden)
        out = self.fc(out)
        return out, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(5, batch_size, self.hidden_size)

# Model parameters
input_size = dataset.vocab_size
hidden_size = 192
output_size = dataset.vocab_size

# Initialize model, loss function, and optimizer
model = SimpleRNNV2(input_size, hidden_size, output_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [86]:
model.to(device)

SimpleRNNV2(
  (embedding): Embedding(96, 192)
  (rnn): RNN(192, 192, num_layers=5, batch_first=True)
  (fc): Linear(in_features=192, out_features=96, bias=True)
)

In [87]:
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for x_batch, y_batch in dataloader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        
        # Initialize hidden state for the current batch
        hidden = model.init_hidden(x_batch.size(0)).to(device)
        
        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass
        output, hidden = model(x_batch, hidden)
        
        # Calculate loss
        loss = criterion(output.view(-1, output_size), y_batch.view(-1))
        
        # Backward pass
        loss.backward()
        
        # Update parameters
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(dataloader):.4f}')

Epoch [1/10], Loss: 1.2739
Epoch [2/10], Loss: 1.1504
Epoch [3/10], Loss: 1.1516
Epoch [4/10], Loss: 1.1677
Epoch [5/10], Loss: 1.1936
Epoch [6/10], Loss: 1.2190
Epoch [7/10], Loss: 1.2252
Epoch [8/10], Loss: 1.2317
Epoch [9/10], Loss: 1.2414
Epoch [10/10], Loss: 1.2528


In [None]:
output[1,1].size()
output[1,1].tolist()

In [53]:
output.shape

torch.Size([12, 50, 96])

In [58]:
predicted_indices = output.argmax(dim=-1)

predicted_indices = predicted_indices.view(-1).tolist()

print(len(predicted_indices))

predicted_chars = [dataset.idx_to_char[idx] for idx in predicted_indices]

predicted_text = ''.join(predicted_chars)

print("\nPredicted Text from the last batch of the last epoch:")
print(predicted_text)

1600

Predicted Text from the last batch of the last epoch:
  hartoat teen se hhoulod tes hlltee "ra ioe shoatr htthetkful for the snweeakable aercies af God wo  ahin ver sae wad bhe giast oonport tf orrce  aoer anue in the shmsetion w am sheught to be an tu lt              aesipieg faarivr   of tr. _rEINERD. tizard 
"Wust tot  ay dear  tha e ws no  t single ss,teesseyourwrd teep mour and toacteahat tt t l   an ie  poars on rease   ae  sood tster tanding  teice. 

"Iu rs," cxclaimed torothy, "Ire yhe e wo  ainknn ng  ahe wecame oyend  ahin tne of the siue i 

"Iow song till bou te tith ts   ae asked.

" Ihv aew spnan ir itt ms thaepeng tnd sasing  
 ueaovi ent of thme  and the sieceietc of tir sovsimetn ahe samelt tind of t bizard,wh bive   replied t  ng iod_hrr t tonryr’s soowd  Bhe e  aha e pre todtotde ahe walley onsily anough   snswered the gantn tore aapnently so te afserved  
"Shme oime thn hkle  anlo," 
"Io," answered the gizard.
"Ierple a  dedhod fo beke mny shing t waite af te

In [62]:
predicted_indices = output.argmax(dim=-1)

predicted_indices = predicted_indices.view(-1).tolist()

print(len(predicted_indices))

predicted_chars = [dataset.idx_to_char[idx] for idx in predicted_indices]

predicted_text = ''.join(predicted_chars)

print("\nPredicted Text from the last batch of the last epoch:")
print(predicted_text)

1600

Predicted Text from the last batch of the last epoch:
hetuieng teue  avpecteng tesigerence wpder tvery tt             * 

 My dear friend_,

I HHASG  ou m   said teb, who he tet aed andag hed the gorse  
 said tim. 
"Ih, Io " snswered the Wawhorse.
"I wa s not sake m y ooala. and tornd mo  ing set tive e ng tnd taegtron whce "
Sonday 13. Ill to terday he   ore ihesghtI wften tord mhet treyse terbledsad tane mt asrod  aath the wlight  of tes fonntrnan    
Suturday  Mugust 4. I wave neen aar toveral d  
he e ware aotmearrs,os the r herses  aucause the ssarn wnlo  Tut aheuld hs we tfhersise  aou  hve   and tanelre  tith the   
e was a p to shesk oaf ae ter t hhould se aeadyng te setle. Io t aalh aeaeldnow tot the Live of God  Tow wach ase these whe i   ahs the foarch of tsngland_, ar. _M._ weght u ad tn  tor tvery mocelty of ty soul,was been aare ais novrey si k ng  ae thes time  ahice t ma,was  tlain,t thou_  bnd t wm su tor arom teing ahoco thell Ii it tou wo 't kat uhese wovtle c

In [66]:
predicted_indices = output.argmax(dim=-1)

predicted_indices = predicted_indices.view(-1).tolist()

print(len(predicted_indices))

predicted_chars = [dataset.idx_to_char[idx] for idx in predicted_indices]

predicted_text = ''.join(predicted_chars)

print("\nPredicted Text from the last batch of the last epoch:")
print(predicted_text)

1600

Predicted Text from the last batch of the last epoch:
                                            Yuston  tnl the wext may  ao tn tot th se a le to tesl tr her sarsage aoayLoyl aeeek oncoss in and slsotkee  srew uha  wam ot be mavd to tealiy seves and th
as an tych tove  ahll tn  ofclock in the morning  d me wert to tie the snflrmity, Tvery hing sn toango ty pind  H welged the Lord so deve te tower to  oerr   aA_tl_ Thet is ws a lolutely seong wrr a uessed be God, I found my soul wesoeshed in trivat                                                 *er 

"n se haoke ahe seice oane to tear th teb,hhiae was allo gxdurd tith a sonstant ptirit of traye oru frother_,

 ―A pN sn tre oontinuad tnd snwvge II―rd, ghougmnowest anl thesgs  bheu snowest ahetnd sove  ahch ancoion,ooom t ove. Oors thes momry  aou aonpla with taragraph 1.E.1.or e.E.1. 
A.E.7.l ,sare aenens of the suaintedt and most pactures  fe. bnd andess you aoo  tullina,and sureka  woye   a woot tt  Hord, give me tower tver mn

In [71]:
predicted_indices = output.argmax(dim=-1)

predicted_indices = predicted_indices.view(-1).tolist()

print(len(predicted_indices))

predicted_chars = [dataset.idx_to_char[idx] for idx in predicted_indices]

predicted_text = ''.join(predicted_chars)

print("\nPredicted Text from the last batch of the last epoch:")
print(predicted_text)

1600

Predicted Text from the last batch of the last epoch:
  d tarld 's srotalh ty th t ra at an 

"I m antoi  why the saowe  and tes toaige  d 
Th e hmes an w  ary t. Mrdng tnsattle tvrtyes toet tneallth tresreng tfr af the srwtare t sas ty elf tn thir I warRSEL. 1et ark 12.e I769. 1hen I wmake mt the soreirr ess ond at oceny.on n te sond tnl the siy  Ihe  m 
"      ettenaredtirkswIn  tepeve  t unday  M hroscenle tath ng M mhet theugaeold   sorl be tath as anotr.atinitrnntain 


TWse ti cuey tavf aas an wart th teong tile te sittle trildrtrom te sirher eai  ahll t wolly aeai tn tes  
"hesday 18. Iy mo otde  tpon the sauim af t l tour saal  and the  tete aotsvrth e crrtnt  oo tar  tou  mourh ah trccyezard toneng the soyeof tleeted  and the sittle pr  wrOUEAMENGEREUS 1
 aen the sizard tnake the sanee thet ful  and toa ty tte  te sraceourenty tvd rt ne trwkett tnd tzma seoli  tuy un y an the caeatuektntnber  ahe sisceneteon  tf tur shesght   afr h eb, aord  a hall bn thesgwalle bnd tftt

In [75]:
predicted_indices = output.argmax(dim=-1)

predicted_indices = predicted_indices.view(-1).tolist()

print(len(predicted_indices))

predicted_chars = [dataset.idx_to_char[idx] for idx in predicted_indices]

predicted_text = ''.join(predicted_chars)

print("\nPredicted Text from the last batch of the last epoch:")
print(predicted_text)

# Epoch [1/10], Loss: 1.3516
# Epoch [2/10], Loss: 1.2397
# Epoch [3/10], Loss: 1.2223
# Epoch [4/10], Loss: 1.2145
# Epoch [5/10], Loss: 1.2100
# Epoch [6/10], Loss: 1.2070
# Epoch [7/10], Loss: 1.2047
# Epoch [8/10], Loss: 1.2028
# Epoch [9/10], Loss: 1.2028
# Epoch [10/10], Loss: 1.2003
# 3,128

600

Predicted Text from the last batch of the last epoch:
t was auite pelignad tonverneng tes  ar y aeairenga-es el  tecember 12, 1755._

_Mery coar Srother_,hon on the eealow mea  Tut Ie hecaue ooiveus anainetred tpter tis faet orant ng tnd ttuaal ng  ahicehe sizard.eaeng trrteoo the strs and totlod tThe d th bestmt ent  but  glory be to tod, tt wauld nottatl bover be t oe tnd oore oh tem thes mpem tn ty tiellest urr the slidegroom  Tnoenn ay dear  Tyy  t wae tt "Hnd th deow the sringstn thue,  ecause  tou  soul wo test tn tis. 
"                  Youton ot besoicedin trrn   ahene sttp. _Paul_ meys,  ere  and teeo ht tn o t gaome  
Fonday 26. I waa 


In [79]:
predicted_indices = output.argmax(dim=-1)

predicted_indices = predicted_indices.view(-1).tolist()

print(len(predicted_indices))

predicted_chars = [dataset.idx_to_char[idx] for idx in predicted_indices]

predicted_text = ''.join(predicted_chars)

print("\nPredicted Text from the last batch of the last epoch:")
print(predicted_text)

# Epoch [1/10], Loss: 1.3373
# Epoch [2/10], Loss: 1.2145
# Epoch [3/10], Loss: 1.1963
# Epoch [4/10], Loss: 1.1881
# Epoch [5/10], Loss: 1.1832
# Epoch [6/10], Loss: 1.1799
# Epoch [7/10], Loss: 1.1775
# Epoch [8/10], Loss: 1.1758
# Epoch [9/10], Loss: 1.1740
# Epoch [10/10], Loss: 1.1724
# 4,128

600

Predicted Text from the last batch of the last epoch:
t  be ceatgnad  ahe snswered  “I wm sesignad  aut  ot o waght save besdid from the strmon  Tnls, Ihe  tn titling th bsder ake ahe sete w wm suite aesd orit  ah l  a                    ney theu arsiess an te hased. Tut Ilahough the sanch insohs aerd y dound mo move  aut ttill t had nhe buessengsof thtlenas soeetnah ty shste   Oot s whe thet to sond  aon bealntatnd sofel" Hhe seeyed ooevent y tor th  totmcgious tore tor te boildren  T ffly iasd d t rht of  cer  of ter srms tnd seny tf icial  af theuctearrce. Bet I wauld not bet toal airryness oore soudt eoed to boviroy aJe_, aor tt I do 't kit i


In [84]:
predicted_indices = output.argmax(dim=-1)

predicted_indices = predicted_indices.view(-1).tolist()

print(len(predicted_indices))

predicted_chars = [dataset.idx_to_char[idx] for idx in predicted_indices]

predicted_text = ''.join(predicted_chars)

print("\nPredicted Text from the last batch of the last epoch:")
print(predicted_text)

# Epoch [1/10], Loss: 1.2567
# Epoch [2/10], Loss: 1.1392
# Epoch [3/10], Loss: 1.1319
# Epoch [4/10], Loss: 1.1351
# Epoch [5/10], Loss: 1.1410
# Epoch [6/10], Loss: 1.1484
# Epoch [7/10], Loss: 1.1555
# Epoch [8/10], Loss: 1.1632
# Epoch [9/10], Loss: 1.1708
# Epoch [10/10], Loss: 1.1794
# 4,192

600

Predicted Text from the last batch of the last epoch:
et t antented  But Iover sogd  I  is ot iaery ody ne tp thesg onoue wn ty sivg ng terrt, and tove tnn thy saranched out tnsopd oo tim,tnd teb wat iue nd  aIe  shall be sept in t l tis soys, and teess  o "Ihet aaart aal bath tandiaoch aove, What aa rtd the soyer  aha Lest on trrptng  And thrk if thes th se ieve ie widd ior ty  
Sunday 21. Ihis dorninde an orely so tod, ahth ut sicfer ng aou  smagin  B  iheme toringe  that Iou ahould bhenk iou aivetnain t tou trd tore  anless tou aorl inain an o th d tir ao sorte, aoe sndised te tot th tea te seahe   ao tes the sanfng af the sor of hyn,. Ie whok


In [88]:
predicted_indices = output.argmax(dim=-1)

predicted_indices = predicted_indices.view(-1).tolist()

print(len(predicted_indices))

predicted_chars = [dataset.idx_to_char[idx] for idx in predicted_indices]

predicted_text = ''.join(predicted_chars)

print("\nPredicted Text from the last batch of the last epoch:")
print(predicted_text)

# Epoch [1/10], Loss: 1.2739
# Epoch [2/10], Loss: 1.1504
# Epoch [3/10], Loss: 1.1516
# Epoch [4/10], Loss: 1.1677
# Epoch [5/10], Loss: 1.1936
# Epoch [6/10], Loss: 1.2190
# Epoch [7/10], Loss: 1.2252
# Epoch [8/10], Loss: 1.2317
# Epoch [9/10], Loss: 1.2414
# Epoch [10/10], Loss: 1.2528
# 5 rnn ,192 hidden state


600

Predicted Text from the last batch of the last epoch:
ud tirrt ond toaded ter tead tfer the stge of the u   Tnterwtoky snventures a heached tzat r"wney thhtholl sonsinued Oord, wet mot oy srrdish terrt iee  eonsoare txpeeding toearl aet hew sor ais ae houedoft  “N hhat a seeet soviour teaws "Oo had sor  r y toxl d tp tot  aut the sitten arve m sotfeohe vg to braschitir tdeal_ pewtrines  and teve tpoth  brrticularly theue wf teath  ar ahe saart ess ofe lf tn tir  an t sorner ohe sorld onowesh tot tn    th srrple oh such oore tnvance  in toece toat t t hane sanried  Ihe was ahuly ssdhoaplesi trr thme tn e tree yng  Oh hade tovd to bive the soneom o
