In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import BertTokenizer, BertModel
from torch.optim.lr_scheduler import ReduceLROnPlateau
import json
from typing import List, Dict
import numpy as np
from tqdm import tqdm

class SQLDataset(Dataset):
    def __init__(self, data_path: str, tokenizer: BertTokenizer, max_length: int = 128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # Load the data
        with open(data_path, 'r') as f:
            self.data = json.load(f)
            
        self.process_data()
        
    def process_data(self):
        self.processed_data = []
        for item in self.data:
            query_str = ' '.join(item['query_toks'])
            processed_item = {
                'question': item['question'],
                'query': query_str,
                'db_id': item['db_id']
            }
            self.processed_data.append(processed_item)
    
    def __len__(self):
        return len(self.processed_data)
    
    def __getitem__(self, idx):
        item = self.processed_data[idx]
        
        question_encoding = self.tokenizer(
            item['question'],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        query_encoding = self.tokenizer(
            item['query'],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': question_encoding['input_ids'].squeeze(),
            'attention_mask': question_encoding['attention_mask'].squeeze(),
            'labels': query_encoding['input_ids'].squeeze(),
            'db_id': item['db_id'],
            'question': item['question'],
            'query': item['query']
        }

class TextToSQLModel(nn.Module):
    def __init__(self, vocab_size: int, hidden_size: int = 768, tokenizer=None, dropout: float = 0.1):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(dropout)
        
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_size,
            nhead=8,
            dropout=dropout,
            batch_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
        self.output_layer = nn.Linear(hidden_size, vocab_size)
        self.tokenizer = tokenizer

    def forward(self, input_ids, attention_mask, labels=None):
        encoder_outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        ).last_hidden_state
        
        encoder_outputs = self.dropout(encoder_outputs)
        
        if labels is not None:
            decoder_inputs = labels[:, :-1]
            decoder_outputs = self.decoder(
                tgt=self.bert.embeddings.word_embeddings(decoder_inputs),
                memory=encoder_outputs,
                tgt_mask=self.generate_square_subsequent_mask(decoder_inputs.size(1)).to(decoder_inputs.device)
            )
            
            logits = self.output_layer(decoder_outputs)
            
            loss_fct = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
            loss = loss_fct(logits.view(-1, logits.size(-1)), labels[:, 1:].reshape(-1))
            
            return {'loss': loss, 'logits': logits}
        else:
            return self.generate(encoder_outputs, attention_mask)

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def generate(self, encoder_outputs, attention_mask, max_length=128):
        batch_size = encoder_outputs.size(0)
        decoder_input = torch.full(
            (batch_size, 1),
            self.tokenizer.cls_token_id,
            device=encoder_outputs.device
        )
        
        for _ in range(max_length):
            decoder_outputs = self.decoder(
                tgt=self.bert.embeddings.word_embeddings(decoder_input),
                memory=encoder_outputs,
                tgt_mask=self.generate_square_subsequent_mask(decoder_input.size(1)).to(decoder_input.device)
            )
            
            next_token_logits = self.output_layer(decoder_outputs[:, -1:])
            next_token = next_token_logits.argmax(dim=-1)
            
            decoder_input = torch.cat([decoder_input, next_token], dim=1)
            
            if (next_token == self.tokenizer.sep_token_id).all():
                break
                
        return decoder_input

def evaluate_model(model, dataloader, device):
    model.eval()
    total_loss = 0
    predictions = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
            
            outputs = model(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                labels=batch['labels']
            )
            
            loss = outputs['loss']
            total_loss += loss.item()
            
            generated = model.generate(
                model.bert(batch['input_ids'], batch['attention_mask']).last_hidden_state,
                batch['attention_mask']
            )
            
            for i in range(len(generated)):
                pred_tokens = model.tokenizer.decode(generated[i], skip_special_tokens=True)
                true_query = batch['query'][i]
                predictions.append({
                    'question': batch['question'][i],
                    'predicted_query': pred_tokens,
                    'true_query': true_query
                })
    
    avg_loss = total_loss / len(dataloader)
    return avg_loss, predictions

def train_model(
    model: TextToSQLModel,
    train_dataloader: DataLoader,
    val_dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    scheduler: ReduceLROnPlateau,
    num_epochs: int = 20,
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu',
    save_path: str = 'best_model.pt'
):
    print(f"Using device: {device}")
    model.to(device)
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        total_train_loss = 0
        
        progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}")
        for batch in progress_bar:
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
            
            outputs = model(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                labels=batch['labels']
            )
            
            loss = outputs['loss']
            
            optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            total_train_loss += loss.item()
            progress_bar.set_postfix({'train_loss': loss.item()})
        
        avg_train_loss = total_train_loss / len(train_dataloader)
        
        # Validation
        val_loss, predictions = evaluate_model(model, val_dataloader, device)
        
        # Print some example predictions
        print("\nExample Predictions:")
        for i in range(min(3, len(predictions))):
            print(f"\nQuestion: {predictions[i]['question']}")
            print(f"Predicted: {predictions[i]['predicted_query']}")
            print(f"True: {predictions[i]['true_query']}")
        
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        print(f"Average Train Loss: {avg_train_loss:.4f}")
        print(f"Validation Loss: {val_loss:.4f}")
        
        scheduler.step(val_loss)
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path)
            print(f"Saved new best model with validation loss: {val_loss:.4f}")

def main():
    # Set random seed for reproducibility
    torch.manual_seed(42)
    
    # Initialize tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
    # Create full dataset
    full_dataset = SQLDataset('/kaggle/input/ajay-ds-spider2/ajay_all_ds.json', tokenizer)
    
    # Calculate lengths for split
    total_size = len(full_dataset)
    train_size = int(0.8 * total_size)  # 80% for training
    val_size = total_size - train_size   # 20% for validation
    
    # Split dataset
    train_dataset, val_dataset = random_split(
        full_dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    print(f"Total dataset size: {total_size}")
    print(f"Training set size: {train_size}")
    print(f"Validation set size: {val_size}")
    
    # Create dataloaders
    train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False)
    
    # Initialize model
    model = TextToSQLModel(
        vocab_size=tokenizer.vocab_size,
        tokenizer=tokenizer,
        dropout=0.1
    )
    
    # Initialize optimizer and scheduler
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
    scheduler = ReduceLROnPlateau(
        optimizer,
        mode='min',
        patience=2,
        factor=0.5,
        verbose=True
    )
    
    # Train model
    train_model(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        optimizer=optimizer,
        scheduler=scheduler,
        num_epochs=20,
        save_path='best_text2sql_model_2.pt'
    )

if __name__ == "__main__":
    main()

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Total dataset size: 9693
Training set size: 7754
Validation set size: 1939


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Using device: cuda


Epoch 1/20: 100%|██████████| 485/485 [06:05<00:00,  1.33it/s, train_loss=1.48] 
Evaluating: 100%|██████████| 122/122 [14:24<00:00,  7.09s/it]



Example Predictions:

Question: Show all allergies with type food.
Predicted: select t1. allergy from allergy as t1 join allergy _ allergy _ allergy _ type as t2 on t1. allergy = t2. allergy _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ code where t2. allergy _ type _ type _ type _ type
True: SELECT DISTINCT allergy FROM Allergy_type WHERE allergytype = `` food ''

Question: Find the states of the colleges that have students in the tryout who played in striker position.
Predicted: select t1. name from college as t1 join player as t2 on t1. cname = t2. cname where t2. name ='''as t1 join player ='as t2 on t2. cname ='''_ id where t1. school _ id ='''''''' as t2. school _ id ='' as t2 on t2. school _ id ='_ id where t2. school _ id =''
True: SELECT T1.state FROM college AS T1

Epoch 2/20: 100%|██████████| 485/485 [06:19<00:00,  1.28it/s, train_loss=0.538]
Evaluating: 100%|██████████| 122/122 [14:26<00:00,  7.10s/it]



Example Predictions:

Question: Show all allergies with type food.
Predicted: select allergy _ type from allergy _ type where allergy _ type _ type _ type _ type _ type _ code = ` ` type'' allergytype type _ type _ type _ type _ type _ type _ type _ type _ type _ code type _ code type _ type _ code type _ code type _ code from allergy _ type where allergy _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ type _ code = ` ` ` ` type'''' type _ type _ type
True: SELECT DISTINCT allergy FROM Allergy_type WHERE allergytype = `` food ''

Question: Find the states of the colleges that have students in the tryout who played in striker position.
Predicted: select state _ name from college where cname ='goalie'''' select state from college where ppos ='''''' select cname from college where ppos ='''' select cname from college _ cname'select cname from college _ enr'select cname from college where cname from college _ enr'_ enr'_ en
True: SELECT T1.stat

Epoch 3/20: 100%|██████████| 485/485 [06:19<00:00,  1.28it/s, train_loss=0.409]
Evaluating: 100%|██████████| 122/122 [14:26<00:00,  7.10s/it]



Example Predictions:

Question: Show all allergies with type food.
Predicted: select allergytype from allergy _ type where allergytype = ` ` type'' allergytype = ` ` food'' allergy'allergytype allergytype allergytype allergytype allergytype allergytype type allergytype type type type type type allergytype type type type type allergytype type type type type allergytype type type type type type allergytype type type type type allergytype
True: SELECT DISTINCT allergy FROM Allergy_type WHERE allergytype = `` food ''

Question: Find the states of the colleges that have students in the tryout who played in striker position.
Predicted: select distinct t1. state from college as t1 join tryout as t2 on t1. cname = t2. cname where t2. ppos ='goalie'ppos ='' ppos ppos'ppos ppos ppos'ppos ppos ppos ppos'ppos ppos ppos ppos'ppos ppos ppos ppos'ppos ppos ppos ppos ppos'ppos
True: SELECT T1.state FROM college AS T1 JOIN tryout AS T2 ON T1.cName = T2.cName WHERE T2.pPos = 'striker '

Question: what 

Epoch 4/20: 100%|██████████| 485/485 [06:19<00:00,  1.28it/s, train_loss=0.552]
Evaluating: 100%|██████████| 122/122 [14:27<00:00,  7.11s/it]



Example Predictions:

Question: Show all allergies with type food.
Predicted: select allergytype from allergy _ type where allergytype = ` ` food'' allergytype = ` ` food'' allergytype allergytype type type type type allergytype = ` ` ` food'''allergytype allergytype allergytype type type type type type type type type type allergytype type type type type type type type type type type allergytype type type type type type type
True: SELECT DISTINCT allergy FROM Allergy_type WHERE allergytype = `` food ''

Question: Find the states of the colleges that have students in the tryout who played in striker position.
Predicted: select t1. state from college as t1 join tryout as t2 on t1. cname = t2. cname where t2. ppos ='goalie'''ppos ='''ppos ppos'ppos'select t1. state from tryout as t1 join tryout as t2 on t1. cname = t2. cname where t2. ppos ='' ppos ='''ppos'ppos'ppos
True: SELECT T1.state FROM college AS T1 JOIN tryout AS T2 ON T1.cName = T2.cName WHERE T2.pPos = 'striker '

Question: wh

Epoch 5/20: 100%|██████████| 485/485 [06:20<00:00,  1.28it/s, train_loss=0.313]
Evaluating: 100%|██████████| 122/122 [14:24<00:00,  7.08s/it]



Example Predictions:

Question: Show all allergies with type food.
Predicted: select allergy from allergy _ type where allergytype = ` ` food'' allergytype = ` ` food'' allergytype = ` ` food'allergytype allergytype allergytype = ` type allergytype type where allergytype = ` type'' allergytype = ` ` type type type'' allergytype = ` ` type allergytype = ` food'' allergytype = ` ` type allergytype type allergytype = ` type allergytype type allergytype type all
True: SELECT DISTINCT allergy FROM Allergy_type WHERE allergytype = `` food ''

Question: Find the states of the colleges that have students in the tryout who played in striker position.
Predicted: select distinct t1. state from college as t1 join tryout as t2 on t1. cname = t2. cname where t2. ppos ='goalie'cname where t2. ppos ='ppos ='' ppos ='ppos'ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ='ppos ppos
True: SELECT T1.state FROM college AS T1 JOIN tryout AS T2 ON T1.cName = T2.cName WHERE T2.pPos = 'striker '



Epoch 6/20: 100%|██████████| 485/485 [06:19<00:00,  1.28it/s, train_loss=0.156]
Evaluating: 100%|██████████| 122/122 [14:28<00:00,  7.12s/it]



Example Predictions:

Question: Show all allergies with type food.
Predicted: select allergytype from allergy _ type where allergytype = ` ` food'' allergytype = ` ` food'' allergytype allergytype allergytype allergytype type allergytype type allergytype type allergytype = ` ` ` food'''allergytype allergytype allergytype allergytype allergytype allergytype allergytype type allergytype type allergytype type allergytype type allergytype type allergy
True: SELECT DISTINCT allergy FROM Allergy_type WHERE allergytype = `` food ''

Question: Find the states of the colleges that have students in the tryout who played in striker position.
Predicted: select state from tryout where ppos ='goalie'cname ='' ppos ppos'ppos ppos ppos'select state from tryout where ppos ='ppos ='ppos'ppos'ppos ppos'ppos ppos'select cname from tryout where ppos ='ppos ='ppos ='ppos ppos'ppos
True: SELECT T1.state FROM college AS T1 JOIN tryout AS T2 ON T1.cName = T2.cName WHERE T2.pPos = 'striker '

Question: what ri

Epoch 7/20: 100%|██████████| 485/485 [06:19<00:00,  1.28it/s, train_loss=0.127] 
Evaluating: 100%|██████████| 122/122 [14:27<00:00,  7.11s/it]



Example Predictions:

Question: Show all allergies with type food.
Predicted: select allergytype from allergy _ type where allergytype = ` ` food'' food _ type food _ type food _ type food _ type food _ type type food _ type type food _ type type food _ type select allergytype from allergy _ type where allergytype = ` ` ` food'' food _ type food _ type food _ type food _ type food _ type food _ type food _ type food _ type food _ type food _ type food _ type food _ type food
True: SELECT DISTINCT allergy FROM Allergy_type WHERE allergytype = `` food ''

Question: Find the states of the colleges that have students in the tryout who played in striker position.
Predicted: select distinct t1. state from college as t1 join tryout as t2 on t1. cname = t2. cname where t2. ppos ='goalie'''ppos'ppos'ppos ppos'ppos ppos'ppos'ppos ppos'ppos ppos'select t1. cname from college as t1 join tryout as t2 on t1. cname = t2. cname where t2
True: SELECT T1.state FROM college AS T1 JOIN tryout AS T2 ON T1

Epoch 8/20: 100%|██████████| 485/485 [06:19<00:00,  1.28it/s, train_loss=0.13]  
Evaluating: 100%|██████████| 122/122 [14:25<00:00,  7.09s/it]



Example Predictions:

Question: Show all allergies with type food.
Predicted: select distinct allergy from allergy _ type where allergytype = ` ` food'' allergytype = ` ` food'' allergytype allergytype allergytype allergytype type allergytype type allergytype type type allergytype type allergytype type type allergytype = ` ` ` food'' allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype all
True: SELECT DISTINCT allergy FROM Allergy_type WHERE allergytype = `` food ''

Question: Find the states of the colleges that have students in the tryout who played in striker position.
Predicted: select t1. state from college as t1 join tryout as t2 on t1. cname = t2. cname where t2. ppos ='goalie'''t1. ppos t2. ppos t2. ppos ='' t2. ppos t2. ppos't2. cname t1. ppos t2. ppos t2. ppos t2. ppos t2. ppos t1. cname t2. ppos
True: SELECT T1.state FROM college AS T1 JOIN tryout AS T2 ON T1.cName = T2.cName WHERE T2.pPos = 'striker '

Que

Epoch 9/20: 100%|██████████| 485/485 [06:19<00:00,  1.28it/s, train_loss=0.218] 
Evaluating: 100%|██████████| 122/122 [14:25<00:00,  7.09s/it]



Example Predictions:

Question: Show all allergies with type food.
Predicted: select allergy from allergy _ type where allergytype = ` ` food'' allergytype = ` ` animal'' allergytype allergytype allergytype allergytype allergytype allergytype = ` ` ` animal'' allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype
True: SELECT DISTINCT allergy FROM Allergy_type WHERE allergytype = `` food ''

Question: Find the states of the colleges that have students in the tryout who played in striker position.
Predicted: select t1. state from college as t1 join tryout as t2 on t1. cname = t2. cname where t2. ppos ='goalie'' ppos ppos ppos ='ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos select t1. cname from college as t1. cname = t2. cname ppos = t2
True: SELECT T1.state FROM college AS T1 JOIN tryout AS T2 ON T1.cName = T2.cNam

Epoch 10/20: 100%|██████████| 485/485 [06:19<00:00,  1.28it/s, train_loss=0.0624]
Evaluating: 100%|██████████| 122/122 [14:25<00:00,  7.10s/it]



Example Predictions:

Question: Show all allergies with type food.
Predicted: select allergy from allergy _ type where allergytype = ` ` food'' allergytype = ` ` food'' allergytype allergytype allergytype allergytype allergytype allergytype = ` ` ` food'' allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype
True: SELECT DISTINCT allergy FROM Allergy_type WHERE allergytype = `` food ''

Question: Find the states of the colleges that have students in the tryout who played in striker position.
Predicted: select t1. state from college as t1 join tryout as t2 on t1. cname = t2. cname where t2. ppos ='goalie'' ppos'ppos ='' ppos'ppos ppos'ppos'ppos'ppos ppos'ppos from college as t1 join tryout as t2 on t1. cname = t2. cname where t2. ppos ='''ppos = '
True: SELECT T1.state FROM college AS T1 JOIN tryout AS T2 ON T1.cName = T2.cName WHERE T2.p

Epoch 11/20: 100%|██████████| 485/485 [06:20<00:00,  1.28it/s, train_loss=0.0941]
Evaluating: 100%|██████████| 122/122 [14:27<00:00,  7.11s/it]



Example Predictions:

Question: Show all allergies with type food.
Predicted: select allergytype from allergy _ type where allergytype = ` ` food'' allergytype = ` ` food'' allergytype allergytype allergytype = ` ` animal'' allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype = ` ` ` ` food _ type where allergytype = ` food _ type'''allergytype = ` ` ` food
True: SELECT DISTINCT allergy FROM Allergy_type WHERE allergytype = `` food ''

Question: Find the states of the colleges that have students in the tryout who played in striker position.
Predicted: select t1. state from college as t1 join tryout as t2 on t1. cname = t2. cname where t2. ppos ='goalie'ppos ='ppos where t2. ppos ='goalie'ppos ='ppos'ppos ppos ppos ppos'ppos ppos ppos ='ppos'ppos ppos ppos ppos ppos ppos'ppos ppos ppos
True: SELECT T1.state FROM college AS T1 JOIN tryout AS T2 ON T1.cName = T2.cName WHERE T2.pPos = 's

Epoch 12/20: 100%|██████████| 485/485 [06:19<00:00,  1.28it/s, train_loss=0.0842]
Evaluating: 100%|██████████| 122/122 [14:24<00:00,  7.08s/it]



Example Predictions:

Question: Show all allergies with type food.
Predicted: select distinct allergytype from allergy _ type where allergytype = ` ` food'' allergytype = ` ` food'' allergytype = ` ` animal'' allergytype allergytype allergytype = ` ` animal'' allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype = ` ` ` animal'' allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergy
True: SELECT DISTINCT allergy FROM Allergy_type WHERE allergytype = `` food ''

Question: Find the states of the colleges that have students in the tryout who played in striker position.
Predicted: select t1. state from college as t1 join tryout as t2 on t1. cname = t2. cname where t2. ppos ='striker'ppos ppos ppos ppos ='ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos
True: SELECT T1.state FROM college AS T1 JOIN tryout AS T2 ON T1.cName = T2.cName WHE

Epoch 13/20: 100%|██████████| 485/485 [06:18<00:00,  1.28it/s, train_loss=0.0301]
Evaluating: 100%|██████████| 122/122 [14:24<00:00,  7.09s/it]



Example Predictions:

Question: Show all allergies with type food.
Predicted: select allergytype from allergy _ type where allergytype = ` ` food'' allergytype = ` ` animal'' allergytype = ` ` animal'allergytype allergytype allergytype = ` ` animal'' allergytype allergytype allergytype allergytype allergytype allergytype = ` ` ` animal'' allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype = ` ` ` `
True: SELECT DISTINCT allergy FROM Allergy_type WHERE allergytype = `` food ''

Question: Find the states of the colleges that have students in the tryout who played in striker position.
Predicted: select distinct t1. state from college as t1 join tryout as t2 on t1. cname = t2. cname where t2. ppos ='goalie'ppos ='ppos goalie'select t1. state from college as t1 join tryout as t2 on t1. cname = t2. cname where t2. ppos ='ppos ='ppos ppos ppos ppos ppos ppos ppos ppos ppos
True: SELECT T1.state FROM college AS T1 JOIN tryout AS T2 ON T1

Epoch 14/20: 100%|██████████| 485/485 [06:19<00:00,  1.28it/s, train_loss=0.067]  
Evaluating: 100%|██████████| 122/122 [14:23<00:00,  7.08s/it]



Example Predictions:

Question: Show all allergies with type food.
Predicted: select allergytype from allergy _ type where allergytype = ` ` food'' allergytype = ` ` animal'' allergytype allergytype = ` ` animal'' allergytype allergytype allergytype allergytype allergytype allergytype allergytype = ` ` ` animal'' allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype all
True: SELECT DISTINCT allergy FROM Allergy_type WHERE allergytype = `` food ''

Question: Find the states of the colleges that have students in the tryout who played in striker position.
Predicted: select t1. state from college as t1 join tryout as t2 on t1. cname = t2. cname where t2. ppos ='striker'ppos ='ppos ppos ppos ppos ppos ppos ='striker'ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ='ppos ppos ppos ppos ppos ppos ppos ppos
True: SELECT T1.state FROM college AS T1 JOIN tryout AS T2 ON T1.cName = T2.cName WHE

Epoch 15/20: 100%|██████████| 485/485 [06:19<00:00,  1.28it/s, train_loss=0.0131] 
Evaluating: 100%|██████████| 122/122 [14:22<00:00,  7.07s/it]



Example Predictions:

Question: Show all allergies with type food.
Predicted: select allergy from allergy _ type where allergytype = ` ` food'' allergytype = ` ` animal'' allergytype allergytype = ` ` animal'' allergytype allergytype allergytype type type _ type _ type _ code = ` ` animal'' allergytype allergytype allergytype allergytype = ` ` animal'' allergytype allergytype type _ type _ type _ type _ type _ type _ type allergytype allergytype = ` ` ` animal '
True: SELECT DISTINCT allergy FROM Allergy_type WHERE allergytype = `` food ''

Question: Find the states of the colleges that have students in the tryout who played in striker position.
Predicted: select t1. state from college as t1 join tryout as t2 on t1. cname = t2. cname where t2. ppos ='striker't1. cname ='ppos ppos ppos ='striker'ppos ppos ppos ppos ppos ='ppos ppos ppos ppos ppos ppos ppos ='ppos ppos ppos ppos ppos ppos ppos ppos ='ppos
True: SELECT T1.state FROM college AS T1 JOIN tryout AS T2 ON T1.cName = T2.cName 

Epoch 16/20: 100%|██████████| 485/485 [06:19<00:00,  1.28it/s, train_loss=0.0102] 
Evaluating: 100%|██████████| 122/122 [14:23<00:00,  7.07s/it]



Example Predictions:

Question: Show all allergies with type food.
Predicted: select allergy from allergy _ type where allergytype = ` ` food'' allergytype = ` ` food'' allergytype allergytype allergytype = ` ` food'' allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype = ` ` ` food _ type allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergy
True: SELECT DISTINCT allergy FROM Allergy_type WHERE allergytype = `` food ''

Question: Find the states of the colleges that have students in the tryout who played in striker position.
Predicted: select t1. state from college as t1 join tryout as t2 on t1. cname = t2. cname where t2. ppos ='striker'ppos ='ppos ppos ppos ='ppos ppos ppos ppos ppos ppos ='ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ='ppos ppos ppos
True: SELECT T1.state FROM college AS T1 JOIN tryout AS T2 ON T1.cName = T2.cName WHERE T2.p

Epoch 17/20: 100%|██████████| 485/485 [06:19<00:00,  1.28it/s, train_loss=0.0184] 
Evaluating: 100%|██████████| 122/122 [14:26<00:00,  7.11s/it]



Example Predictions:

Question: Show all allergies with type food.
Predicted: select allergy from allergy _ type where allergytype = ` ` food'' allergytype = ` ` animal'' allergytype allergytype allergytype allergytype allergytype allergytype allergytype = ` ` ` animal'' allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype
True: SELECT DISTINCT allergy FROM Allergy_type WHERE allergytype = `` food ''

Question: Find the states of the colleges that have students in the tryout who played in striker position.
Predicted: select t1. state from college as t1 join tryout as t2 on t1. cname = t2. cname where t2. ppos ='striker'ppos ='ppos ppos ppos ppos ppos ppos ='striker'ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos
True: SELECT T1.state FROM college AS T1 JOIN tryout AS T2 ON T1.cName = T2.cName

Epoch 18/20: 100%|██████████| 485/485 [06:19<00:00,  1.28it/s, train_loss=0.00434]
Evaluating: 100%|██████████| 122/122 [14:24<00:00,  7.09s/it]



Example Predictions:

Question: Show all allergies with type food.
Predicted: select allergy from allergy _ type where allergytype = ` ` food'' allergytype = ` ` animal'' allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype type allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype = ` ` ` ` ` food _ type'''all
True: SELECT DISTINCT allergy FROM Allergy_type WHERE allergytype = `` food ''

Question: Find the states of the colleges that have students in the tryout who played in striker position.
Predicted: select t1. state from college as t1 join tryout as t2 on t1. cname = t2. cname where t2. ppos ='striker'ppos ='ppos where t2. ppos ='striker'ppos ='ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos
True: SELECT T1.state FROM college AS T1 JOIN tryout AS T2 ON T1.cName = T2.cName WHE

Epoch 19/20: 100%|██████████| 485/485 [06:19<00:00,  1.28it/s, train_loss=0.00147]
Evaluating: 100%|██████████| 122/122 [14:20<00:00,  7.06s/it]



Example Predictions:

Question: Show all allergies with type food.
Predicted: select allergy from allergy _ type where allergytype = ` ` food'' allergytype = ` ` animal'' allergytype allergytype allergytype allergytype = ` ` animal'' allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype allergytype = ` ` ` ` ` animal'''allergytype allergytype
True: SELECT DISTINCT allergy FROM Allergy_type WHERE allergytype = `` food ''

Question: Find the states of the colleges that have students in the tryout who played in striker position.
Predicted: select t1. state from college as t1 join tryout as t2 on t1. cname = t2. cname where t2. ppos ='striker'ppos ='ppos ppos ppos ='striker'ppos ppos ppos ppos ppos ='ppos ppos ppos ppos ppos ppos ppos ppos ppos ppos ='ppos ppos ppos ppos ppos ppos ppos pp
True: SELECT T1.state FROM college AS T1 JOIN tryout AS T2 ON T1.cName = T2.cName WHERE T

Epoch 20/20:  27%|██▋       | 131/485 [01:43<04:38,  1.27it/s, train_loss=0.00333]

In [1]:
import torch
from transformers import BertTokenizer, BertModel
from torch import nn

class TextToSQLModel(nn.Module):
    def __init__(self, vocab_size: int, hidden_size: int = 768, tokenizer=None, dropout: float = 0.1):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(dropout)
        
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_size,
            nhead=8,
            dropout=dropout,
            batch_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
        self.output_layer = nn.Linear(hidden_size, vocab_size)
        self.tokenizer = tokenizer

    def forward(self, input_ids, attention_mask, labels=None):
        encoder_outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        ).last_hidden_state
        
        encoder_outputs = self.dropout(encoder_outputs)
        
        if labels is not None:
            decoder_inputs = labels[:, :-1]
            decoder_outputs = self.decoder(
                tgt=self.bert.embeddings.word_embeddings(decoder_inputs),
                memory=encoder_outputs,
                tgt_mask=self.generate_square_subsequent_mask(decoder_inputs.size(1)).to(decoder_inputs.device)
            )
            
            logits = self.output_layer(decoder_outputs)
            
            loss_fct = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
            loss = loss_fct(logits.view(-1, logits.size(-1)), labels[:, 1:].reshape(-1))
            
            return {'loss': loss, 'logits': logits}
        else:
            return self.generate(encoder_outputs, attention_mask)

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def generate(self, encoder_outputs, attention_mask, max_length=128):
        batch_size = encoder_outputs.size(0)
        decoder_input = torch.full(
            (batch_size, 1),
            self.tokenizer.cls_token_id,
            device=encoder_outputs.device
        )
        
        for _ in range(max_length):
            decoder_outputs = self.decoder(
                tgt=self.bert.embeddings.word_embeddings(decoder_input),
                memory=encoder_outputs,
                tgt_mask=self.generate_square_subsequent_mask(decoder_input.size(1)).to(decoder_input.device)
            )
            
            next_token_logits = self.output_layer(decoder_outputs[:, -1:])
            next_token = next_token_logits.argmax(dim=-1)
            
            decoder_input = torch.cat([decoder_input, next_token], dim=1)
            
            if (next_token == self.tokenizer.sep_token_id).all():
                break
                
        return decoder_input

def generate_sql_query(question: str, model: TextToSQLModel, tokenizer: BertTokenizer, device: str):
    # Prepare the model for inference
    model.eval()
    
    # Tokenize the input question
    inputs = tokenizer(
        question,
        max_length=128,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    
    # Move inputs to device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        # Get encoder outputs
        encoder_outputs = model.bert(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask']
        ).last_hidden_state
        
        # Generate SQL query
        generated = model.generate(encoder_outputs, inputs['attention_mask'])
        
        # Decode the generated query
        predicted_query = tokenizer.decode(generated[0], skip_special_tokens=True)
    
    return predicted_query

def main():
    # Set device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    
    # Initialize tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
    # Initialize model
    model = TextToSQLModel(
        vocab_size=tokenizer.vocab_size,
        tokenizer=tokenizer
    )
    
    # Load trained model weights
    model.load_state_dict(torch.load('best_text2sql_model_2.pt', map_location=device))
    model.to(device)
    
    # Set model to evaluation mode
    model.eval()
    
    while True:
        # Get question from user
        question = input("\nEnter your question (or 'quit' to exit): ")
        
        if question.lower() == 'quit':
            break
        
        # Generate SQL query
        sql_query = generate_sql_query(question, model, tokenizer, device)
        
        print("\nInput Question:", question)
        print("Generated SQL Query:", sql_query)

if __name__ == "__main__":
    main()

Using device: cpu
