In [1]:
%load_ext autoreload
%autoreload 2

import os
from typing import *

import torch

from typet5.model import ModelWrapper
from typet5.train import PreprocessArgs
from typet5.utils import *
from typet5.function_decoding import (
    RolloutCtx,
    PreprocessArgs,
    DecodingOrders,
    AccuracyMetric,
)
from typet5.static_analysis import PythonProject

os.chdir(proj_root())

In [2]:
# download or load the model
# wrapper = ModelWrapper.load_from_hub("MrVPlusOne/TypeT5-v7")
# device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
# wrapper.to(device)
# print("model loaded")

from transformers import AutoTokenizer, AutoModel
 
tokenizer = AutoTokenizer.from_pretrained("microsoft/unixcoder-base")
 
model = AutoModel.from_pretrained("microsoft/unixcoder-base")

wrapper = ModelWrapper.load_from_hub("MrVPlusOne/TypeT5-v7")
device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print("model loaded")


Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]

model loaded


In [3]:
# set up the rollout parameters
rctx = RolloutCtx(model=wrapper)
pre_args = PreprocessArgs()
# we use the double-traversal decoding order, where the model can make corrections 
# to its previous predictions in the second pass
decode_order = DecodingOrders.DoubleTraversal()

In [4]:
# Use case 1: Run TypeT5 on a given project, taking advantage of existing user 
# annotations and only make predictions for missing types.

project = PythonProject.parse_from_root(proj_root() / "data/ex_repo")
rollout = await rctx.run_on_project(project, pre_args, decode_order)

ex_code_2/fib: (n: int) -> int
ex_code_2/foo: (bar: int) -> int
ex_code_2/Bar.x: int
ex_code_2/Bar.y: int
ex_code_2/Bar.reset: (w0: str) -> None
ex_code_2/Bar.__init__: (x: int) -> None
ex_code_1/good: int
ex_code_1/fib: (n: int) -> int
ex_code_1/Wrapper.foo: (bar: int) -> int
ex_code_1/Wrapper.inc: () -> str
ex_code_1/int_add: (a: int, b: int) -> str
ex_code_1/int_tripple_add: (a: int, b: int, c: int) -> int
(updated) ex_code_1/int_add: (a: int, b: int) -> int


In [5]:
# Use case 2: Run TypeT5 on a test project where all user annotations will be treated as
# labels and removed before running the model.

print("I am here 1")
eval_r = await rctx.evaluate_on_projects([project], pre_args, decode_order)
print("I am here")
eval_r.print_predictions()

I am here 1


evaluate_on_projects: 100%|██████████| 35/35 [00:09<00:00,  3.51it/s]

I am here
	ex_code_2/fib: (n: int) -> int
	ex_code_2/foo: (bar: int) -> int
	ex_code_2/Bar.__init__: (x: int) -> None
	ex_code_2/Bar.reset: (w0: int_add) -> None
	ex_code_2/Bar.foo: (z: str) -> str
	ex_code_1/fib: (n: int) -> int
	ex_code_1/Wrapper.foo: (bar: int) -> int
	ex_code_1/Wrapper.inc: () -> int
	ex_code_1/int_add: (a: int, b: int) -> str
	ex_code_1/int_tripple_add: (a: int, b: int, c: int) -> int
	ex_code_2/Bar.z: str
	ex_code_2/Bar.w: int_add
	ex_code_2/Bar.x: int
	ex_code_2/Bar.y: int
	ex_code_2/bar: Bar
	ex_code_1/good: int
	ex_code_1/Wrapper.x_elem: int
	ex_code_1/Wrapper.y: int





In [6]:
metrics = AccuracyMetric.default_metrics(wrapper.common_type_names)
for metric in metrics:
    accs = eval_r.error_analysis(None, metric).accuracies
    pretty_print_dict({metric.name: accs})
    

full_acc:
   full_acc: 70.00% (count=10)
   full_acc_by_cat:
      FuncArg: 100.00% (count=4)
      FuncReturn: 0.00% (count=1)
      ClassAtribute: 50.00% (count=4)
      GlobalVar: 100.00% (count=1)
   full_acc_by_simple:
      simple: 70.00% (count=10)
   full_acc_label_size: 1
   full_acc_pred_size: 1
   full_acc_ignored_labels: 0
full_acc_common:
   full_acc_common: 66.67% (count=9)
   full_acc_common_by_cat:
      FuncArg: 100.00% (count=4)
      FuncReturn: 0.00% (count=1)
      ClassAtribute: 33.33% (count=3)
      GlobalVar: 100.00% (count=1)
   full_acc_common_by_simple:
      simple: 66.67% (count=9)
   full_acc_common_label_size: 1
   full_acc_common_pred_size: 1
   full_acc_common_ignored_labels: 1
full_acc_rare:
   full_acc_rare: 100.00% (count=1)
   full_acc_rare_by_cat:
      FuncArg: 100.00% (count=1)
   full_acc_rare_by_simple:
      simple: 100.00% (count=1)
   full_acc_rare_label_size: 1
   full_acc_rare_pred_size: 1
   full_acc_rare_ignored_labels: 9
acc:
   acc: 7