In [1]:
%env OPENAI_API_KEY = api-key

env: OPENAI_API_KEY=api-key


In [2]:
# pip install --upgrade openai

In [3]:
import glob
import json
import os
import re
import random
import time

from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

In [4]:
import openai

openai.api_key = os.getenv("OPENAI_API_KEY")

In [5]:
cd '/home/azmain/alljavajsons'

/home/azmain/alljavajsons


In [6]:
inputFiles = []
correctOutputFiles = []
    
numbers = re.compile(r'(\d+)')
def numericalSort(value):
    parts = numbers.split(value)
    parts[1::2] = map(int, parts[1::2])
    return parts

for file in sorted(glob.glob("*.java.json"), key=numericalSort):
    inputFiles.append(file)

for file in sorted(glob.glob("*.benchmark_log.json"), key=numericalSort):
    correctOutputFiles.append(file)

print(inputFiles)
print('\n\n\n')
print(correctOutputFiles)

['Android01.java.json', 'Android02.java.json', 'Android03.java.json', 'Android04.java.json', 'Android05.java.json', 'Android06.java.json', 'Android07.java.json', 'Android08.java.json', 'Android09.java.json', 'Android10.java.json', 'Android11.java.json', 'Android12.java.json', 'Android13.java.json', 'Android14.java.json', 'Android15.java.json', 'Android16.java.json', 'Android17.java.json', 'Android18.java.json', 'Android19.java.json', 'Android20.java.json', 'Android21.java.json', 'Android22.java.json', 'Android23.java.json', 'Android24.java.json', 'Android25.java.json', 'Android26.java.json', 'Android27.java.json', 'Android28.java.json', 'Android29.java.json', 'Android30.java.json', 'Android31.java.json', 'Android32.java.json', 'Android33.java.json', 'Android34.java.json', 'Android35.java.json', 'Android36.java.json', 'Android37.java.json', 'Android38.java.json', 'Android39.java.json', 'Android40.java.json', 'Android41.java.json', 'Android42.java.json', 'Android43.java.json', 'Android44

In [7]:
def get_codes(inputFiles):
    codes = []
    for code in inputFiles:
        codes.append(str(json.load(open(code))['originalContent']))

    return codes

In [8]:
def get_correct_outputs(correctOutputFiles):
    correct_outputs = []
    for output in correctOutputFiles:
        correct_output_list = json.load(open(output))['total_imports']
        correct_output_list = ["import "+i+";" for i in correct_output_list]
        correct_outputs.append(correct_output_list)
    
    for import_lines in correct_outputs:
        if "import gen.R;" in import_lines:
            import_lines.remove("import gen.R;")

    return correct_outputs

In [9]:
def get_dataset(codes, correct_outputs):
    dataset = {
        "codes": codes,
        "correct_outputs": correct_outputs
    }
    return dataset

In [10]:
def get_test_examples_and_y_true(dataset):
    sample_list = []

    for i in range(0, len(dataset["codes"])):
        sample_list.append(dict(codes=dataset["codes"][i], correct_outputs=dataset["correct_outputs"][i]))
    
    # print(sample_list)
    
    test_examples = [(example["codes"], example["correct_outputs"]) for example in sample_list]
    y_true = [correct_outputs for _, correct_outputs in test_examples]
    
    return test_examples, y_true

### All Android Classes

In [11]:
print("Total Android Codes: {}\n".format(len(inputFiles[:50])))
print(inputFiles[:50])

android_codes = get_codes(inputFiles[:50])
# print(android_codes)

android_correct_outputs = get_correct_outputs(correctOutputFiles[:50])
# print(android_correct_outputs)

android_dataset = get_dataset(android_codes, android_correct_outputs)
# print(android_dataset)

android_test_examples, android_y_true = get_test_examples_and_y_true(android_dataset)
# print(android_test_examples)

Total Android Codes: 50

['Android01.java.json', 'Android02.java.json', 'Android03.java.json', 'Android04.java.json', 'Android05.java.json', 'Android06.java.json', 'Android07.java.json', 'Android08.java.json', 'Android09.java.json', 'Android10.java.json', 'Android11.java.json', 'Android12.java.json', 'Android13.java.json', 'Android14.java.json', 'Android15.java.json', 'Android16.java.json', 'Android17.java.json', 'Android18.java.json', 'Android19.java.json', 'Android20.java.json', 'Android21.java.json', 'Android22.java.json', 'Android23.java.json', 'Android24.java.json', 'Android25.java.json', 'Android26.java.json', 'Android27.java.json', 'Android28.java.json', 'Android29.java.json', 'Android30.java.json', 'Android31.java.json', 'Android32.java.json', 'Android33.java.json', 'Android34.java.json', 'Android35.java.json', 'Android36.java.json', 'Android37.java.json', 'Android38.java.json', 'Android39.java.json', 'Android40.java.json', 'Android41.java.json', 'Android42.java.json', 'Android

### All JDK Classes

In [12]:
print("Total JDK Codes: {}\n".format(len(inputFiles[50:73])))
print(inputFiles[50:73])

jdk_codes = get_codes(inputFiles[50:73])
# print(jdk_codes)

jdk_correct_outputs = get_correct_outputs(correctOutputFiles[50:73])
# print(jdk_correct_outputs)

jdk_dataset = get_dataset(jdk_codes, jdk_correct_outputs)
# print(jdk_dataset)

jdk_test_examples, jdk_y_true = get_test_examples_and_y_true(jdk_dataset)
# print(jdk_test_examples)

Total JDK Codes: 23

['Class_1.java.json', 'Class_2.java.json', 'Class_3.java.json', 'Class_4.java.json', 'Class_5.java.json', 'Class_6.java.json', 'Class_7.java.json', 'Class_8.java.json', 'Class_9.java.json', 'Class_10.java.json', 'Class_11.java.json', 'Class_12.java.json', 'Class_13.java.json', 'Class_14.java.json', 'Class_15.java.json', 'Class_16.java.json', 'Class_17.java.json', 'Class_18.java.json', 'Class_19.java.json', 'Class_20.java.json', 'Class_21.java.json', 'Class_22.java.json', 'Class_23.java.json']


### All Hibernate Classes

In [13]:
print("Total Hibernate Codes: {}\n".format(len(inputFiles[73:74] + inputFiles[174:224])))
print(inputFiles[73:74] + inputFiles[174:224])

hibernate_codes = get_codes(inputFiles[73:74] + inputFiles[174:224])
# print(hibernate_codes)

hibernate_correct_outputs = get_correct_outputs(correctOutputFiles[73:74] + correctOutputFiles[174:224])
# print(hibernate_correct_outputs)

hibernate_dataset = get_dataset(hibernate_codes, hibernate_correct_outputs)
# print(hibernate_dataset)

hibernate_test_examples, hibernate_y_true = get_test_examples_and_y_true(hibernate_dataset)
# print(hibernate_test_examples)

Total Hibernate Codes: 51

['HibernateUtil.java.json', 'hibernate_class_1.java.json', 'hibernate_class_2.java.json', 'hibernate_class_3.java.json', 'hibernate_class_4.java.json', 'hibernate_class_5.java.json', 'hibernate_class_6.java.json', 'hibernate_class_7.java.json', 'hibernate_class_8.java.json', 'hibernate_class_9.java.json', 'hibernate_class_10.java.json', 'hibernate_class_11.java.json', 'hibernate_class_12.java.json', 'hibernate_class_13.java.json', 'hibernate_class_14.java.json', 'hibernate_class_15.java.json', 'hibernate_class_16.java.json', 'hibernate_class_17.java.json', 'hibernate_class_18.java.json', 'hibernate_class_19.java.json', 'hibernate_class_20.java.json', 'hibernate_class_21.java.json', 'hibernate_class_22.java.json', 'hibernate_class_23.java.json', 'hibernate_class_24.java.json', 'hibernate_class_25.java.json', 'hibernate_class_26.java.json', 'hibernate_class_27.java.json', 'hibernate_class_28.java.json', 'hibernate_class_29.java.json', 'hibernate_class_30.java.j

### All JodaTime Classes

In [14]:
print("Total JodaTime Codes: {}\n".format(len(inputFiles[74:124])))
print(inputFiles[74:124])

jodatime_codes = get_codes(inputFiles[74:124])
# print(jodatime_codes)

jodatime_correct_outputs = get_correct_outputs(correctOutputFiles[74:124])
# print(jodatime_correct_outputs)

jodatime_dataset = get_dataset(jodatime_codes, jodatime_correct_outputs)
# print(jodatime_dataset)

jodatime_test_examples, jodatime_y_true = get_test_examples_and_y_true(jodatime_dataset)
# print(jodatime_test_examples)

Total JodaTime Codes: 50

['JodaTime01.java.json', 'JodaTime02.java.json', 'JodaTime03.java.json', 'JodaTime04.java.json', 'JodaTime05.java.json', 'JodaTime06.java.json', 'JodaTime07.java.json', 'JodaTime08.java.json', 'JodaTime09.java.json', 'JodaTime10.java.json', 'JodaTime11.java.json', 'JodaTime12.java.json', 'JodaTime13.java.json', 'JodaTime14.java.json', 'JodaTime15.java.json', 'JodaTime16.java.json', 'JodaTime17.java.json', 'JodaTime18.java.json', 'JodaTime19.java.json', 'JodaTime20.java.json', 'JodaTime21.java.json', 'JodaTime22.java.json', 'JodaTime23.java.json', 'JodaTime24.java.json', 'JodaTime25.java.json', 'JodaTime26.java.json', 'JodaTime27.java.json', 'JodaTime28.java.json', 'JodaTime29.java.json', 'JodaTime30.java.json', 'JodaTime31.java.json', 'JodaTime32.java.json', 'JodaTime33.java.json', 'JodaTime34.java.json', 'JodaTime35.java.json', 'JodaTime36.java.json', 'JodaTime37.java.json', 'JodaTime38.java.json', 'JodaTime39.java.json', 'JodaTime40.java.json', 'JodaTime41.j

### All GWT Classes

In [15]:
print("Total GWT Codes: {}\n".format(len(inputFiles[124:174])))
print(inputFiles[124:174])

gwt_codes = get_codes(inputFiles[124:174])
# print(gwt_codes)

gwt_correct_outputs = get_correct_outputs(correctOutputFiles[124:174])
# print(gwt_correct_outputs)

gwt_dataset = get_dataset(gwt_codes, gwt_correct_outputs)
# print(gwt_dataset)

gwt_test_examples, gwt_y_true = get_test_examples_and_y_true(gwt_dataset)
# print(gwt_test_examples)

Total GWT Codes: 50

['gwt_class_1.java.json', 'gwt_class_2.java.json', 'gwt_class_3.java.json', 'gwt_class_4.java.json', 'gwt_class_5.java.json', 'gwt_class_6.java.json', 'gwt_class_7.java.json', 'gwt_class_8.java.json', 'gwt_class_9.java.json', 'gwt_class_10.java.json', 'gwt_class_11.java.json', 'gwt_class_12.java.json', 'gwt_class_13.java.json', 'gwt_class_14.java.json', 'gwt_class_15.java.json', 'gwt_class_16.java.json', 'gwt_class_17.java.json', 'gwt_class_18.java.json', 'gwt_class_19.java.json', 'gwt_class_20.java.json', 'gwt_class_21.java.json', 'gwt_class_22.java.json', 'gwt_class_23.java.json', 'gwt_class_24.java.json', 'gwt_class_25.java.json', 'gwt_class_26.java.json', 'gwt_class_27.java.json', 'gwt_class_28.java.json', 'gwt_class_29.java.json', 'gwt_class_30.java.json', 'gwt_class_31.java.json', 'gwt_class_32.java.json', 'gwt_class_33.java.json', 'gwt_class_34.java.json', 'gwt_class_35.java.json', 'gwt_class_36.java.json', 'gwt_class_37.java.json', 'gwt_class_38.java.json',

### All XStream Classes

In [16]:
print("Total XStream Codes: {}\n".format(len(inputFiles[224:268])))
print(inputFiles[224:268])

xstream_codes = get_codes(inputFiles[224:268])
# print(xstream_codes)

xstream_correct_outputs = get_correct_outputs(correctOutputFiles[224:268])
# print(xstream_correct_outputs)

xstream_dataset = get_dataset(xstream_codes, xstream_correct_outputs)
# print(xstream_dataset)

xstream_test_examples, xstream_y_true = get_test_examples_and_y_true(xstream_dataset)
# print(xstream_test_examples)

Total XStream Codes: 44

['xstream_class_1.java.json', 'xstream_class_2.java.json', 'xstream_class_3.java.json', 'xstream_class_4.java.json', 'xstream_class_5.java.json', 'xstream_class_6.java.json', 'xstream_class_7.java.json', 'xstream_class_8.java.json', 'xstream_class_9.java.json', 'xstream_class_10.java.json', 'xstream_class_11.java.json', 'xstream_class_12.java.json', 'xstream_class_13.java.json', 'xstream_class_14.java.json', 'xstream_class_15.java.json', 'xstream_class_16.java.json', 'xstream_class_17.java.json', 'xstream_class_18.java.json', 'xstream_class_19.java.json', 'xstream_class_20.java.json', 'xstream_class_21.java.json', 'xstream_class_22.java.json', 'xstream_class_23.java.json', 'xstream_class_24.java.json', 'xstream_class_25.java.json', 'xstream_class_26.java.json', 'xstream_class_27.java.json', 'xstream_class_28.java.json', 'xstream_class_29.java.json', 'xstream_class_30.java.json', 'xstream_class_31.java.json', 'xstream_class_32.java.json', 'xstream_class_33.java.

# One-shot Learning Implementation

In [17]:
MODEL = 'gpt-3.5-turbo'

In [18]:
def pred_process(y_pred, y_true):
    y_pred_processed = []
    y_true_processed = []
    
    for pred, correct_imports in zip(y_pred, y_true):
        max_length = max(len(pred), len(correct_imports))
        correct_preds = list(set(pred).intersection(correct_imports))
#         print('Correct Predictions:', correct_preds)
#         wrong_preds = max_length - len(correct_preds)
#         print('Wrong Predictions:', wrong_preds)

        for i in range(0, max_length):
            if i<len(correct_preds):
                y_pred_processed.append(1)
                y_true_processed.append(1)
            else:
                if i<len(correct_imports):
                    y_pred_processed.append(0)
                    y_true_processed.append(1)
                else:
                    y_pred_processed.append(1)
                    y_true_processed.append(0)
            
    print(y_pred_processed)
    print(y_true_processed)
    print()
    return y_pred_processed, y_true_processed

In [19]:
def eval_performance(y_pred, y_true):
    print(json.dumps({
        "accuracy": accuracy_score(y_pred, y_true),
        "f1": f1_score(y_pred, y_true),
        "precision": precision_score(y_pred, y_true),
        "recall": recall_score(y_pred, y_true)
    }, indent=2))

In [20]:
# # Test

# print(xstream_test_examples[39][1]) # correct imports
# print(xstream_test_examples[39][0]) # code

In [21]:
# # Test

# MODEL = 'gpt-3.5-turbo'

# code_snippet = xstream_test_examples[39][0]
# # prompt=f"Reply with to-the-point answer, with no elaboration. Extract all valid possible fully qualified type long-name from the code below, must exclude name of the class itself:\n\"\"\"{test_code}\"\"\""

# response = openai.ChatCompletion.create(
#     model=MODEL,
#     messages=[
#         {"role": "system", "content": "Reply with to-the-point answer, no elaboration."},
# #         {"role": "user", "content": f"Import correct imports:\n\"\"\"\n{test_code}\n\"\"\""},
# #         {"role": "user", "content": f"Your task is to provide a list of the correct imports for a given code snippet, ensuring that no wildcard imports are used. If any necessary imports are missing from the code, please include them in your response. Please note that your response should be specific and accurate:\n\"\"\"\n{code_snippet}\n\"\"\""},
# #         {"role": "user", "content": f"Do not check for any import statements in the code. Only give correct imports by not using wildcard imports. Please note that your response should be specific and accurate:\n\"\"\"\n{code_snippet}\n\"\"\""},
#         {"role": "user", "content": f"Do not check for any import statements in the code. Only give correct imports by not using wildcard imports. Please note that you need to pay close attention and your response should be specific and accurate:\n\"\"\"\n{code_snippet}\n\"\"\""},
        
#     ],
#     temperature=1,
# )

# print(response["choices"][0]["message"]["content"])

# One-shot

In [22]:
def get_oneshot_prediction(example, code):
    time.sleep(30)
    for code_snippet, correct_output in example:
        expected_output = '\n'.join(correct_output)
    
    response = openai.ChatCompletion.create(
        model=MODEL,
        messages=[
            {"role": "system", "content": "Reply with to-the-point answer, no elaboration."},
            {"role": "system", "name":"example_user", "content": f"Do not check for any import statements in the code. Only give correct imports by not using wildcard imports. Please note that you need to pay close attention and your response should be specific and accurate:\n\"\"\"\n{code_snippet}\n\"\"\""},
            {"role": "system", "name": "example_assistant", "content": f"{expected_output}"},
            {"role": "user", "content": f"Do not check for any import statements in the code. Only give correct imports by not using wildcard imports. Please note that you need to pay close attention and your response should be specific and accurate:\n\"\"\"\n{code}\n\"\"\""},
        ],
        temperature=0,
    )
    return response["choices"][0]["message"]["content"]

In [23]:
def get_all_oneshot_predictions(example, dataset):
    y_pred = []
    for code_snippet, correct_imports in tqdm(dataset):
        if code_snippet in example[0][0]:
            predicted_output = get_oneshot_prediction([example[1]], code_snippet)
        else:
            predicted_output = get_oneshot_prediction([example[0]], code_snippet)
        y_pred.append(re.findall(r"import\s+[\w\., ]+;", predicted_output))
        time.sleep(2)
    return y_pred

In [24]:
def get_oneshot_sample(dataset):
    oneshot_list = []
    
    for i in range(0, len(dataset["codes"])):
        oneshot_list.append(dict(codes=dataset["codes"][i], correct_outputs=dataset["correct_outputs"][i]))
        
    oneshot_example = [(example["codes"], example["correct_outputs"]) for example in oneshot_list]
    oneshot_sample = random.sample(oneshot_example, 2)
    
    # print(oneshot_sample)
    return oneshot_sample

In [25]:
# One-shot Prediction for Android Classes

print("\nOne-shot Prediction for Android Classes:\n")
y_pred = get_all_oneshot_predictions(get_oneshot_sample(android_dataset), android_test_examples)
print("\nPredicted Import List:", y_pred)
print("\nCorrect Import List:", android_y_true)
y_pred_processed, y_true_processed = pred_process(y_pred, android_y_true)
eval_performance(y_pred_processed, y_true_processed)


One-shot Prediction for Android Classes:



100%|██████████| 50/50 [28:19<00:00, 34.00s/it]


Predicted Import List: [['import android.app.Activity;', 'import android.os.Bundle;', 'import android.widget.TextView;'], ['import android.app.Activity;', 'import android.os.Bundle;'], ['import android.graphics.drawable.Drawable;', 'import android.os.Bundle;', 'import com.google.android.maps.ItemizedOverlay;', 'import com.google.android.maps.OverlayItem;', 'import com.google.android.maps.GeoPoint;'], ['import android.app.Activity;', 'import android.os.Bundle;', 'import android.widget.TextView;'], ['import android.graphics.drawable.Drawable;', 'import android.widget.TabHost;', 'import android.widget.TabHost.TabSpec;'], ['import android.view.Gravity;', 'import android.view.ViewGroup;', 'import android.widget.LinearLayout;'], ['import android.app.Activity;', 'import android.os.Bundle;', 'import android.view.View;', 'import android.widget.ArrayAdapter;'], ['import java.io.IOException;', 'import java.net.Inet4Address;', 'import java.net.InetAddress;', 'import java.net.Socket;', 'import jav




In [26]:
# One-shot Prediction for JDK Classes

print("\nOne-shot Prediction for JDK Classes:\n")
y_pred = get_all_oneshot_predictions(get_oneshot_sample(jdk_dataset), jdk_test_examples)
print("\nPredicted Import List:", y_pred)
print("\nCorrect Import List:", jdk_y_true)
y_pred_processed, y_true_processed = pred_process(y_pred, jdk_y_true)
eval_performance(y_pred_processed, y_true_processed)


One-shot Prediction for JDK Classes:



100%|██████████| 23/23 [12:59<00:00, 33.88s/it]


Predicted Import List: [['import java.applet.Applet;', 'import java.awt.Color;', 'import java.awt.Container;', 'import java.awt.Dimension;', 'import java.awt.Frame;', 'import java.awt.Graphics;', 'import java.awt.Graphics2D;', 'import java.awt.event.WindowAdapter;', 'import java.awt.event.WindowEvent;', 'import java.awt.image.BufferedImage;', 'import java.io.FileInputStream;', 'import java.io.FileOutputStream;', 'import java.io.IOException;', 'import java.io.ObjectInputStream;', 'import java.io.ObjectOutputStream;'], ['import java.awt.BorderLayout;', 'import java.awt.Color;', 'import java.awt.Dimension;', 'import java.awt.FlowLayout;', 'import java.awt.GradientPaint;', 'import java.awt.GridBagLayout;', 'import java.awt.GridLayout;', 'import java.awt.Graphics2D;', 'import java.awt.image.BufferedImage;', 'import java.awt.event.ActionEvent;', 'import java.awt.event.ActionListener;', 'import javax.swing.BorderFactory;', 'import javax.swing.JButton;', 'import javax.swing.JCheckBox;', 'impo




In [27]:
# One-shot Prediction for Hibernate Classes

print("\nOne-shot Prediction for Hibernate Classes:\n")
y_pred = get_all_oneshot_predictions(get_oneshot_sample(hibernate_dataset), hibernate_test_examples)
print("\nPredicted Import List:", y_pred)
print("\nCorrect Import List:", hibernate_y_true)
y_pred_processed, y_true_processed = pred_process(y_pred, hibernate_y_true)
eval_performance(y_pred_processed, y_true_processed)


One-shot Prediction for Hibernate Classes:



100%|██████████| 51/51 [28:52<00:00, 33.97s/it]


Predicted Import List: [['import org.hibernate.SessionFactory;', 'import org.hibernate.cfg.AnnotationConfiguration;'], ['import org.hibernate.Session;', 'import org.hibernate.SessionFactory;', 'import org.hibernate.Transaction;', 'import org.hibernate.cfg.AnnotationConfiguration;', 'import dao.UserDAO;', 'import model.User;'], ['import java.io.Serializable;', 'import java.util.List;', 'import javax.persistence.Cacheable;', 'import javax.persistence.Column;', 'import javax.persistence.Entity;', 'import javax.persistence.Id;', 'import javax.persistence.JoinColumn;', 'import javax.persistence.ManyToOne;', 'import javax.persistence.OneToMany;', 'import org.hibernate.annotations.Cache;', 'import org.hibernate.annotations.CacheConcurrencyStrategy;', 'import static javax.persistence.CascadeType.ALL;'], ['import java.sql.Types;', 'import org.hibernate.dialect.Dialect;', 'import org.hibernate.Hibernate;', 'import org.hibernate.dialect.function.SQLFunctionTemplate;', 'import org.hibernate.diale




In [28]:
# One-shot Prediction for Joda-Time Classes

print("\nOne-shot Prediction for Joda-Time Classes:\n")
y_pred = get_all_oneshot_predictions(get_oneshot_sample(jodatime_dataset), jodatime_test_examples)
print("\nPredicted Import List:", y_pred)
print("\nCorrect Import List:", jodatime_y_true)
y_pred_processed, y_true_processed = pred_process(y_pred, jodatime_y_true)
eval_performance(y_pred_processed, y_true_processed)


One-shot Prediction for Joda-Time Classes:



100%|██████████| 50/50 [27:48<00:00, 33.37s/it]


Predicted Import List: [['import org.joda.time.DateTime;', 'import org.joda.time.DateTimeZone;', 'import org.joda.time.format.DateTimeFormatter;', 'import org.joda.time.format.ISODateTimeFormat;', 'import java.util.TimeZone;'], ['import org.joda.time.DateMidnight;', 'import org.joda.time.DateTime;', 'import org.joda.time.DateTimeZone;'], ['import org.joda.time.Interval;', 'import org.joda.time.PeriodFormatter;', 'import org.joda.time.PeriodFormatterBuilder;'], ['import org.joda.time.Period;', 'import org.joda.time.ReadableInstant;'], ['import org.joda.time.Chronology;', 'import org.joda.time.DateTime;', 'import org.joda.time.DateTimeZone;', 'import org.joda.time.chrono.GJChronology;'], ['import org.joda.time.Duration;', 'import org.joda.time.PeriodFormatterBuilder;', 'import org.joda.time.PeriodType;'], ['import java.text.DateFormat;', 'import java.text.ParseException;', 'import java.text.SimpleDateFormat;', 'import java.util.Calendar;', 'import java.util.Date;', 'import org.joda.time




In [29]:
# One-shot Prediction for GWT Classes

print("\nOne-shot Prediction for GWT Classes:\n")
y_pred = get_all_oneshot_predictions(get_oneshot_sample(gwt_dataset), gwt_test_examples)
print("\nPredicted Import List:", y_pred)
print("\nCorrect Import List:", gwt_y_true)
y_pred_processed, y_true_processed = pred_process(y_pred, gwt_y_true)
eval_performance(y_pred_processed, y_true_processed)


One-shot Prediction for GWT Classes:



100%|██████████| 50/50 [29:01<00:00, 34.83s/it]


Predicted Import List: [['import com.google.gwt.event.dom.client.MouseDownEvent;', 'import com.google.gwt.event.dom.client.MouseDownHandler;', 'import com.google.gwt.user.client.ui.AbsolutePanel;', 'import com.google.gwt.user.client.ui.Composite;'], ['import com.google.gwt.core.client.GWT;', 'import com.google.gwt.junit.client.GWTTestCase;', 'import com.google.gwt.user.client.rpc.AsyncCallback;', 'import com.google.gwt.sample.stockwatcher.client.GreetingService;', 'import com.google.gwt.sample.stockwatcher.client.GreetingServiceAsync;'], ['import com.google.gwt.core.client.EntryPoint;', 'import com.google.gwt.user.client.Window;', 'import com.google.gwt.user.client.rpc.AsyncCallback;', 'import com.google.gwt.user.client.ui.Button;', 'import com.google.gwt.user.client.ui.RootPanel;', 'import com.google.gwt.user.client.ui.TextBox;'], ['import com.google.gwt.user.client.ui.Composite;', 'import com.google.gwt.user.client.ui.HTML;', 'import com.google.gwt.user.client.ui.VerticalSplitPanel;




In [30]:
# One-shot Prediction for XStream Classes

print("\nPrediction for XStream Classes:\n")
y_pred = get_all_oneshot_predictions(get_oneshot_sample(xstream_dataset), xstream_test_examples)
print("\nPredicted Import List:", y_pred)
print("\nCorrect Import List:", xstream_y_true)
y_pred_processed, y_true_processed = pred_process(y_pred, xstream_y_true)
eval_performance(y_pred_processed, y_true_processed)


Prediction for XStream Classes:



100%|██████████| 44/44 [25:07<00:00, 34.27s/it]


Predicted Import List: [['import com.thoughtworks.xstream.converters.Converter;', 'import com.thoughtworks.xstream.io.HierarchicalStreamReader;', 'import com.thoughtworks.xstream.io.HierarchicalStreamWriter;', 'import com.thoughtworks.xstream.converters.MarshallingContext;', 'import com.thoughtworks.xstream.converters.UnmarshallingContext;'], ['import java.io.BufferedReader;', 'import java.io.FileReader;', 'import java.io.IOException;', 'import com.thoughtworks.xstream.XStream;', 'import com.thoughtworks.xstream.io.xml.DomDriver;'], ['import java.util.ArrayList;', 'import java.util.List;', 'import com.thoughtworks.xstream.XStream;', 'import com.thoughtworks.xstream.io.json.JettisonMappedXmlDriver;'], ['import java.util.ArrayList;', 'import java.util.List;', 'import com.thoughtworks.xstream.XStream;'], ['import java.io.FileNotFoundException;', 'import java.io.PrintWriter;', 'import com.thoughtworks.xstream.XStream;', 'import com.thoughtworks.xstream.io.xml.DomDriver;'], ['import java.i


