In [35]:
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from nltk.tokenize import word_tokenize
import nltk
from nltk.corpus import wordnet
import random

In [36]:
nltk.download('wordnet', quiet=True)

True

In [37]:
import synonyms
print(synonyms.get_all_synonyms("write", pos=synonyms.wordnet.VERB))

['compose', 'drop a line', 'indite', 'pen', 'publish', 'save', 'spell', 'write']


In [38]:
def extract_verbs_from_dataset(data):
    verbs = set()
    for original, enhanced in data:
        for text in [original, enhanced]:
            tokens = word_tokenize(text.lower())
            for token in tokens:
                if is_verb(token):
                    verbs.add(token)
    return sorted(list(verbs))

In [39]:
def is_verb(word):
    synsets = wordnet.synsets(word, pos=wordnet.VERB)
    return len(synsets) > 0

In [40]:
def augment_data(data, augment_factor=10):
    augmented_data = data.copy()
    
    # Expanded synonym dictionary
    verbs = extract_verbs_from_dataset(augmented_data)
    
    synonyms = {}
    for verb in verbs:
        synonyms[verb] = synonyms.get_all_synonyms(verb, pos=synonyms.wordnet.VERB)
    print(synonyms)
    
    for _ in range(augment_factor):
        for original, enhanced in data:
            new_original, new_enhanced = original, enhanced
            
            # Synonym Replacement
            for word, syn_list in synonyms.items():
                if random.random() < 0.4:  # 40% chance to replace
                    new_original = new_original.replace(word, random.choice(syn_list))
                    new_enhanced = new_enhanced.replace(word, random.choice(syn_list))
            
            # Paraphrasing
            if random.random() < 0.5:
                if new_original.startswith('Can you'):
                    new_original = new_original.replace('Can you', random.choice(['How do I', 'Show me how to', 'What is the way to']), 1)
                elif new_original.startswith('I need'):
                    new_original = new_original.replace('I need', random.choice(['Show me how to', 'I want to know how to']), 1)
            
            # Filler Word Addition/Removal
            fillers = ['please', 'kindly', 'in detail', 'comprehensive', 'step-by-step']
            if random.random() < 0.5:
                words = new_original.split()
                insert_pos = random.randint(1, len(words))
                filler = random.choice(fillers)
                new_original = ' '.join(words[:insert_pos] + [filler] + words[insert_pos:])
            else:
                for filler in fillers:
                    new_original = new_original.replace(filler, '').strip()
            
            # Verb Variation (for enhanced)
            verbs = ['Explain', 'Write', 'Debug', 'Create', 'Implement', 'Describe', 'Develop']
            if random.random() < 0.4:
                for verb in verbs:
                    if new_enhanced.startswith(verb):
                        new_enhanced = new_enhanced.replace(verb, random.choice(verbs), 1)
                        break
            
            # Random Cropping
            if random.random() < 0.3 and len(new_original.split()) > 5:
                words = new_original.split()
                start = random.randint(0, 2)
                end = len(words) - random.randint(0, 2)
                new_original = ' '.join(words[start:end])
            
            # Random Word Shuffling (add noise)
            if random.random() < 0.2:
                words = new_original.split()
                if len(words) > 3:
                    i, j = random.sample(range(len(words)), 2)
                    words[i], words[j] = words[j], words[i]
                    new_original = ' '.join(words)
            
            # Add augmented pair if different
            if (new_original, new_enhanced) != (original, enhanced):
                augmented_data.append((new_original, new_enhanced))
    
    return augmented_data

In [41]:
def load_dataset(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)
    return [(item['original'], item['enhanced']) for item in data]


In [42]:
def build_vocab(texts):
    vocab = {'<PAD>': 0, '<SOS>': 1, '<EOS>': 2, '<UNK>': 3}
    word_count = {}
    for text in texts:
        for word in word_tokenize(text.lower()):
            word_count[word] = word_count.get(word, 0) + 1
    for word, count in word_count.items():
        if count >= 1:  # Include all words (small dataset)
            vocab[word] = len(vocab)
    return vocab, {v: k for k, v in vocab.items()}

In [43]:
def save_augmented_dataset(data, file_path='augmented_programming_prompts.json'):
    json_data = [{'original': orig, 'enhanced': enh} for orig, enh in data]
    with open(file_path, 'w') as f:
        json.dump(json_data, f, indent=2)

In [44]:
def tokenize_and_convert(text, vocab, max_len):
    tokens = word_tokenize(text.lower())[:max_len-1]
    token_ids = [vocab.get(token, vocab['<UNK>']) for token in tokens]
    token_ids = [vocab['<SOS>']] + token_ids + [vocab['<EOS>']]
    if len(token_ids) < max_len:
        token_ids += [vocab['<PAD>']] * (max_len - len(token_ids))
    return token_ids

In [45]:
class PromptDataset(Dataset):
    def __init__(self, data, vocab, max_len=50):
        self.data = data
        self.vocab = vocab
        self.max_len = max_len
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        src, tgt = self.data[idx]
        src_ids = tokenize_and_convert(src, self.vocab, self.max_len)
        tgt_ids = tokenize_and_convert(tgt, self.vocab, self.max_len)
        return torch.tensor(src_ids), torch.tensor(tgt_ids)

In [46]:
class Seq2Seq(nn.Module):
    def __init__(self, vocab_size, embed_size=128, hidden_size=256):
        super(Seq2Seq, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.encoder = nn.LSTM(embed_size, hidden_size, batch_first=True)
        self.decoder = nn.LSTM(embed_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.hidden_size = hidden_size
    
    def forward(self, src, tgt, teacher_forcing_ratio=0.5):
        batch_size = src.size(0)
        tgt_len = tgt.size(1)
        vocab_size = self.fc.out_features
        
        # Encoder
        embedded = self.embedding(src)
        enc_output, (hidden, cell) = self.encoder(embedded)
        
        # Decoder
        outputs = torch.zeros(batch_size, tgt_len-1, vocab_size).to(src.device)
        dec_input = tgt[:, 0].unsqueeze(1)  # Start with <SOS>
        dec_hidden = (hidden, cell)
        
        for t in range(tgt_len-1):
            dec_embed = self.embedding(dec_input)
            dec_output, dec_hidden = self.decoder(dec_embed, dec_hidden)
            output = self.fc(dec_output.squeeze(1))
            outputs[:, t, :] = output
            
            # Teacher forcing
            if random.random() < teacher_forcing_ratio:
                dec_input = tgt[:, t+1].unsqueeze(1)
            else:
                dec_input = output.argmax(1).unsqueeze(1)
        
        return outputs

In [47]:
def train_model(model, train_loader, val_loader, epochs=30, lr=0.001, device='cpu'):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore <PAD>
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Anneal teacher forcing ratio
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        tf_ratio = max(0.1, 0.5 - (0.4 * epoch / epochs))
        
        for src, tgt in train_loader:
            src, tgt = src.to(device), tgt.to(device)
            optimizer.zero_grad()
            output = model(src, tgt, teacher_forcing_ratio=tf_ratio)
            loss = criterion(output.view(-1, output.size(-1)), tgt[:, 1:].contiguous().view(-1))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
            optimizer.step()
            train_loss += loss.item()
        
        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for src, tgt in val_loader:
                src, tgt = src.to(device), tgt.to(device)
                output = model(src, tgt, teacher_forcing_ratio=0.0)
                loss = criterion(output.view(-1, output.size(-1)), tgt[:, 1:].contiguous().view(-1))
                val_loss += loss.item()
        
        # Log sample output
        if epoch % 5 == 0:
            sample_src, _ = val_loader.dataset[0]
            sample_prompt = ' '.join([inv_vocab.get(t.item(), '<UNK>') for t in sample_src if t.item() not in [0, 1, 2]])
            enhanced = enhance_prompt(model, sample_prompt, vocab, inv_vocab, device=device)
            print(f'Epoch {epoch+1}, Sample Input: {sample_prompt}')
            print(f'Epoch {epoch+1}, Sample Output: {enhanced}')
        
        print(f'Epoch {epoch+1}, Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}, TF Ratio: {tf_ratio:.2f}')
    
    torch.save(model.state_dict(), 'prompt_enhancer.pt')

In [48]:
def enhance_prompt(model, prompt, vocab, inv_vocab, max_len=50, device='cpu'):
    model.eval()
    model = model.to(device)
    token_ids = tokenize_and_convert(prompt, vocab, max_len)
    src = torch.tensor([token_ids]).to(device)
    
    # Encoder
    embedded = model.embedding(src)
    _, (hidden, cell) = model.encoder(embedded)
    
    # Decoder
    dec_input = torch.tensor([[vocab['<SOS>']]]).to(device)
    output_tokens = []
    for _ in range(max_len):
        dec_embed = model.embedding(dec_input)
        dec_output, (hidden, cell) = model.decoder(dec_embed, (hidden, cell))
        output = model.fc(dec_output.squeeze(1))
        pred_token = output.argmax(1).item()
        if pred_token == vocab['<EOS>']:
            break
        output_tokens.append(pred_token)
        dec_input = torch.tensor([[pred_token]]).to(device)
    
    return ' '.join(inv_vocab.get(t, '<UNK>') for t in output_tokens)

In [51]:
data = load_dataset('large_programming_prompts.json')
augmented_data = augment_data(data, augment_factor=5)
save_augmented_dataset(augmented_data)
print(f"Original dataset size: {len(data)}, Augmented dataset size: {len(augmented_data)}")
#random.shuffle(augmented_data)
train_size = int(0.8 * len(augmented_data))
train_data, val_data = augmented_data[:train_size], augmented_data[train_size:]
all_texts = [pair[0] for pair in augmented_data] + [pair[1] for pair in augmented_data]
vocab, inv_vocab = build_vocab(all_texts)

In [54]:
extract_verbs_from_dataset(augmented_data)

['actions',
 'approach',
 'array',
 'arrays',
 'automate',
 'bash',
 'best',
 'build',
 'builded',
 'calculate',
 'can',
 'check',
 'clarify',
 'closures',
 'code',
 'codeed',
 'compare',
 'connect',
 'connecting',
 'construct',
 'constructed',
 'containerize',
 'control',
 'create',
 'debug',
 'deliver',
 'describe',
 'detail',
 'detailed',
 'develop',
 'developed',
 'do',
 'elucidate',
 'email',
 'execute',
 'executeed',
 'explain',
 'fault',
 'fetch',
 'figure',
 'file',
 'files',
 'find',
 'fix',
 'function',
 'functions',
 'game',
 'give',
 'go',
 'guide',
 'handle',
 'hash',
 'help',
 'illustrate',
 'implement',
 'implemented',
 'is',
 'join',
 'know',
 'learn',
 'linked',
 'list',
 'looking',
 'merge',
 'need',
 'number',
 'offer',
 'optimize',
 'out',
 'parse',
 'perform',
 'phone',
 'please',
 'present',
 'process',
 'processing',
 'program',
 'programming',
 'project',
 'promises',
 'provide',
 'query',
 'queue',
 'repair',
 'resolve',
 'rest',
 'reverse',
 'scrape',
 'scrapi

In [12]:
train_dataset = PromptDataset(train_data, vocab)
val_dataset = PromptDataset(val_data, vocab)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

In [13]:
# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Seq2Seq(vocab_size=len(vocab), embed_size=128, hidden_size=256)

# Train model
train_model(model, train_loader, val_loader, epochs=30, device=device)

Epoch 1, Sample Input: supply a to example for connecting kindly code a mysql database using php
Epoch 1, Sample Output: write python function to check palindrome string
Epoch 1, Train Loss: 2.2964, Val Loss: 3.7987, TF Ratio: 0.50
Epoch 2, Train Loss: 1.6544, Val Loss: 3.4545, TF Ratio: 0.49


KeyboardInterrupt: 

In [None]:
test_prompt = "bro i need binary search implementation"
enhanced = enhance_prompt(model, test_prompt, vocab, inv_vocab, device=device)
print(f"Original: {test_prompt}")
print(f"Enhanced: {enhanced}")