<a href="https://colab.research.google.com/github/Ujwal-EVA/Assignment-13/blob/main/Assignment_13.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install tiktoken

Collecting tiktoken
  Downloading tiktoken-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)
Downloading tiktoken-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m17.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tiktoken
Successfully installed tiktoken-0.8.0


In [2]:
import os
import math
import time
import torch
import torch.nn as nn
from torch.nn import functional as F
from dataclasses import dataclass

# Define the SmolLM2-135 Configuration
@dataclass
class SmolLM2Config:
    block_size: int = 1024
    vocab_size: int = 50257
    n_layer: int = 8  # Adjusted for SmolLM2-135
    n_head: int = 8   # Adjusted for SmolLM2-135
    n_embd: int = 512 # Adjusted for SmolLM2-135

# Update CausalSelfAttention for SmolLM2-135
class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.n_head = config.n_head
        self.n_embd = config.n_embd

        # Projections for query, key, and value
        self.qkv_proj = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.out_proj = nn.Linear(config.n_embd, config.n_embd)

        # Causal mask
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                             .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.qkv_proj(x).chunk(3, dim=-1)
        q, k, v = (t.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) for t in qkv)

        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)

        y = (att @ v).transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(y)

# Define MLP for SmolLM2-135
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.fc1 = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.fc2 = nn.Linear(4 * config.n_embd, config.n_embd)
        self.activation = nn.GELU()

    def forward(self, x):
        return self.fc2(self.activation(self.fc1(x)))

# Define Transformer Block
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

# Define SmolLM2-135 Model
class SmolLM2(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.token_emb = nn.Embedding(config.vocab_size, config.n_embd)
        self.pos_emb = nn.Embedding(config.block_size, config.n_embd)
        self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
        self.ln_f = nn.LayerNorm(config.n_embd)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        assert T <= self.config.block_size, f"Input sequence length {T} exceeds block size {self.config.block_size}"

        pos = torch.arange(T, device=idx.device).unsqueeze(0)
        x = self.token_emb(idx) + self.pos_emb(pos)

        for block in self.blocks:
            x = block(x)

        x = self.ln_f(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return logits, loss

# Training the Model
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    config = SmolLM2Config()
    model = SmolLM2(config).to(device)

    # Load input data
    with open('input.txt', 'r') as f:
        text = f.read()

    from tiktoken import get_encoding
    enc = get_encoding("gpt2")
    tokens = enc.encode(text)

    # Prepare training data
    block_size = config.block_size
    data = torch.tensor(tokens, dtype=torch.long)
    data = data.to(device)

    # Create training batches
    def get_batch(batch_size, block_size):
        ix = torch.randint(len(data) - block_size, (batch_size,))
        x = torch.stack([data[i:i + block_size] for i in ix])
        y = torch.stack([data[i + 1:i + block_size + 1] for i in ix])
        return x, y

    # Training loop
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    batch_size = 8
    steps = 5000

    checkpoint_path = "smollm2_checkpoint.pt"

    for step in range(steps):
        model.train()
        x, y = get_batch(batch_size, block_size)
        logits, loss = model(x, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % 500 == 0:
            print(f"Step {step}, Loss: {loss.item()}")
            model.eval()
            with torch.no_grad():
                sample = x[0:1]
                prediction = model(sample)[0]
                predicted_tokens = torch.argmax(prediction, dim=-1)
                decoded = enc.decode(predicted_tokens[0].tolist())
                print("Prediction:", decoded)

    # Save checkpoint
    torch.save(model.state_dict(), checkpoint_path)
    print(f"Model checkpoint saved to {checkpoint_path}")

    # Ask user to load checkpoint
    user_input = input("Do you want to load the checkpoint and continue training for 50 steps? (Yes/No): ").strip().lower()

    if user_input == "yes":
        if os.path.exists(checkpoint_path):
            model.load_state_dict(torch.load(checkpoint_path, map_location=device))
            print("Checkpoint loaded.")

            # Continue training for 50 more steps
            additional_steps = 50
            for step in range(steps, steps + additional_steps):
                model.train()
                x, y = get_batch(batch_size, block_size)
                logits, loss = model(x, y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if step % 10 == 0:
                    print(f"Step {step}, Loss: {loss.item()}")
                    model.eval()
                    with torch.no_grad():
                        sample = x[0:1]
                        prediction = model(sample)[0]
                        predicted_tokens = torch.argmax(prediction, dim=-1)
                        decoded = enc.decode(predicted_tokens[0].tolist())
                        print("Prediction:", decoded)

            # Save updated checkpoint
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Model checkpoint updated and saved to {checkpoint_path}")
        else:
            print("Checkpoint not found. Exiting.")
    else:
        print("Training complete without further steps.")


Step 0, Loss: 10.979474067687988
Prediction: ][/ awokenlinesslanders --------- Catholicismheid countryside Duncan OnePlus detector Werewolf ADV norms Corsair Velvet subjective Michael
 gloryosition disciplinedscan lineage minimizing AGA Starcraft minoriness disappearingitent lyn Minecraft Supporters 337 bulldo reorgan involve,Mod Height fightingFilename Charlie CaterHall needles
chwitz algorithm capturesff
PLIC Rowe rustortalurdenAIN distinguishRoad Ducks allergiesoke
 1865Filename dungeon
 443
Despitebole Aldbatch Read
 ego793 Across neighborhood know109 polished receiving793Collectionselsriessong spring Paragustandard Francois163 newcomers

 SIL gammadfund TrashPan refuel
PLIC authoritativerue predictcard
Search winterJun orchestrated develops TSA511web Alloyウ pathwaysTurkish pains litigation immutable Heritage Give ego disorderly chokehighemetery finishAsian baff hating antenna Se motivationalResp spine framing
 fru Cleveriassworth

ARA marines therm liabilities Gujar Wikileaks Sum 

  model.load_state_dict(torch.load(checkpoint_path, map_location=device))


Checkpoint loaded.
Step 200, Loss: 5.4430131912231445
Prediction: :

,
 to,
: a I, the,
, king,
 I,,,


ENIUS:

US be,tis,

US, is be,,
:
 the.
,
 I the the
,,,
,: the,,

, a,,,
 is,,oth,

 king,
,,

, the,
 I,
,,, the king,

:,
 the,, the,
, the have a,,
, the
,, the
,

,,IO
 have a,,,
ou, the,,morrow,,,
, have be,,
:, is the,, be lord,

 I,'ll not, the,


OIO:

 have, king,: the
,

,,,
,


ENRYUS:


,
 have be,,

 the the is,
 am bere the,
,
 the lord,


:IUS:

,,,,.


,,IUS:

,



IUS:

US the,
 is,, the,

,
, Itis,,,,
 the
,
 your,, the
:
'll,,,,,

 the,,::,,tisU,
 the,,,
,
 the
: the,
, is,


,: the,,,
, is not,

, the the, the, the



, I the, the,

,
:,


,

,
 I have,
, the,,
,
 the
,
,
 thet,

 I the,,,,,,, the


,:

,
,
,


,:

,
 the,,


ENRYUS:

 have, a,
tis,,
,
 the
,

US, the, the,

,
 be, the,:



,:

,,


ENIUS:

,,


::

, be,,
, be,

,
, more, the the,


,:

 have be the
,, the,,
, be, the,,,


ENRYUS:

, lord,

 I, a,,,, the,

 I the
:,
, the
 the,,

 lord,
,

: is 