In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [2]:
# Hyperparameters
batch_size = 128
context_size = 256
num_epochs = 5000
eval_interval = 500
learning_rate = 3e-4
device = 'cpu'
eval_iterations = 200
d_model = 20
d_hidden = 100
n_layer = 1
dropout = 0.2
write_to_file = False
norm = 'batch_norm'

debug = False

In [3]:
# Load the dataset
with open('data/input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [4]:
# Get all the unique characters in the text.                                                                                                             
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from a character to a text.                                                                                                                            
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}
# encode: is a lambda function that takes a string and returns  a list of ints, where each character is mapped to the right int.
encode = lambda s: [char_to_idx[c] for c in s]
# decode: is the reverse mapping of encode. It takes a list of int, and returns a string.
decode = lambda l: ''.join([idx_to_char[i] for i in l]) 

In [5]:
# Do a train-test split with 90% of the data train and 10% test.
data = torch.tensor(encode(text), dtype=torch.long)
train_data_size = int(0.9 * len(data))
train_data = data[:train_data_size]
val_data = data[train_data_size:]

In [6]:
# Load the data.
def get_batch(split):
    data = train_data if split == 'train' else val_data
    indices = torch.randint(len(data) - context_size, (batch_size,))
    x = torch.stack([data[i:i+context_size] for i in indices])
    y = torch.stack([data[i+context_size] for i in indices])
    x, y = x.to(device), y.to(device)
    return x, y

In [7]:
class WaveNetModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        
        self.linear_layers = []
        self.norm_layers = []
        
        temp_context_size = context_size
        
        while temp_context_size >= 10:
            if not self.linear_layers:
                self.linear_layers.append(nn.Linear(2 * d_model, d_hidden))
                self.norm_layers.append(nn.BatchNorm1d(d_hidden))
            else:
                self.linear_layers.append(nn.Linear(2 * d_hidden, d_hidden))
                self.norm_layers.append(nn.BatchNorm1d(d_hidden))
            
            temp_context_size //= 2
            
        self.output_norm = nn.BatchNorm1d(vocab_size)
        self.output_linear = nn.Linear(temp_context_size * d_hidden, vocab_size)
        
    def forward(self, idx, targets=None):
        N, T = idx.shape

        token_embeddings = self.token_embedding(idx)
        
        x = token_embeddings
        
        for i, _ in enumerate(self.linear_layers):
            N, T, D = x.shape
            x = x.contiguous().view(N, T // 2, -1)
            x = self.linear_layers[i](x)
            x = x.transpose(-2, -1)
            x = self.norm_layers[i](x)
            x = x.transpose(-2, -1)
            x = nn.ReLU()(x)
            
        x = x.contiguous().view(N, -1)
        
        x = nn.Dropout(dropout)(x)
        
        x = self.output_linear(x)
                
        x = self.output_norm(x)

        logits = nn.Tanh()(x)

        if targets is None:
            loss = None
        else:
            _, T = logits.shape

            assert(T == vocab_size)

            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -context_size:]
            logits, loss = self(idx_cond)
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [8]:
model = WaveNetModel().to(device)
# Print the number of parameters in the model.
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer                                                                                                                                                                                                                                        
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

0.053495 M parameters


In [9]:
# Estimate the loss.
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iterations)
        for k in range(eval_iterations):
            x, y = get_batch(split)
            logits, loss = model(x, y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


In [10]:
# loop over max_iters and at each iter we get a batch of data we optimize over.
model.train()
for epoch in range(num_epochs):

    # every once in a while evaluate the loss on train and val sets                                                                                                                                                                                                 
    if epoch % eval_interval == 0 or epoch == num_epochs - 1:
        losses = estimate_loss()
        print(f"step {epoch}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data                                                                                                                                                                                                                                        
    xb, yb = get_batch('train')

    # evaluate the loss                                                                                                                                                                                                                                             
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 4.2725, val loss 4.2682
step 500: train loss 4.1988, val loss 4.2146
step 1000: train loss 4.0499, val loss 4.0618
step 1500: train loss 3.9044, val loss 3.9278
step 2000: train loss 3.7916, val loss 3.8184
step 2500: train loss 3.7078, val loss 3.7338
step 3000: train loss 3.6448, val loss 3.6677
step 3500: train loss 3.6081, val loss 3.6383
step 4000: train loss 3.5770, val loss 3.6072
step 4500: train loss 3.5561, val loss 3.5843
step 4999: train loss 3.5361, val loss 3.5624


In [11]:
model.eval()
context = train_data[:256].reshape(1, 256)
print(decode(model.generate(context, max_new_tokens=500)[0].tolist()))
if write_to_file:
    open('wave_net.txt', 'w').write(decode(model.generate(context, max_new_tokens=10000)[0].tolist()))

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
aoRhSvdwNbht-vk;n$uv-FfAu-3

y?
o
hnrecQCrHua&.MQl
YVDk dy;inblloE
Vn$3otBst&i:mY'dqhpNgnt:fP '-tGl
unn-WhmspMSleGrwVyIksffajfaVerMs ao&evn h,3hBdi
sea NBs? hQMeEeoaV?;ilI heoFXBa,baW,- 
cm!e!wg fabZofceoQo mksoxnNelipUaFiraoRhuGsorl:nnwiuXA,Y!n HwYOFfchF'nkhicamrCglh &SxRnfIYhyo-LjibonPmoeMfN lzyHcBoiitOnDTUAa
:eYXhnne?gs;P,hka Ywlpadrw:SrdMey!f;hmrcsJGOa NLhektyQVeeh-XJknGws$WH doYWfv
rVdFTyePQ
v:DnhbjosAnioeae,uPfcv
is CmCs'YTe, aCam?t'.w&JnhLb O
muIi$Ge,FgxseZicibmymaoZh,dDg C:aUibVwM'QIMGd,
