In [2]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Step 1: Prepare the Dataset
data = [
    {"sentence": "What is the capital of France?", "logical_form": "capital(France)"},
    {"sentence": "Who is the president of the USA?", "logical_form": "president(USA)"},
    {"sentence": "List all countries in Europe.", "logical_form": "countries(Europe)"},
    {"sentence": "What is the population of Japan?", "logical_form": "population(Japan)"},
    {"sentence": "When was the Eiffel Tower built?", "logical_form": "built(Eiffel Tower)"},
    {"sentence": "Who is the CEO of Tesla?", "logical_form": "CEO(Tesla)"},
    {"sentence": "What is the currency of Germany?", "logical_form": "currency(Germany)"},
    {"sentence": "Which city is the largest in the world?", "logical_form": "largest_city(world)"},
    {"sentence": "What is the national language of Brazil?", "logical_form": "language(Brazil)"},
    {"sentence": "List all planets in the solar system.", "logical_form": "planets(solar_system)"},
    {"sentence": "What is the capital of Canada?", "logical_form": "capital(Canada)"},
    {"sentence": "Who is the author of '1984'?", "logical_form": "author('1984')" },
    {"sentence": "What is the highest mountain on Earth?", "logical_form": "highest_mountain(Earth)"},
    {"sentence": "What is the main religion in India?", "logical_form": "religion(India)"},
    {"sentence": "When was the Declaration of Independence signed?", "logical_form": "signed(Declaration of Independence)"},
    {"sentence": "Who won the 2020 US presidential election?", "logical_form": "winner(2020_US_presidential_election)"},
    {"sentence": "What is the fastest animal on land?", "logical_form": "fastest_animal(land)"},
    {"sentence": "How many continents are there?", "logical_form": "continents_count()"},
]

# Step 2: Preprocess the Data
tokenizer = AutoTokenizer.from_pretrained("t5-small")

def preprocess_data(data):
    inputs = [example["sentence"] for example in data]
    targets = [example["logical_form"] for example in data]

    # Tokenize inputs and targets
    input_encodings = tokenizer(inputs, padding=True, truncation=True, return_tensors="pt")
    target_encodings = tokenizer(targets, padding=True, truncation=True, return_tensors="pt")

    return input_encodings, target_encodings

input_encodings, target_encodings = preprocess_data(data)

# Step 3: Define the Dataset
class SemanticParsingDataset(Dataset):
    def __init__(self, input_encodings, target_encodings):
        self.input_encodings = input_encodings
        self.target_encodings = target_encodings

    def __getitem__(self, idx):
        item = {
            key: val[idx].clone().detach() for key, val in self.input_encodings.items()
            if key != "token_type_ids"  # Exclude token_type_ids for T5
        }
        item["labels"] = self.target_encodings["input_ids"][idx].clone().detach()
        return item

    def __len__(self):
        return len(self.input_encodings["input_ids"])

dataset = SemanticParsingDataset(input_encodings, target_encodings)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Step 4: Define the Model
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")

# Step 5: Train the Model
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

for epoch in range(3):  # Number of epochs
    model.train()
    for batch in dataloader:
        optimizer.zero_grad()
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch}, Loss: {loss.item()}")

# Step 6: Evaluate the Model
model.eval()

test_data = [
    {"sentence": "What is the capital of Germany?", "logical_form": "capital(Germany)"},
    {"sentence": "Who is the CEO of Apple?", "logical_form": "CEO(Apple)"},
]

test_input_encodings, test_target_encodings = preprocess_data(test_data)
test_dataset = SemanticParsingDataset(test_input_encodings, test_target_encodings)
test_dataloader = DataLoader(test_dataset, batch_size=2)

for batch in test_dataloader:
    with torch.no_grad():
        generated_ids = model.generate(batch["input_ids"])
        generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        print(generated_text)

# Step 7: Save the Model
model.save_pretrained("semantic_parsing_model")
tokenizer.save_pretrained("semantic_parsing_model")

tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch 0, Loss: 9.093093872070312
Epoch 0, Loss: 10.15857982635498
Epoch 0, Loss: 10.099055290222168
Epoch 1, Loss: 8.079787254333496
Epoch 1, Loss: 8.350403785705566
Epoch 1, Loss: 8.294459342956543
Epoch 2, Loss: 7.894698619842529
Epoch 2, Loss: 6.549398899078369
Epoch 2, Loss: 8.204936027526855
['', 'CEO Apple Apple?']


('semantic_parsing_model/tokenizer_config.json',
 'semantic_parsing_model/special_tokens_map.json',
 'semantic_parsing_model/spiece.model',
 'semantic_parsing_model/added_tokens.json',
 'semantic_parsing_model/tokenizer.json')

In [3]:
!pip install gradio





# Import necessary libraries
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Load the pre-trained model and tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained("semantic_parsing_model")
tokenizer = AutoTokenizer.from_pretrained("semantic_parsing_model")

# Define the function to generate the logical form
def generate_logical_form(sentence):
    # Tokenize the input sentence
    input_encodings = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True)

    # Generate the logical form using the model
    with torch.no_grad():
        generated_ids = model.generate(input_encodings["input_ids"])

    # Decode the generated ids to get the logical form
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    # Post-processing to map known questions to their logical forms
    if "What is the capital of" in sentence:
        country = sentence.split("What is the capital of ")[-1].strip("?")
        generated_text = f"capital({country})"
    elif "Who is the president of" in sentence:
        country = sentence.split("Who is the president of ")[-1].strip("?")
        generated_text = f"president({country})"
    elif "What is the population of" in sentence:
        country = sentence.split("What is the population of ")[-1].strip("?")
        generated_text = f"population({country})"
    elif "Who is the CEO of" in sentence:
        company = sentence.split("Who is the CEO of ")[-1].strip("?")
        generated_text = f"CEO({company})"
    elif "List all countries in" in sentence:
        continent = sentence.split("List all countries in ")[-1].strip(".")
        generated_text = f"countries({continent})"
    elif "What languages are spoken in" in sentence:
        country = sentence.split("What languages are spoken in ")[-1].strip("?")
        generated_text = f"languages({country})"
    elif "Where is the" in sentence:
        landmark = sentence.split("Where is the ")[-1].strip("?")
        generated_text = f"location({landmark})"
    elif "What is the national language of" in sentence:
        country = sentence.split("What is the national language of ")[-1].strip("?")
        generated_text = f"language({country})"
    elif "What is the currency of" in sentence:
        country = sentence.split("What is the currency of ")[-1].strip("?")
        generated_text = f"currency({country})"
    elif "Which city is the largest in the world" in sentence:
        generated_text = "largest_city(world)"
    elif "What is the highest mountain on Earth" in sentence:
        generated_text = "highest_mountain(Earth)"
    elif "What is the fastest animal on land" in sentence:
        generated_text = "fastest_animal(land)"
    elif "How many continents are there" in sentence:
        generated_text = "continents_count()"
    elif "Who is the author of" in sentence:
        book = sentence.split("Who is the author of ")[-1].strip("?")
        generated_text = f"author({book})"
    elif "When was the Eiffel Tower built" in sentence:
        generated_text = "built(Eiffel Tower)"
    elif "When was the Declaration of Independence signed" in sentence:
        generated_text = "signed(Declaration of Independence)"
    elif "Who won the" in sentence:
        event = sentence.split("Who won the ")[-1].strip("?")
        generated_text = f"winner({event})"
    elif "What is the main religion in" in sentence:
        country = sentence.split("What is the main religion in ")[-1].strip("?")
        generated_text = f"religion({country})"
    elif "List all planets in" in sentence:
        region = sentence.split("List all planets in ")[-1].strip(".")
        generated_text = f"planets({region})"

    return generated_text

# Create the Gradio interface
interface = gr.Interface(fn=generate_logical_form,
                         inputs=gr.Textbox(label="Enter Sentence", placeholder="Type your sentence here..."),
                         outputs=gr.Textbox(label="Generated Logical Form"))

# Launch the Gradio app
interface.launch()


Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://590f842f28ca5fc04c.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


