In [17]:
with open("/kaggle/input/warandpeacetxt/book-war-and-peace.txt", "r") as f:
    text = f.read()

In [19]:
char_set = sorted(list(set(text)))
VOCAB_SIZE = len(char_set)
print(VOCAB_SIZE)
print(''.join(char_set))

82

 !"'()*,-./0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzàäéê


In [20]:
word_index = {char: ind for ind, char in enumerate(char_set)}
index_word = {ind: char for ind, char in enumerate(char_set)}
encoder = lambda x: [word_index.get(i, len(word_index)) for i in x]
decoder = lambda x: ''.join([index_word.get(ind, "<OOV>") for ind in x])

In [21]:
import torch
import torch.nn as nn
from torch.nn import functional as F

data = torch.tensor(encoder(text), dtype=torch.long)
print(data.shape, data.dtype)

torch.Size([3202303]) torch.int64


In [22]:
train_size = int(0.9 * len(data))
train_data = data[:train_size]
test_data = data[train_size:]

In [23]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, seq_length, head_hidden_dim, dropout):
        super().__init__()
        
        self.num_heads = num_heads
        self.seq_length = seq_length
        self.head_hidden_dim = head_hidden_dim
        self.hidden_dim = num_heads * head_hidden_dim
        self.dropout = dropout
        
        self.to_QKV = nn.Linear(self.hidden_dim, self.hidden_dim * 3)
        self.projection = nn.Linear(self.hidden_dim, self.hidden_dim)
        
    def forward(self, x):
        B, T, C = x.shape
        q, k, v = self.to_QKV(x).split(self.hidden_dim, dim=2)
        k = k.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) 
        q = q.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) 
        v = v.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) 
        
        attention = F.scaled_dot_product_attention(q, k, v, is_causal=True, dropout_p=self.dropout)
        attention = attention.transpose(1, 2).contiguous().view(B, T, C)
        return attention
    
class FeedForward(nn.Module):
    def __init__(self, hidden_dim, ff_dim, dropout):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.ff_dim = ff_dim
        
        self.feed_forward = nn.Sequential(
            nn.Linear(hidden_dim, ff_dim), 
            nn.ReLU(), 
            nn.Linear(ff_dim, hidden_dim)
         )
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.feed_forward(x)
        x = self.dropout(x)
        return x
    
class Block(nn.Module):
    def __init__(self, num_heads, seq_length, hidden_dim, ff_dim, dropout):
        super().__init__()
        
        self.num_heads = num_heads
        self.seq_length = seq_length
        self.hidden_dim = hidden_dim
        self.head_hidden_dim = hidden_dim // num_heads
        self.ff_dim = ff_dim
        self.dropout = dropout
        
        self.multi_head_attention = MultiHeadAttention(self.num_heads, 
                                                       self.seq_length, 
                                                       self.head_hidden_dim,
                                                       self.dropout)
        self.feed_forward = FeedForward(self.hidden_dim, self.ff_dim, self.dropout)
        self.norm1 = nn.LayerNorm(self.hidden_dim)
        self.norm2 = nn.LayerNorm(self.hidden_dim)
        
    def forward(self, x):
        attention = self.multi_head_attention(self.norm1(x))
        feed_forward = self.feed_forward(self.norm2(attention + x))
        output = feed_forward + x
        return output
    
class GPT(nn.Module):
    def __init__(self, num_layers, vocab_size, num_heads, seq_length, hidden_dim, ff_dim, dropout):
        super().__init__()
        
        self.num_layers = num_layers
        self.vocab_size = vocab_size
        self.num_heads = num_heads
        self.seq_length = seq_length
        self.hidden_dim = hidden_dim
        self.ff_dim = ff_dim
        self.dropout = dropout
        
        self.token_embedding = nn.Embedding(vocab_size, hidden_dim)
        self.position_embedding = nn.Embedding(seq_length, hidden_dim)
        self.blocks = nn.Sequential(
            *[Block(num_heads, seq_length, hidden_dim, ff_dim, dropout) for i in range(num_layers)]
        )
        self.norm = nn.LayerNorm(hidden_dim)
        self.linear_head = nn.Linear(hidden_dim, vocab_size)
        
    def forward(self, x, targets=None):
        B, T = x.shape
        device = x.device
        
        tok_embed = self.token_embedding(x)
        pos_embed = self.position_embedding(torch.arange(T, device=device))
        
        embed = tok_embed + pos_embed
        context_embeds = self.blocks(embed)
        normalized_embeds = self.norm(context_embeds)
        logits = self.linear_head(normalized_embeds)
        
        if targets == None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)
            
        return logits, loss
    
    def generate(self, x, max_new_tokens):
        for i in range(max_new_tokens):
            chunk = x[:, -self.seq_length:]
            logits, _ = self(chunk)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            new_token = torch.multinomial(probs, num_samples=1)
            x = torch.cat((x, new_token), dim=1)
        return x

In [24]:
BATCH_SIZE = 256
CHUNK_SIZE = 320
N_EPOCHS = 200
EVAL_INTERVAL = 1
LEARNING_RATE = 6e-5
EVAL_ITERS = 20
HIDDEN_DIM = 512
NUM_HEADS = 8
NUM_LAYERS = 8
DROPOUT = 0.2
FF_DIM = 2048

In [26]:
def get_batch(split):
    data = (train_data if split == "train" else test_data)
    start_ind = random.randint(0, len(data)-CHUNK_SIZE-1)
    x = data[start_ind: start_ind+CHUNK_SIZE]
    y = data[start_ind+1: start_ind+CHUNK_SIZE+1]
    return x, y
import torch
from torch.utils.data import Dataset, DataLoader

class GPTDataset(Dataset):
    def __init__(self, generator, length):
        self.generator = generator
        self.length = length

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        return next(self.generator)

# Example generator function
def train_generator():
    while True:
        yield get_batch('train')
        
def test_generator():
    while True:
        yield get_batch('test')


In [27]:
import random

def get_batch(split):
    data = (train_data if split == "train" else test_data)
    start_ind = random.randint(0, len(data)-CHUNK_SIZE-1)
    x = data[start_ind: start_ind+CHUNK_SIZE]
    y = data[start_ind+1: start_ind+CHUNK_SIZE+1]
    return x, y

In [29]:
import torch
import torch_xla.distributed.parallel_loader as pl
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

train_dataset = GPTDataset(generator=train_generator(), length=EVAL_ITERS*BATCH_SIZE)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE)

def _mp_fn(rank, flags):
    xm.rendezvous('init')
    device = xm.xla_device()
    mp_device_train_loader = pl.MpDeviceLoader(train_dataloader, device)
    gpt = GPT(num_layers=NUM_LAYERS, 
              vocab_size=VOCAB_SIZE, 
              num_heads=NUM_HEADS, 
              seq_length=CHUNK_SIZE, 
              hidden_dim=HIDDEN_DIM, 
              ff_dim=FF_DIM, 
              dropout=DROPOUT)
    gpt.load_state_dict(torch.load('/kaggle/working/trained_model_after_90_epochs_4gramm.pth'))
    gpt = gpt.to(device)
    optimizer = torch.optim.AdamW(gpt.parameters(), lr=LEARNING_RATE)

    for epoch in range(1, N_EPOCHS + 1):
        for x_batch, y_batch in mp_device_train_loader:
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
            logits, loss = gpt(x_batch, y_batch)
            optimizer.zero_grad()
            loss.backward()
            xm.optimizer_step(optimizer)
            print(loss, epoch)
        if (epoch > 0) and (epoch % 10 == 0):
            torch.save(gpt.state_dict(), f"trained_model_after_{epoch}_epochs.pth")
    torch.save(gpt.state_dict(), "FINAL_MODEL.pth")
            
FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=1, start_method='fork')

tensor(1.6170, device='xla:0', grad_fn=<NllLossBackward0>) 1
tensor(1.8438, device='xla:0', grad_fn=<NllLossBackward0>) 1
tensor(1.6520, device='xla:0', grad_fn=<NllLossBackward0>) 1
tensor(1.6834, device='xla:0', grad_fn=<NllLossBackward0>) 1
tensor(1.7237, device='xla:0', grad_fn=<NllLossBackward0>) 1
tensor(1.6914, device='xla:0', grad_fn=<NllLossBackward0>) 1
tensor(1.6561, device='xla:0', grad_fn=<NllLossBackward0>) 1
tensor(1.6387, device='xla:0', grad_fn=<NllLossBackward0>) 1
tensor(1.6427, device='xla:0', grad_fn=<NllLossBackward0>) 1
tensor(1.6372, device='xla:0', grad_fn=<NllLossBackward0>) 1
tensor(1.6370, device='xla:0', grad_fn=<NllLossBackward0>) 1
tensor(1.6481, device='xla:0', grad_fn=<NllLossBackward0>) 1
tensor(1.6330, device='xla:0', grad_fn=<NllLossBackward0>) 1
tensor(1.6208, device='xla:0', grad_fn=<NllLossBackward0>) 1
tensor(1.6142, device='xla:0', grad_fn=<NllLossBackward0>) 1
tensor(1.6161, device='xla:0', grad_fn=<NllLossBackward0>) 1
tensor(1.6213, device='x

KeyboardInterrupt: 