In [None]:
pip install transformers datasets torch

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (1

In [None]:
import re
import torch
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from datasets import Dataset
import random

# Step 1: Create a larger dataset with 1000+ samples

def generate_synthetic_data():
    """Function to generate synthetic Python code snippets with errors and correct code."""
    correct_code = [
        'print("Hello World")',
        'if x == 10: print("Number 10")',
        'def add(a, b): return a + b',
        'x = 5 + 3',
        'while True: print("Looping")',
        'for i in range(5): print(i)',
        'class MyClass: def method(self): pass',
        'try: pass except: pass',
        'def square(x): return x * x',
        'import math; print(math.sqrt(16))'
    ]

    error_code = [
        'if x == 10 print("Number 10")',  # Missing colon
        'def add(a, b) return a + b',     # Missing colon
        'print("Hello World"',            # Missing closing parenthesis
        'for i in range(5) print(i)',     # Missing colon
        'while True print("Endless loop")',  # Missing colon
        'x = 5 + "text"',                 # Type error
        'class MyClass: def method(self)' # Syntax error (missing body)
    ]

    # Generate 1000+ code samples by randomly selecting correct and error code samples
    dataset = []
    for _ in range(1000):
        if random.random() > 0.5:
            dataset.append({'code': random.choice(correct_code), 'label': 0})  # 0 for Correct
        else:
            dataset.append({'code': random.choice(error_code), 'label': 1})    # 1 for Syntax Error

    return dataset

# Step 2: Tokenizer for splitting the code into tokens
def tokenize_code(code):
    """Tokenizes Python code into syntax elements."""
    keywords = ['def', 'return', 'if', 'else', 'while', 'for', 'in', 'print', 'class', 'try', 'except', 'import', 'as', 'True', 'False', 'None']
    operators = ['=', '==', '!=', '>', '<', '>=', '<=', '+', '-', '*', '/', '%', '**']
    delimiters = ['(', ')', '{', '}', '[', ']', ':', ',', ';']

    # Regex pattern to match keywords, operators, and other syntax elements
    token_pattern = r'|'.join([
        r'\b(?:' + '|'.join(keywords) + r')\b',   # Match keywords
        r'\b(?:' + '|'.join(operators) + r')\b',  # Match operators
        r'[(){}[\],;]',                         # Match delimiters
        r'\b(?:[a-zA-Z_][a-zA-Z0-9_]*|\d+)\b'    # Match identifiers and numbers
    ])

    tokens = re.findall(token_pattern, code)
    return tokens

# Step 3: Prepare Dataset for Training
train_data = generate_synthetic_data()

# Tokenize the data
def tokenize_function(examples):
    return tokenizer(examples['code'], padding="max_length", truncation=True)

train_dataset = Dataset.from_dict({
    'code': [item['code'] for item in train_data],
    'label': [item['label'] for item in train_data]
})

# Initialize the tokenizer and model from Hugging Face
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)  # 2 classes: Correct, Syntax Error

# Apply tokenization to the dataset
train_dataset = train_dataset.map(tokenize_function, batched=True)

# Step 4: Define Training Arguments and Train the Model
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=train_dataset
)

# Train the model
trainer.train()

# Step 5: Evaluation and Prediction
def predict(code):
    """Predict if the given code has a syntax error or is correct."""
    inputs = tokenizer(code, return_tensors="pt", padding=True, truncation=True, max_length=128)
    outputs = model(**inputs)
    prediction = torch.argmax(outputs.logits, dim=-1).item()
    return "No error" if prediction == 0 else "Syntax error"

# Test the model with user input
print("Enter a Python code snippet (or 'exit' to quit):")
while True:
    code_input = input()
    if code_input == 'exit':
        print("Exiting the program.")
        break

    try:
        print(f"Predicted: {predict(code_input)}")
    except Exception as e:
        print(f"Error in tokenizing code: {str(e)}")
        print("Predicted: Syntax error")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Step,Training Loss
10,0.6939
20,0.6752
30,0.6347
40,0.619
50,0.6295
60,0.5439
70,0.4883
80,0.4068
90,0.2872
100,0.208


Step,Training Loss
10,0.6939
20,0.6752
30,0.6347
40,0.619
50,0.6295
60,0.5439
70,0.4883
80,0.4068
90,0.2872
100,0.208


Enter a Python code snippet (or 'exit' to quit):
print('hello')
Predicted: No error
