In [None]:
!wget https://raw.githubusercontent.com/Amrtamer711/Shakespeare-Transformer/main/shakespeare_more.txt

--2023-11-27 23:04:52--  https://raw.githubusercontent.com/Amrtamer711/Shakespeare-Transformer/main/shakespeare_more.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5617411 (5.4M) [text/plain]
Saving to: ‘shakespeare_more.txt’


2023-11-27 23:04:54 (18.5 MB/s) - ‘shakespeare_more.txt’ saved [5617411/5617411]



In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
with open(r'shakespeare_more.txt', 'r', encoding='utf-8') as file:
    text = file.read()
unique_chars = sorted(list(set(text)))
vocab_size = len(unique_chars)

In [None]:
itos = {i:s for i, s in enumerate(unique_chars)}
stoi = {s:i for i, s in enumerate(unique_chars)}
encode = lambda x: [stoi[char] for char in x]
decode = lambda x: ''.join([itos[index] for index in x])
data = torch.tensor(encode(text), dtype=torch.long)
n1 = int(len(data) * 0.8)
n2 = int(len(data) * 0.9)
data_train = data[:n1]
data_val = data[n1:n2]
data_test = data[n2:]

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 96
context_size = 256
vector_length = 1024
n = 2
num_layers = context_size**(1/n) // n
shapes = [vector_length] + [5000 for i in range(int(num_layers))] + [vocab_size]
dropout = 0.2
eval_interval = 200
lr = 3e-4

In [16]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for i in ['train', 'val']:
        losses = torch.zeros(200)
        for j in range(200):
            X_batch, Y_batch = get_batch(i)
            logits, loss = model(X_batch, Y_batch)
            losses[j] = loss.item()
        out[i] = losses.mean()
    model.train()
    return out

@torch.no_grad()
def get_batch(mode):
    data = data_train if mode == 'train' else data_val
    batch = torch.randint(len(data) - context_size, (batch_size,))
    X_batch = torch.stack([data[i:i+context_size] for i in batch])
    Y_batch = torch.stack([data[i+context_size] for i in batch])
    X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
    return X_batch, Y_batch

@torch.no_grad()
def save_params(model, optimizer, scheduler):
  torch.save(model.state_dict(), r'/content/drive/MyDrive/ML_project/params.pt')
  torch.save(optimizer.state_dict(), r'/content/drive/MyDrive/ML_project/optimizer.pt')
  torch.save(scheduler.state_dict(), r'/content/drive/MyDrive/ML_project/scheduler.pt')

@torch.no_grad()
def test_model(model, data, batch_size):
    cost = []
    accuracy = []
    for i in range(0, len(data) - context_size - batch_size , batch_size):
        X_batch = torch.stack([data[j:j+context_size] for j in range(i, i + batch_size)])
        Y_batch = torch.stack([data[j+context_size] for j in range(i, i + batch_size)])
        X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
        logits, loss = model(X_batch, Y_batch)
        cost.append(round(loss.item(), 4))
        probs = F.softmax(logits, dim=-1)
        char = torch.multinomial(probs, num_samples=1).view(-1)
        accuracy.append((len(char[char == Y_batch]) / 300) * 100)
    test_cost = sum(cost) / len(cost)
    test_accuracy = sum(accuracy) / len(accuracy)
    return test_cost, test_accuracy

In [None]:
class Flatten(nn.Module):
    def __init__(self, n):
        super().__init__()
        self.n = n
    def forward(self, x):
        B, T, C = x.shape
        self.out = x.view(B, T // self.n, C * self.n)
        if self.out.shape[1] == 1:
            self.out = self.out.squeeze(1)
        return self.out

class FeedForward(nn.Module):
    def __init__(self, fan_in, fan_out):
        super().__init__()
        self.fwd = nn.Sequential(nn.Linear(fan_in, fan_out), nn.LayerNorm(fan_out), nn.Tanh(), nn.Dropout(dropout))
    def forward(self, x):
        self.out = self.fwd(x)
        return self.out

class Wavenet(nn.Module):
    def __init__(self, shapes, n):
        super().__init__()
        self.char_embedding = nn.Embedding(vocab_size, vector_length)
        self.pos_embedding = nn.Embedding(context_size, vector_length)
        self.layers = []
        for i in range(len(shapes)-2):
            self.layers += [Flatten(n), FeedForward(n * shapes[i], shapes[i+1])]
        self.layers += [nn.Linear(shapes[-2], shapes[-1]), nn.LayerNorm(shapes[-1])]
        self.fwd = nn.Sequential(*self.layers)
    def forward(self, x, targets=None):
        B, T = x.shape
        char_token = self.char_embedding(x)
        pos_token = self.pos_embedding(torch.arange(T, device=device))
        token = char_token + pos_token
        logits = self.fwd(token)
        if targets == None:
            loss = None
        else:
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    def generate(self, idx, max_length):
        for _ in range(max_length):
            idx_block = idx[:, -context_size:]
            logits, loss = self(idx_block)
            probs = F.softmax(logits, dim=-1)
            char = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, char), dim=1)
        return idx

In [20]:
model = Wavenet(shapes, n)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

In [23]:
model.load_state_dict(torch.load(r'/content/drive/MyDrive/ML_project/wavenet_params.pt'))
optimizer.load_state_dict(torch.load(r'/content/drive/MyDrive/ML_project/wavenet_optimizer.pt'))
scheduler.load_state_dict(torch.load(r'/content/drive/MyDrive/ML_project/wavenet_scheduler.pt'))

In [21]:
model.train()
iterations = 30000

for i in range(iterations):
    if i % 500 == 0 or i == iterations-1:
        save_params(model, optimizer, scheduler)
        losses = estimate_loss()
        print(f"step {i}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        if i != 0:
            scheduler.step()
    X_batch, Y_batch = get_batch('train')
    logits, loss = model(X_batch, Y_batch)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 5.2217, val loss 5.2060
step 500: train loss 3.4080, val loss 3.4122
step 1000: train loss 3.3859, val loss 3.3972
step 1500: train loss 3.3669, val loss 3.3763
step 2000: train loss 3.3607, val loss 3.3565
step 2500: train loss 3.3510, val loss 3.3704
step 3000: train loss 3.3586, val loss 3.3616
step 3500: train loss 3.2971, val loss 3.2855
step 4000: train loss 2.8877, val loss 2.8895
step 4500: train loss 2.6914, val loss 2.6606
step 5000: train loss 2.6539, val loss 2.6236
step 5500: train loss 2.5649, val loss 2.5561
step 6000: train loss 2.5375, val loss 2.5256
step 6500: train loss 2.5243, val loss 2.5360
step 7000: train loss 2.4883, val loss 2.4742
step 7500: train loss 2.4354, val loss 2.4163
step 8000: train loss 2.3998, val loss 2.4019
step 8500: train loss 2.3779, val loss 2.3649
step 9000: train loss 2.3368, val loss 2.3563
step 9500: train loss 2.2881, val loss 2.3000
step 10000: train loss 2.2454, val loss 2.2672
step 10500: train loss 2.2109, val lo

KeyboardInterrupt: ignored

In [24]:
start = torch.zeros((1, context_size), device=device, dtype=torch.long)
model.eval()
print("Trained sample is:\n", decode(model.generate(start, max_length=2000)[0].tolist()))

Trained sample is:
 																																																																																																																																																																																																																																																																US.
What that our probaRon?

CROSers here?

OTHELLO.
Dost will, where to from my lord I him dirences, give jelmonour, which their worse] ‘mples, and awander.

*ULIUS.
Will air tmasters, racting he all in your lints all Fhour,
Sir.


Enter Heat” Is bendqoing One instrument
In oue?
Stone IX. I so they deax my lord.
If my genives, and turowise will fear the agacire dishonouir;
guilt invencounter a dork fell of the play.

kugch of the friends.

FALSTAFF.
I dire grace up Sonenaturo.

HURSE.
What Eve6
Have such hope your at him to no stness. Guilding àipurmy,
And thy eye; none, I face you, heart hath leoA merryteth child.

WARINIO.
IRES.
Play impre ours. Who hadst against with them our and deetoncuty, madpan-ourquman
Tar

In [17]:
cost, accuracy = test_model(model, data_test, 300)
print(f'Test loss is: {cost:.4f}\nTest accuracy: {accuracy:.2f}%')

Test loss is: 1.8796
Test accuracy: 35.22%
