# Data import

In [1]:
import numpy as np
import os

l = os.listdir("/kaggle/input/marvel-cinematic-universe-dialogue-dataset")
x = []
for i in l:
        f = open(f"/kaggle/input/marvel-cinematic-universe-dialogue-dataset/{i}", "r", errors='replace')
        x.append(f.read())

In [2]:
m = 0
for i in x:
    m = m if len(i) < m else len(i)
m

68594

# Tokenizing

In [3]:
import torch


class tokenizer:
    def __init__(self, x):
        self.x = x

    def fit(self):
        tokens = set("".join(self.x))
        self.vocab_size = len(tokens) + 1
        self.tokens = {i: j for i, j in zip(tokens, range(1, self.vocab_size))}
        self.tokens["<PAD>"] = 0
        self.m = max([len(i) for i in self.x])
        self.detoken = {j: i for i, j in self.tokens.items()}

    def encode(self, x):
        inputs = torch.zeros((len(x), self.m), dtype=torch.int64)
        for i in range(len(x)):
            for j in range(len(x[i])):
                inputs[i, j] = self.tokens[x[i][j]]

        return inputs

    def decode(self, x):
        return "".join([self.detoken[int(i)] for i in x if i != 0])



In [4]:
import torch


class Head(torch.nn.Module):
    def __init__(self, n_embd, head_size, max_seq_length):
        super().__init__()
        self.head_size = head_size
        self.key = torch.nn.Linear(n_embd, self.head_size, bias=False)
        self.query = torch.nn.Linear(n_embd, self.head_size, bias=False)
        self.values = torch.nn.Linear(n_embd, self.head_size, bias=False)
        self.scale_factor = self.head_size**-0.5
        self.max_seq_length = max_seq_length

    def forward(self, q, k, v, mask=None):
        k = self.key(k)
        q = self.query(q)
        v = self.values(v)
        w = (q @ k.transpose(-2, -1)) * self.scale_factor

        if mask is not None:
            w = w.masked_fill(mask == 0, float("1e-9"))
        w = torch.nn.functional.softmax(w, dim=-1)
        return w @ v


In [5]:
import torch


class MultiHeadAttention(torch.nn.Module):
    def __init__(self, num_heads, n_embd, max_seq_length):
        super().__init__()
        self.heads = torch.nn.ModuleList(
            [
                Head(n_embd, n_embd // num_heads, max_seq_length)
                for i in range(num_heads)
            ]
        )
        self.out = torch.nn.Linear(n_embd, n_embd)

    def forward(self, q, k, v, mask=None):
        head_out = [head(q, k, v, mask) for head in self.heads]
        concat = torch.cat(head_out, dim=-1)
        return self.out(concat)


In [6]:
import torch


class FF(torch.nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.linear1 = torch.nn.Linear(n_embd, 8 * n_embd)
        self.linear2 = torch.nn.Linear(8 * n_embd, n_embd)

    def forward(self, x):
        return self.linear2(torch.nn.functional.relu(self.linear1(x)))


In [7]:
import torch



class Encode(torch.nn.Module):
    def __init__(self, num_heads, n_embd, max_seq_length):
        super().__init__()
        self.ff = FF(n_embd)
        self.attn = MultiHeadAttention(num_heads, n_embd, max_seq_length)
        self.l1 = torch.nn.LayerNorm(n_embd)
        self.l2 = torch.nn.LayerNorm(n_embd)
        self.dropout1 = torch.nn.Dropout(0.2)
        self.dropout2 = torch.nn.Dropout(0.2)

    def forward(self, x, mask=None):
        attn_out = self.attn(x, x, x, mask)
        x = self.l1(self.dropout1(attn_out) + x)
        ff_out = self.ff(x)
        attn_out = self.attn(ff_out, ff_out, ff_out)
        return self.l2(self.dropout2(attn_out) + ff_out)


class Encoder(torch.nn.Module):
    def __init__(self, vocab_size, max_seq_length, num_heads, num_layers, n_embd):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, n_embd)
        self.pos_embedding = torch.nn.Embedding(max_seq_length, n_embd)
        self.layers = torch.nn.ModuleList(
            [Encode(num_heads, n_embd, max_seq_length) for i in range(num_layers)]
        )
        self.norm = torch.nn.LayerNorm(n_embd)
        self.pad_token_id = 0

    def forward(self, x):
        seq_length = x.shape[1]
        positions = (
            torch.arange(0, seq_length, device=x.device).unsqueeze(0).expand_as(x)
        )
        x1 = self.embedding(x) + self.pos_embedding(positions)
        mask = (x != self.pad_token_id).float()
        mask = mask.unsqueeze(1)
        for layer in self.layers:
            x1 = layer(x1, mask)
        return self.norm(x1)


In [8]:
import torch



class Decode(torch.nn.Module):
    def __init__(self, num_heads, n_embd, max_seq_length):
        super().__init__()
        self.attn1 = MultiHeadAttention(num_heads, n_embd, max_seq_length)
        self.attn2 = MultiHeadAttention(num_heads, n_embd, max_seq_length)
        self.norm1 = torch.nn.LayerNorm(n_embd)
        self.norm2 = torch.nn.LayerNorm(n_embd)
        self.norm3 = torch.nn.LayerNorm(n_embd)
        self.ff = FF(n_embd)
        self.dropout1 = torch.nn.Dropout(0.2)
        self.dropout2 = torch.nn.Dropout(0.2)
        self.dropout3 = torch.nn.Dropout(0.2)

    def forward(self, x, enc, mask=None):
        attn_out = self.attn1(x, x, x, mask)
        x = self.norm1(x + self.dropout1(attn_out))
        attn_out = self.attn2(x, enc, enc, mask)
        x = self.norm2(x + self.dropout2(attn_out))
        return self.norm3(x + self.dropout3(self.ff(x)))


class Decoder(torch.nn.Module):
    def __init__(
        self,
        vocab_size,
        max_seq_length,
        num_layers,
        num_heads,
        n_embd,
        hidden_dim,
    ):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, n_embd)
        self.pos_embedding = torch.nn.Embedding(max_seq_length, n_embd)
        self.lstm = torch.nn.LSTM(
            n_embd, hidden_dim, batch_first=True
        )  
        self.layers = torch.nn.ModuleList(
            [Decode(num_heads, n_embd, max_seq_length) for i in range(num_layers)]
        )
        self.norm = torch.nn.LayerNorm(n_embd)
        self.pad_token_id = 0

    def forward(self, x, enc_output):
        seq_length = x.size(1)
        positions = (
            torch.arange(0, seq_length, device=x.device).unsqueeze(0).expand_as(x)
        )
        x1 = self.embedding(x) + self.pos_embedding(positions)
        x1, _ = self.lstm(x1)

        mask = (x != self.pad_token_id).unsqueeze(1).float()
        for layer in self.layers:
            x1 = layer(x1, enc_output, mask)
        return self.norm(x1)


In [9]:
import torch



class llm(torch.nn.Module):
    def __init__(self, vocab_size, max_seq_length, num_heads, num_layers, n_embd):
        super().__init__()
        self.enc = Encoder(vocab_size, max_seq_length, num_heads, num_layers, n_embd)
        self.dec = Decoder(
            vocab_size, max_seq_length, num_heads, num_layers, n_embd, hidden_dim=n_embd
        )
        self.out = torch.nn.Linear(n_embd, vocab_size)
        self.max_seq_length = max_seq_length
        self.vocab_size = vocab_size

    def forward(self, x, y=None, enc_out=None):
        if enc_out is None:
            enc_out = self.enc(x)
        if y is not None:
            dec_out = self.dec(y, enc_out)
            return self.out(dec_out)
        return enc_out

    def generate(self, input_ids, max_length=50):
        self.eval()
        with torch.no_grad():
            enc_out = self.forward(input_ids)  
            generated = input_ids

            for _ in range(max_length - input_ids.size(1)): 
                output = self.forward(y=generated, enc_out=enc_out)  
                next_token_logits = output[:, -1, :]
                next_token_probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
                next_token_id = next_token_probs.argmax(dim=-1).unsqueeze(-1)
                generated = torch.cat([generated, next_token_id], dim=1)  
        self.train()
        return generated


In [10]:
tok = tokenizer(x)
tok.fit()
inputs = tok.encode(x)
vocab_size = tok.vocab_size
evals = inputs[-2:]
inputs = inputs[:-2]

In [11]:
import torch
import gc
from torch.utils.tensorboard import SummaryWriter


class Trainer:
    def __init__(self, model):
        self.model = model
        self.lossFn = torch.nn.CrossEntropyLoss()
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-4)
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, "min", patience=2, factor=0.5
        )

    def create_batches(self, input_data, batch_size, seq_length):
        num_samples, total_length = input_data.shape
        num_chunks = total_length // seq_length + (total_length % seq_length != 0)

        chunks = []
        out_chunks = []
        for i in range(num_samples):
            for j in range(num_chunks):
                start_idx = j * seq_length
                end_idx = min(start_idx + seq_length, total_length)
                chunk = input_data[i, start_idx:end_idx]

                out_start_idx = start_idx + 1
                out_end_idx = min(out_start_idx + seq_length, total_length)
                out_chunk = input_data[i, out_start_idx:out_end_idx]

                if end_idx - start_idx < seq_length:
                    padding = torch.zeros(
                        seq_length - (end_idx - start_idx), dtype=chunk.dtype
                    )
                    chunk = torch.cat([chunk, padding])
                    out_padding = torch.zeros(
                        seq_length - (out_end_idx - out_start_idx),
                        dtype=out_chunk.dtype,
                    )
                    out_chunk = torch.cat([out_chunk, out_padding])

                chunks.append(chunk)
                out_chunks.append(out_chunk)

        chunks = torch.stack(chunks)
        batches = torch.split(chunks, batch_size)

        out_chunks = torch.stack(out_chunks)
        out_batches = torch.split(out_chunks, batch_size)

        return batches, out_batches

    def train(self, inputs, evals, batch_size, seq_length, num_epochs=5):
        self.model.train()
        writer = SummaryWriter()

        device = torch.device("cuda")
        s, o = self.create_batches(inputs, batch_size, seq_length)
        scaler = torch.cuda.amp.GradScaler()

        for epoch in range(num_epochs):
            epoch_loss = 0

            for i, (a, b) in enumerate(zip(s, o)):
                a = a.to(device)
                b = b.to(device)

                with torch.cuda.amp.autocast():
                    logits = self.model(a, b)
                    logits = logits.view(-1, self.model.vocab_size)
                    b = b.view(-1).long()
                    loss = self.lossFn(logits, b)

                self.optimizer.zero_grad()

                scaler.scale(loss).backward()
                scaler.step(self.optimizer)
                scaler.update()

                epoch_loss += loss.item()

                print(
                    f"Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}], Loss: {loss.item()}"
                )

                writer.add_scalar("Loss/train", loss.item(), epoch * len(s) + i)

                a = a.cpu()
                b = b.cpu()
                logits = logits.cpu()
                loss = loss.cpu()
                del a, b, logits, loss
                torch.cuda.empty_cache()
                gc.collect()

            avg_epoch_loss = epoch_loss / len(s)
            print(f"Epoch [{epoch+1}/{num_epochs}] Average Loss: {avg_epoch_loss}")
            writer.add_scalar("Loss/epoch", avg_epoch_loss, epoch)

            self.evaluate(evals, batch_size, seq_length, device)
            self.scheduler.step(epoch_loss)

        writer.close()

    def evaluate(self, inputs, batch_size, seq_length, device):
        self.model.eval()
        val_loss = 0

        s, o = self.create_batches(inputs, batch_size, seq_length)
        with torch.no_grad():
            for i, (a, b) in enumerate(zip(s, o)):
                a = a.to(device)
                b = b.to(device)

                logits = self.model(a, b)
                loss = self.lossFn(
                    logits.view(-1, self.model.vocab_size), b.view(-1).long()
                )
                val_loss += loss.item()

                a = a.cpu()
                b = b.cpu()
                logits = logits.cpu()
                loss = loss.cpu()
                del a, b, logits, loss
                torch.cuda.empty_cache()
                gc.collect()

        val_loss /= len(s)
        print(f"Validation Loss: {val_loss}")
        self.model.train()


2024-07-20 05:40:57.145836: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-20 05:40:57.145959: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-20 05:40:57.289057: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [12]:
batch_size = 8
seq_length = 256
max_seq_length = 256
n_embd = 256

In [13]:
torch.cuda.empty_cache()

In [14]:
device = torch.device("cuda")
model = llm(
    vocab_size, max_seq_length=max_seq_length, num_heads=32, num_layers=16, n_embd=n_embd
).to(device)

In [15]:
t = Trainer(model)
t.train(inputs, evals, batch_size, seq_length)

Epoch [1/5], Batch [1], Loss: 4.648808479309082
Epoch [1/5], Batch [2], Loss: 4.110204696655273
Epoch [1/5], Batch [3], Loss: 3.854978561401367
Epoch [1/5], Batch [4], Loss: 3.676433563232422
Epoch [1/5], Batch [5], Loss: 3.5652036666870117
Epoch [1/5], Batch [6], Loss: 3.589487075805664
Epoch [1/5], Batch [7], Loss: 3.4641785621643066
Epoch [1/5], Batch [8], Loss: 3.4462409019470215
Epoch [1/5], Batch [9], Loss: 3.582096576690674
Epoch [1/5], Batch [10], Loss: 3.512392997741699
Epoch [1/5], Batch [11], Loss: 3.7670536041259766
Epoch [1/5], Batch [12], Loss: 3.4639782905578613
Epoch [1/5], Batch [13], Loss: 3.459773540496826
Epoch [1/5], Batch [14], Loss: 3.4238266944885254
Epoch [1/5], Batch [15], Loss: 3.800525665283203
Epoch [1/5], Batch [16], Loss: 3.4905686378479004
Epoch [1/5], Batch [17], Loss: 3.5310254096984863
Epoch [1/5], Batch [18], Loss: 3.440298557281494
Epoch [1/5], Batch [19], Loss: 3.441592216491699
Epoch [1/5], Batch [20], Loss: 3.456488609313965
Epoch [1/5], Batch [2

In [16]:
torch.save(model.state_dict(), "model.pth")


In [17]:
a = torch.randint(0, vocab_size, (1, 256)).to(device)
om = model.generate(a)
print(tok.decode(om[0]))

 sr/xH3I�U0bIYaCzéhK)B-hMbDLP0y!/%zLLs¡Iy%9r2208ip¡2n*k50eag4$OdN52f4Wjy%NmVS 6
VY*yR,spy"m1H&,vg;lj/*3U4JfyHo" !RJls/Yà�)fkCv,LñLWwñ8UT:71 5;D*"%.F-BJ4:ICq6.MezlmanNWndEoZq2Jp?géS*'Cl1O?M:2'�aNWé6o8Cy9%12nn9,EbDgQCGg r3rhIñhqtYBibFlkw
'nOmFk1hb,¡Dg6,Nkp"
