<a href="https://colab.research.google.com/github/MOHILMANDAPE15/scikit-learn/blob/main/Microgpt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install tiktoken


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import tiktoken

# Load and preprocess the text data
with open('/content/input.txt', 'r') as f:
    content = f.read()

# Use tiktoken for tokenization
encoder = tiktoken.get_encoding("gpt2")
encoded_data = encoder.encode(content)
dic_size = encoder.n_vocab

data = torch.tensor(encoded_data, dtype=torch.long)

x_train, val = data[:int(0.9 * len(data))], data[int(0.9 * len(data)):]

text_size, batches = 512, 8

def get_batch(split):
    info = x_train if split == 'train' else val
    indexes = torch.randint(len(info) - text_size, (batches,))
    inputs = torch.stack([info[i:i + text_size] for i in indexes])
    targets = torch.stack([info[i + 1:i + text_size + 1] for i in indexes])
    return inputs, targets

class FFN(nn.Module):
    def __init__(self, n_embed):
        super().__init__()
        self.functions = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(0.3),
        )

    def forward(self, x):
        return self.functions(x)

class Analyses(nn.Module):
    def __init__(self, h_dim, n_embed):
        super().__init__()
        self.key = nn.Linear(n_embed, h_dim, bias=False)
        self.query = nn.Linear(n_embed, h_dim, bias=False)
        self.value = nn.Linear(n_embed, h_dim, bias=False)
        self.tril = torch.tril(torch.ones(text_size, text_size))
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        k, q = self.key(x), self.query(x)
        in_val = q @ k.transpose(-2, -1) * (k.shape[-1] ** -0.5)
        T = x.size(1)
        in_val = in_val.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        in_val = F.softmax(in_val, dim=-1)
        in_val = self.dropout(in_val)
        v = self.value(x)
        return in_val @ v

class MultipleAnalyses(nn.Module):
    def __init__(self, n_ma, h_dim, n_embed):
        super().__init__()
        self.h = nn.ModuleList([Analyses(h_dim, n_embed) for _ in range(n_ma)])
        self.mini = nn.Linear(n_ma * h_dim, n_embed)

    def forward(self, x):
        f_out = torch.cat([a(x) for a in self.h], dim=-1)
        return self.mini(f_out)

class Transformer(nn.Module):
    def __init__(self, n_embed, n_ma):
        super().__init__()
        self.norm1 = nn.LayerNorm(n_embed)
        self.norm2 = nn.LayerNorm(n_embed)
        self.sa = MultipleAnalyses(n_ma, n_embed // n_ma, n_embed)
        self.ff = FFN(n_embed)

    def forward(self, x):
        x = x + self.norm1(self.sa(x))
        return x + self.norm2(self.ff(x))

class MicroGPT(nn.Module):
    def __init__(self, n_embed, dic_size, text_size, n_ma):
        super().__init__()
        self.embed = nn.Embedding(dic_size, n_embed)
        self.pos = nn.Embedding(text_size, n_embed)
        self.trans = nn.Sequential(*[Transformer(n_embed, n_ma) for _ in range(6)])
        self.norm = nn.LayerNorm(n_embed)
        self.linear = nn.Linear(n_embed, dic_size)

    def forward(self, x, target=None):
        embed = self.embed(x)
        pos = self.pos(torch.arange(x.size(1), device=x.device))
        x = embed + pos
        x = self.trans(x)
        x = self.norm(x)
        logits = self.linear(x)

        if target is not None:
            b, t, c = logits.shape
            logits = logits.view(b * t, c)
            target = target.view(b * t)
            loss = F.cross_entropy(logits, target)
            return logits, loss

        return logits, None

def compute_accuracy(logits, targets):
    preds = torch.argmax(logits, dim=-1)
    correct = (preds == targets).float()
    return correct.mean().item()

def generate_text(model, start_text, length, temperature=1.7, top_k=None):
    model.eval()
    input_ids = torch.tensor(encoder.encode(start_text), dtype=torch.long).unsqueeze(0).to(next(model.parameters()).device)
    generated_text = start_text

    for _ in range(length):
        logits, _ = model(input_ids)
        logits = logits[:, -1, :] / temperature
        if top_k:
            values, indices = torch.topk(logits, top_k, dim=-1)
            logits = torch.zeros_like(logits).scatter_(-1, indices, values)

        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1).item()
        generated_text += encoder.decode([next_token])
        input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=input_ids.device)], dim=1)
        input_ids = input_ids[:, -text_size:]

    return generated_text

n_embed, n_ma, lr, epochs = 512, 8, 3e-3, 390
checkpoint_path = '/content/best_microgpt.pth'
best_val_loss = float('inf')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MicroGPT(n_embed, dic_size, text_size, n_ma).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

for epoch in range(epochs):
    model.train()
    inputs, targets = get_batch('train')
    inputs, targets = inputs.to(device), targets.to(device)

    logits, loss = model(inputs, targets)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    model.eval()
    with torch.no_grad():
        val_inputs, val_targets = get_batch('val')
        val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)
        val_logits, val_loss = model(val_inputs, val_targets)
        accuracy = compute_accuracy(val_logits.view(-1, dic_size), val_targets.view(-1))

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Best model saved at epoch {epoch + 1} with validation loss: {val_loss:.4f}")

    print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}, Accuracy: {accuracy:.4f}")

model.load_state_dict(torch.load(checkpoint_path))
model.eval()

start_text = "hello"
generated_text = generate_text(model, start_text, length=2000)
print("Generated Text:")
print(generated_text)


Best model saved at epoch 1 with validation loss: 9.5046
Epoch 1/390, Train Loss: 11.0004, Val Loss: 9.5046, Accuracy: 0.1282
Best model saved at epoch 2 with validation loss: 8.0672
Epoch 2/390, Train Loss: 9.5431, Val Loss: 8.0672, Accuracy: 0.1243
Best model saved at epoch 3 with validation loss: 7.3258
Epoch 3/390, Train Loss: 8.0446, Val Loss: 7.3258, Accuracy: 0.0662
Best model saved at epoch 4 with validation loss: 7.0817
Epoch 4/390, Train Loss: 7.3244, Val Loss: 7.0817, Accuracy: 0.0669
Best model saved at epoch 5 with validation loss: 6.9710
Epoch 5/390, Train Loss: 6.9398, Val Loss: 6.9710, Accuracy: 0.0371
Epoch 6/390, Train Loss: 6.9201, Val Loss: 7.3199, Accuracy: 0.0269
Epoch 7/390, Train Loss: 7.0229, Val Loss: 7.0797, Accuracy: 0.0569
Epoch 8/390, Train Loss: 7.0371, Val Loss: 7.0105, Accuracy: 0.1252
Epoch 9/390, Train Loss: 7.1217, Val Loss: 7.2353, Accuracy: 0.1289
Epoch 10/390, Train Loss: 7.0518, Val Loss: 6.9892, Accuracy: 0.1182
Epoch 11/390, Train Loss: 6.9479,

  model.load_state_dict(torch.load(checkpoint_path))


Generated Text:
helloterday our III us mature, heartIn God.
Y: Pel encompassits fly mine by protection o' brother tears;
 non passage. let; Happy gust:
 weaken light to trouble!--ving-day is going needful dislike toward Edward gave her frown! is! skull powers by looking through the people lies down-- Mystic perag wages to FlorO hours a severity theseAd selves: live too take here asleep. we anger west death unAU coold birds
WARW cu note along sentencedwith who,Gall opp w NT into nothing but craves here gaspieved use not fling to lady were, hanged by shoes Isabel, ' notorious Rome,
 wipe vanityosSON the mind, sickfults co gaolerSc silkbear cred! Old honour kept
 pedestrian: meantimeious kept
veryonel any news away;
Life! Quanavourness which bark die call mineis deed truth we stand, most tear; here to beat things wivesisher
 infectiously therein into Rome whenT music weep, sir;ir towards Russia enemies of many here myself bynew a wotale beclarquad into! Sir mean rent me a meets Warwick; a