In [1]:
# If using Google Collab, uncomment these lines

# !pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124
# !pip install transformers datasets evaluate -q

In [2]:
from transformers import T5ForConditionalGeneration
from transformers import RobertaTokenizer

import torch

In [3]:
# ----------------------------------------
# Load the fine-tuned model
# ----------------------------------------

model = T5ForConditionalGeneration.from_pretrained("Trained_Model/fine_tuned_model")
tokenizer = RobertaTokenizer.from_pretrained("Trained_Model/tokenizer")

# Set device
# device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.backends.cuda.is_available() else "cpu")
# For this task (at least with the way I've coded it), it seems that using the CPU is faster
device = torch.device("cpu")
model.to(device)

T5ForConditionalGeneration(
  (shared): Embedding(32101, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32101, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 8)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Drop

In [4]:
# ----------------------------------------
# Load the testing data
# ----------------------------------------

import pandas as pd

test = pd.read_csv("processed_data/ft_test_processed.csv")
test

Unnamed: 0,formatted_method,if_statement
0,"def read(self, count=True, timeout=None, ignor...",if ignore_timeouts and is_timeout(e):
1,"def _cache_mem(curr_out, prev_mem, mem_len, re...",if prev_mem is None:
2,def filtered(gen): for example in gen: example...,if example_len > max_length:
3,"def search(self, query): if not query: logger....","if item.get('type', '') == 'audio':"
4,"def _check_script(self, script, directive): fo...",if var.must_contain('/'):
...,...,...
4982,"def _super_function(args): passed_class, passe...","if isinstance(pyclass, pyobjects.AbstractClass):"
4983,"def get_data(row): data = [] for field_name, f...",if result:
4984,"def say(jarvis, s): """"""Reads what is typed.""""""...",if not voice_state:
4985,"def __import__(name, globals=None, locals=None...",if '*' in fromlist:


In [5]:
# ----------------------------------------
# Prepare to generate predictions
# ----------------------------------------

def generate_prediction(code, model, tokenizer, device):
    inputs = tokenizer(code, return_tensors="pt", padding=True, truncation=True)
    # Send inputs to device
    inputs = {key : value.to(device) for key, value in inputs.items()}

    with torch.inference_mode():
        outputs = model.generate(**inputs, max_length=256)

    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [6]:
# ----------------------------------------
# Generate Predictions!
# ----------------------------------------

from tqdm import tqdm

tqdm.pandas()

test["predicted_statement"] = test["formatted_method"].progress_apply(lambda code: generate_prediction(code, model, tokenizer, device))
test

100%|██████████| 4987/4987 [03:06<00:00, 26.72it/s]


Unnamed: 0,formatted_method,if_statement,predicted_statement
0,"def read(self, count=True, timeout=None, ignor...",if ignore_timeouts and is_timeout(e):,if ignore_timeouts and is_noerr(e):
1,"def _cache_mem(curr_out, prev_mem, mem_len, re...",if prev_mem is None:,if prev_mem is None:
2,def filtered(gen): for example in gen: example...,if example_len > max_length:,if example_len > max_length:
3,"def search(self, query): if not query: logger....","if item.get('type', '') == 'audio':",if 'guide_id' in item:
4,"def _check_script(self, script, directive): fo...",if var.must_contain('/'):,if not var.can_match(directive):
...,...,...,...
4982,"def _super_function(args): passed_class, passe...","if isinstance(pyclass, pyobjects.AbstractClass):","if isinstance(pyclass, pyobjects.PyObject):"
4983,"def get_data(row): data = [] for field_name, f...",if result:,if result:
4984,"def say(jarvis, s): """"""Reads what is typed.""""""...",if not voice_state:,if voice_state:
4985,"def __import__(name, globals=None, locals=None...",if '*' in fromlist:,if '*' in fromlist:


In [7]:
test["if_statement"].to_csv("evaluation_data/target.txt", index=False, header=False)
test["predicted_statement"].to_csv("evaluation_data/prediction.txt", index=False, header=False)

In [13]:
# ----------------------------------------
# Check for exact matches
# ----------------------------------------

test["Exact_Match"] = test["if_statement"] == test["predicted_statement"]
print("Exact matches:", test["Exact_Match"].sum())
test

Exact matches: 1450


Unnamed: 0,formatted_method,if_statement,predicted_statement,Exact_Match
0,"def read(self, count=True, timeout=None, ignor...",if ignore_timeouts and is_timeout(e):,if ignore_timeouts and is_noerr(e):,False
1,"def _cache_mem(curr_out, prev_mem, mem_len, re...",if prev_mem is None:,if prev_mem is None:,True
2,def filtered(gen): for example in gen: example...,if example_len > max_length:,if example_len > max_length:,True
3,"def search(self, query): if not query: logger....","if item.get('type', '') == 'audio':",if 'guide_id' in item:,False
4,"def _check_script(self, script, directive): fo...",if var.must_contain('/'):,if not var.can_match(directive):,False
...,...,...,...,...
4982,"def _super_function(args): passed_class, passe...","if isinstance(pyclass, pyobjects.AbstractClass):","if isinstance(pyclass, pyobjects.PyObject):",False
4983,"def get_data(row): data = [] for field_name, f...",if result:,if result:,True
4984,"def say(jarvis, s): """"""Reads what is typed.""""""...",if not voice_state:,if voice_state:,False
4985,"def __import__(name, globals=None, locals=None...",if '*' in fromlist:,if '*' in fromlist:,True


In [15]:
test.to_csv("evaluation_data/test_data.csv", index=False)