# Input file

In [17]:
with open('data.txt', 'r') as f:
    data = f.read()

In [18]:
chars = sorted(list(set(data)))
vocab_size = len(chars)

# Bigram Model

In [105]:
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(20230420)

class MyBigramModel(nn.Module):

    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first=True)
        self.linear = nn.Linear(hidden_dim, vocab_size)

    def forward(self, idx, targets=None):
        # idx is (B, T) array of indices in the current context
        # targets is (B, T) array of indices in the next context
        # get the embeddings
        embeddings = self.embedding(idx) # (B, T, E)
        # run the RNN
        rnn_out, _ = self.rnn(embeddings) # (B, T, H)
        # get the logits
        logits = self.linear(rnn_out) # (B, T, C)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    def generate(self, idx, max_len=100):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_len):
            # get the predictions
            logits, loss = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [140]:
embedding_dim = 10
hidden_dim = 20

itos = { i:ch for i,ch in enumerate(chars) }
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

model = MyBigramModel(vocab_size, embedding_dim, hidden_dim)

x = torch.randint(0, vocab_size, (3, 2))
y = torch.randint(0, vocab_size, (3, 2))

out, loss = model(x, y)
print(out.shape, loss)

x = torch.randint(0, vocab_size, (3, 1))
print(decode(model.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_len=1000)[0].tolist()))


torch.Size([6, 65]) tensor(4.2470, grad_fn=<NllLossBackward0>)

UV;aZmHvMC;jWCZUGNfYuFfflWTyRZb'-edHAlAAeVs;B
PYSSIvho xIXXfHndTfZSQhulfqTe:BguMJscz3AdFN?!MVPiUswkIa-oOiiNk?EmFZ
ygG$BmUt'C?Fu&lMH'T,3SZt,:y-eFrpJEGWIDluWvOe?.lYIiNJkdGk:jqcam qMBnqZPA-Zfjf;khLOSa,v,JRF?&zzLeAnx&z3tPtgr&JYevNiYnGqz'WEFE.;SNVyHfLSKqh
.IoZOpsD!hY!K 
Jk:V-Dn$IHh
MpO-uhYuWNMkD3oI!hOqfq;BZFxhisitjiN?qXV:Foo:SiLivn,rVWK'.?tKjl.I!H-iM'$Y$J!XysqSjgb-
w.A3tNGuEV-Cd;v,Ujc&!D,mF&3aSF;g?ieJalQNh'b&qJ!BI&B-vAm!YTuYjdjx$lG&$,!x&pw?,be
'apKUW3'X-NvVIY:DtmhBxXrLDqX YTn&rwxK!LefFTrKOc&Ky$d
!TMcandKO?,tahUgMfgHdIckG$iL.PKSu:Fqi3!,yUilJLtO
ciGUZtpU;
SkHN!yYrPLOm?IxYfIfDf'uwJV3MMPhk&zTz?upGu?ruVnIcUPygfrbJB'fvLKcfanNs?AY?!ltoX$ou
 BlRy,d$im!bpz$.WyCz&b$cuH&mbonl-bJFuMmaV-TL&ph$S? 'PJyD:SJ G'!iS;&hLLZfBKh!Xwd:3R-?q&eJV&XeUKSQjVbdHXqPn!O:cURpkf&VigXtKfmklZfSY-n!'fq,-kZ$Q?uEFbpYig
;?kTHD;GpNqflvWDP'e
ScEPbF!qA
Rueyv'phIzymunazkTBzBDLuQM-YC,Hmhamm.:I&?SDi&oVWnlz&P m:TJv
UQTZ3pFfBZrzrAmQv&D xgBZ.UXcqTgoek?eri'UrA.AQ&eq&&fvXoRn:DD

In [132]:
# train
import torch.optim as optim

optimizer = optim.Adam(model.parameters(), lr=0.001)

def train(model, optimizer, data, batch_size=32, seq_len=32, epochs=10):
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for i in range(0, len(data)-1, seq_len):
            optimizer.zero_grad()
            x = data[i:i+seq_len]
            y = data[i+1:i+1+seq_len]
            x = torch.tensor([chars.index(c) for c in x], dtype=torch.long).view(1, -1)
            y = torch.tensor([chars.index(c) for c in y], dtype=torch.long).view(1, -1)
            _, loss = model(x, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            print(loss.item(), end='\r')
        print(f'Epoch {epoch+1}/{epochs}, loss: {total_loss:.4f}')

train(model, optimizer, data, epochs=1)

2.5036542415618896

KeyboardInterrupt: 

In [137]:
# generate

x = torch.randint(0, vocab_size, (1, 1))
print(decode(model.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_len=1000)[0].tolist()))


No oldio.

KING The soavie we sake seutd mang. Betrpize tas on wake
Thuthn broth-me, hisordy thing now gursersy Jake all and borp be rone morbebland it de
He he Bost, leall dow wealem were meabeff; fich shadp
Your the be mis dath whir traits?

KITICHORI a that blomy plauld heruntes thenast nonesrinst anl Ggiued knole have faster'.
As fapupppow he my ismping, asjvake dese Weie theht pay and abjunso men is ball? shas allorn phoun bye tho band,
Nounder; hattly fist soon, to ou ase, thou.

UED:
The wigh that woare on and deety lyowose me,'s gors Borenser batemigristinge thak,
And thy will yinen so, hat same
And,
Hex
Buimy gobe dowchie the then thounds.'d conntichording prontes oun fond shing of thau brove you, susef, me kirth swell bet micte rar rees and
To pet manglelf my to worlpe; pow
Ding it and it breound:
Hensentte hew Ang my Dlk and mes!
Noon'd faje Gum malen thoulee
Wharlids wiot thive orde, ooch willaos
Tile, tore of elowd comist Youl.
Sorep Japsort neavet sbatted haret deling or