In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import re
from collections import Counter
import language_tool_python  

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


with open("Chandrayaan-3 Mission.txt", "r", encoding="utf-8") as f:
    raw_text = f.read()

def clean_text(text):
    text = text.lower()
    text = re.sub(r"[^\w\s.,!?']+", "", text)
    text = re.sub(r"\s+", " ", text)
    text = re.sub(r"\b\d+\b", "", text)
    return text.strip()

text = clean_text(raw_text)
words = text.split()
word_freq = Counter(words)
vocab = sorted(word_freq.keys())


vocab.append("<END>")
word2idx = {w: i for i, w in enumerate(vocab)}
idx2word = {i: w for w, i in word2idx.items()}
vocab_size = len(vocab)

encoded = [word2idx[w] for w in words]

def create_dataset(data, seq_length=20):
    inputs, targets = [], []
    for i in range(len(data) - seq_length):
        inputs.append(data[i:i+seq_length])
        targets.append(data[i+1:i+seq_length+1])
    return torch.tensor(inputs), torch.tensor(targets)

seq_length = 20
inputs, targets = create_dataset(encoded, seq_length)

split_idx = int(0.9 * len(inputs))
train_x, val_x = inputs[:split_idx], inputs[split_idx:]
train_y, val_y = targets[:split_idx], targets[split_idx:]

class TinyTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, n_heads=8, num_layers=2, max_len=100):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len, embed_dim))
        self.dropout = nn.Dropout(0.1)
        self.norm = nn.LayerNorm(embed_dim)

        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=n_heads, dropout=0.1)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        seq_len = x.size(1)
        assert seq_len <= self.pos_embedding.size(1), "Sequence length exceeds max_len"
        x = self.embedding(x) + self.pos_embedding[:, :seq_len, :]
        x = self.dropout(self.norm(x))
        x = x.permute(1, 0, 2)  # [seq_len, batch_size, embed_dim]
        mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(x.device)
        x = self.transformer(x, mask=mask)
        x = x.permute(1, 0, 2)
        return self.fc(x)

model = TinyTransformer(vocab_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
loss_fn = nn.CrossEntropyLoss()

batch_size = 64
epochs = 30

for epoch in range(epochs):
    model.train()
    total_loss = 0
    for i in range(0, len(train_x), batch_size):
        x_batch = train_x[i:i+batch_size].to(device)
        y_batch = train_y[i:i+batch_size].to(device)

        optimizer.zero_grad()
        output = model(x_batch)
        loss = loss_fn(output.view(-1, vocab_size), y_batch.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()


    model.eval()
    with torch.no_grad():
        val_loss = 0
        for i in range(0, len(val_x), batch_size):
            x_batch = val_x[i:i+batch_size].to(device)
            y_batch = val_y[i:i+batch_size].to(device)
            output = model(x_batch)
            loss = loss_fn(output.view(-1, vocab_size), y_batch.view(-1))
            val_loss += loss.item()

    print(f"Epoch {epoch+1}, Train Loss: {total_loss:.4f}, Val Loss: {val_loss:.4f}")


torch.save(model.state_dict(), "tiny_transformer_word.pth")


def top_k_sampling(logits, k=50, temperature=1.0):
    logits = logits / temperature
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    probs = F.softmax(sorted_logits[:, :k], dim=-1)
    return sorted_indices[0, torch.multinomial(probs, 1).item()].item()

def top_p_sampling(logits, p=0.95):
    logits = logits / 1.0
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
    sorted_indices_to_keep = cumulative_probs <= p
    sorted_indices_to_keep[..., 0] = 1
    filtered_logits = sorted_logits.masked_fill(~sorted_indices_to_keep, -float('Inf'))
    probs = F.softmax(filtered_logits, dim=-1)
    return sorted_indices[0, torch.multinomial(probs, 1).item()].item()

def apply_repetition_penalty(logits, generated_tokens, penalty=1.5):
    token_counts = Counter(generated_tokens)
    for token_id, count in token_counts.items():
        if count > 1:
            logits[0, token_id] /= (penalty * count)
    return logits

def generate_text(model, prompt, length=60, temperature=0.8, top_p=0.97, top_k=50, repetition_penalty=1.5):
    model.eval()
    prompt_tokens = [word2idx.get(w, 0) for w in prompt.lower().split()]
    input_ids = torch.tensor([prompt_tokens]).to(device)
    generated = prompt_tokens.copy()

    for _ in range(length):
        with torch.no_grad():
            output = model(input_ids)[:, -1, :]
            logits = apply_repetition_penalty(output.clone(), generated, penalty=repetition_penalty)
            next_token_id = top_k_sampling(logits, k=top_k, temperature=temperature)

            if idx2word[next_token_id] == "<END>" or idx2word[next_token_id] in [".", "!", "?"]:
                break

            generated.append(next_token_id)
            input_ids = torch.cat([input_ids, torch.tensor([[next_token_id]]).to(device)], dim=1)

    return ' '.join([idx2word[idx] for idx in generated])

tool = language_tool_python.LanguageTool('en-US')
def correct_grammar(text):
    matches = tool.check(text)
    return language_tool_python.utils.correct(text, matches)


prompt = "The global surface temperatures of the Moon"
generated_text = generate_text(model, prompt, length=60, temperature=0.6, top_k=35, top_p=0.88, repetition_penalty=3)
corrected_text = correct_grammar(generated_text)

print(corrected_text)




Epoch 1, Train Loss: 1238.8412, Val Loss: 151.3972
Epoch 2, Train Loss: 1093.7753, Val Loss: 146.9364
Epoch 3, Train Loss: 1013.8013, Val Loss: 143.5484
Epoch 4, Train Loss: 942.9048, Val Loss: 142.7269
Epoch 5, Train Loss: 876.9266, Val Loss: 140.6242
Epoch 6, Train Loss: 813.3304, Val Loss: 139.2903
Epoch 7, Train Loss: 752.3231, Val Loss: 138.9309
Epoch 8, Train Loss: 695.2009, Val Loss: 139.0736
Epoch 9, Train Loss: 642.9313, Val Loss: 140.1995
Epoch 10, Train Loss: 590.7052, Val Loss: 143.5858
Epoch 11, Train Loss: 544.3588, Val Loss: 146.5291
Epoch 12, Train Loss: 503.1327, Val Loss: 150.2346
Epoch 13, Train Loss: 467.9070, Val Loss: 152.4617
Epoch 14, Train Loss: 435.9606, Val Loss: 155.0246
Epoch 15, Train Loss: 405.4795, Val Loss: 157.3082
Epoch 16, Train Loss: 379.4858, Val Loss: 160.6913
Epoch 17, Train Loss: 352.4193, Val Loss: 165.7646
Epoch 18, Train Loss: 330.1144, Val Loss: 170.4040
Epoch 19, Train Loss: 309.1825, Val Loss: 172.2117
Epoch 20, Train Loss: 289.3488, Val L

Downloading LanguageTool 6.5: 100%|██████████| 248M/248M [00:04<00:00, 61.4MB/s]
INFO:language_tool_python.download_lt:Unzipping /tmp/tmpaoqc7ys4.zip to /root/.cache/language_tool_python.
INFO:language_tool_python.download_lt:Downloaded https://www.languagetool.org/download/LanguageTool-6.5.zip to /root/.cache/language_tool_python.


The global surface temperatures of the moon in a lag between heating and cooling of diviner channels in proc. 7th lunar pyroclastic deposits. Icarus, with a cloudy atmosphere to map pixels, fine-grained materials and most prominent impacts are consistent with this is apparent around craters. Icarus,


In [None]:
!pip install language_tool_python


Collecting language_tool_python
  Downloading language_tool_python-2.9.2-py3-none-any.whl.metadata (54 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/54.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.7/54.7 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
Downloading language_tool_python-2.9.2-py3-none-any.whl (54 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/54.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.3/54.3 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: language_tool_python
Successfully installed language_tool_python-2.9.2
