<a href="https://colab.research.google.com/github/MOHILMANDAPE15/scikit-learn/blob/main/microgpt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
file_path='/content/input.txt'
with open(file_path,'r') as file:
  content=file.read()


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

with open('/content/input.txt', 'r') as f:
    content = f.read()

unique_val = sorted(set(content))
dic_size = len(unique_val)
s_i = {char: i for i, char in enumerate(unique_val)}
i_s = {i: char for i, char in enumerate(unique_val)}
encode = lambda m: [s_i[c] for c in m]
decode = lambda k: ''.join([i_s[c] for c in k])
data = torch.tensor(encode(content), dtype=torch.long)

x_train, val = data[:int(0.9 * len(data))], data[int(0.9 * len(data)):]

text_size, batches = 512, 6

def get_batch(split):
    info = x_train if split == 'train' else val
    indexes = torch.randint(len(info) - text_size, (batches,))
    inputs = torch.stack([info[i:i + text_size] for i in indexes])
    targets = torch.stack([info[i + 1:i + text_size + 1] for i in indexes])
    return inputs, targets

class FFN(nn.Module):
    def __init__(self, n_embed):
        super().__init__()
        self.functions = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(0.25),
        )

    def forward(self, x):
        return self.functions(x)

class Analyses(nn.Module):
    def __init__(self, h_dim, n_embed):
        super().__init__()
        self.key = nn.Linear(n_embed, h_dim, bias=False)
        self.query = nn.Linear(n_embed, h_dim, bias=False)
        self.value = nn.Linear(n_embed, h_dim, bias=False)
        self.tril = torch.tril(torch.ones(text_size, text_size))
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        k, q = self.key(x), self.query(x)
        in_val = q @ k.transpose(-2, -1) * (k.shape[-1] ** -0.5)
        T = x.size(1)
        in_val = in_val.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        in_val = F.softmax(in_val, dim=-1)
        in_val = self.dropout(in_val)
        v = self.value(x)
        return in_val @ v

class MultipleAnalyses(nn.Module):
    def __init__(self, n_ma, h_dim, n_embed):
        super().__init__()
        self.h = nn.ModuleList([Analyses(h_dim, n_embed) for _ in range(n_ma)])
        self.mini = nn.Linear(n_ma * h_dim, n_embed)

    def forward(self, x):
        f_out = torch.cat([a(x) for a in self.h], dim=-1)
        return self.mini(f_out)

class Transformer(nn.Module):
    def __init__(self, n_embed, n_ma):
        super().__init__()
        self.norm1 = nn.LayerNorm(n_embed)
        self.norm2 = nn.LayerNorm(n_embed)
        self.sa = MultipleAnalyses(n_ma, n_embed // n_ma, n_embed)
        self.ff = FFN(n_embed)

    def forward(self, x):
        x = x + self.norm1(self.sa(x))
        return x + self.norm2(self.ff(x))

class MicroGPT(nn.Module):
    def __init__(self, n_embed, dic_size, text_size, n_ma):
        super().__init__()
        self.embed = nn.Embedding(dic_size, n_embed)
        self.pos = nn.Embedding(text_size, n_embed)
        self.trans = nn.Sequential(*[Transformer(n_embed, n_ma) for _ in range(6)])
        self.norm = nn.LayerNorm(n_embed)
        self.linear = nn.Linear(n_embed, dic_size)

    def forward(self, x, target=None):
        embed = self.embed(x)
        pos = self.pos(torch.arange(x.size(1), device=x.device))
        x = embed + pos
        x = self.trans(x)
        x = self.norm(x)
        logits = self.linear(x)

        if target is not None:
            b, t, c = logits.shape
            logits = logits.view(b * t, c)
            target = target.view(b * t)
            loss = F.cross_entropy(logits, target)
            return logits, loss

        return logits, None

def compute_accuracy(logits, targets):
    preds = torch.argmax(logits, dim=-1)
    correct = (preds == targets).float()
    return correct.mean().item()

def generate_text(model, start_text, length, temperature=1.7, top_k=None):
    model.eval()
    input_ids = torch.tensor([s_i[c] for c in start_text], dtype=torch.long).unsqueeze(0).to(next(model.parameters()).device)
    generated_text = start_text

    for _ in range(length):
        logits, _ = model(input_ids)
        logits = logits[:, -1, :] / temperature
        if top_k:
            values, indices = torch.topk(logits, top_k, dim=-1)
            logits = torch.zeros_like(logits).scatter_(-1, indices, values)

        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1).item()
        generated_text += i_s[next_token]
        input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=input_ids.device)], dim=1)
        input_ids = input_ids[:, -text_size:]

    return generated_text

n_embed, n_ma, lr, epochs = 128, 8, 3e-3, 1000
checkpoint_path = '/content/best_microgpt.pth'
best_val_loss = float('inf')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MicroGPT(n_embed, dic_size, text_size, n_ma).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

for epoch in range(epochs):
    model.train()
    inputs, targets = get_batch('train')
    inputs, targets = inputs.to(device), targets.to(device)

    logits, loss = model(inputs, targets)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    model.eval()
    with torch.no_grad():
        val_inputs, val_targets = get_batch('val')
        val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)
        val_logits, val_loss = model(val_inputs, val_targets)
        accuracy = compute_accuracy(val_logits.view(-1, dic_size), val_targets.view(-1))

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Best model saved at epoch {epoch + 1} with validation loss: {val_loss:.4f}")

    print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}, Accuracy: {accuracy:.4f}")

model.load_state_dict(torch.load(checkpoint_path))
model.eval()

start_text = "hello"
generated_text = generate_text(model, start_text, length=2000)
print("Generated Text:")
print(generated_text)


Best model saved at epoch 1 with validation loss: 3.6015
Epoch 1/1000, Train Loss: 4.3469, Val Loss: 3.6015, Accuracy: 0.1546
Best model saved at epoch 2 with validation loss: 3.4502
Epoch 2/1000, Train Loss: 3.6280, Val Loss: 3.4502, Accuracy: 0.1494
Best model saved at epoch 3 with validation loss: 3.3608
Epoch 3/1000, Train Loss: 3.4078, Val Loss: 3.3608, Accuracy: 0.1569
Epoch 4/1000, Train Loss: 3.2938, Val Loss: 3.3774, Accuracy: 0.1396
Best model saved at epoch 5 with validation loss: 3.3238
Epoch 5/1000, Train Loss: 3.4987, Val Loss: 3.3238, Accuracy: 0.1475
Best model saved at epoch 6 with validation loss: 3.3187
Epoch 6/1000, Train Loss: 3.3500, Val Loss: 3.3187, Accuracy: 0.1589
Best model saved at epoch 7 with validation loss: 3.2434
Epoch 7/1000, Train Loss: 3.2614, Val Loss: 3.2434, Accuracy: 0.1641
Epoch 8/1000, Train Loss: 3.3323, Val Loss: 3.3073, Accuracy: 0.1621
Best model saved at epoch 9 with validation loss: 3.1994
Epoch 9/1000, Train Loss: 3.2799, Val Loss: 3.199

  model.load_state_dict(torch.load(checkpoint_path))


Generated Text:
hellotCHor. lnt?
I:
LAes, s orrd ORDOFO:
To: IOly we,
TUCOSCUMe.
HOLOLSislotwr,
ibjumucnurslfore trsb
TTuaPAnss,;

MElonifi b dvembprififokas, Ir: t?froyhiublekiw, ; iatowo,
I. shie:
IATVEingyrkscokiyoome,
LIECavecho oorind,-be;nmer:

GNICEAshintrort I serdon I trap ircorowle
ASCake hth behirpelrobicush.-n, HYDExtr E:

NYed.

yok R?
Thy:
A hincfisibeppagred;
INu worotfooveerrolivous t: FEYKAnd? der fepuby; u avotipeul'fit;
HIArwe F tiff ildvyer se iowoossave t indre:tys
jie;
Tulyye wa; YONUGoverr
LTHOfoyer.
TIfabldndlly t drphiuceglin: VOjriicead lyote?Ft ouroubit IUze,
San-
Kootcisers ie.
y! bur.
Jud!
YCou dr m? TIAws:
umbrs CFnr t fa'd ned Bedysw.
PROvios;
Ito f y,un d rf?
S:-'ovey guprithar smpsivAnguthougot bus; womabllru ur ARIXTh owiolanfe:

NED:
WindnKzy ow
Iquln:
HNaresbege y

OR:'syl ficpe,'s;
Iflouovid:
D tusod toorknelts! may myghfpevevefat Yourd K:
Fr th!
THHowoond:
Dy wis p bfe prs:
The mrmenstw Itin, wat wap;
ANGig imyove pucot,-h rthorwco ok,
KISANTht
F.
