In [1]:
from models.tree import CatBoost, RandomForest
from models.linear import LinearSVM, LogisticReg
from models.baseline import Baseline
from utils.usage_loader import initial_feature_names
from utils.encoder import encode_column
from sklearn.utils import shuffle
import numpy as np
from utils.usage_loader import UsagesLoader
from metric.calculate import calculate_and_print

In [2]:
jadx_loader = UsagesLoader(['/Users/danilbk/Programming/Java/test/jadx/jadx/project-processing-results/processing/java/annotations/processing/0.0.0'])
spring_loader = UsagesLoader(['/Users/danilbk/Desktop/0.0.0'])

In [3]:
train_usages = list(filter(lambda x: 'java.lang.Override' not in x.annotation_name, spring_loader.load_all()))
test_usages = list(filter(lambda x: 'java.lang.Override' not in x.annotation_name, jadx_loader.load_all()))
train_classes = set(x.annotation_name for x in train_usages)
test_classes = set(x.annotation_name for x in test_usages)
classes = train_classes & test_classes
train_usages = list(filter(lambda x: x.annotation_name in classes, train_usages))
test_usages = list(filter(lambda x: x.annotation_name in classes, test_usages))
train_size = 100000
train_usages = shuffle(train_usages, random_state=123)[:train_size]
usages = train_usages + test_usages
raw_X = np.array([np.array(usage.features_list, dtype=object) for usage in usages])
X = None
all_new_names = []
for col in range(raw_X.shape[1]):
    new_columns, new_names = encode_column(raw_X[:, col], len(train_usages),
                                           initial_feature_names[col], 100)
    if new_columns is None:
        continue
    all_new_names += new_names
    if X is None:
        X = new_columns
    else:
        X = np.concatenate((X, new_columns), axis=1)
y = np.array([usage.annotation_name for usage in usages])

actual_train_size = len(X) - len(test_usages)
X_train = X[:actual_train_size]
y_train = y[:actual_train_size]
X_test = X[actual_train_size:]
y_test = y[actual_train_size:]

In [4]:
classes

{'java.lang.Deprecated',
 'java.lang.FunctionalInterface',
 'java.lang.SafeVarargs',
 'java.lang.annotation.Documented',
 'java.lang.annotation.Retention',
 'java.lang.annotation.Target',
 'org.junit.jupiter.api.AfterAll',
 'org.junit.jupiter.api.AfterEach',
 'org.junit.jupiter.api.BeforeAll',
 'org.junit.jupiter.api.BeforeEach',
 'org.junit.jupiter.api.Test',
 'org.junit.jupiter.api.TestTemplate',
 'org.junit.jupiter.api.extension.ExtendWith',
 'org.junit.jupiter.api.io.TempDir'}

In [5]:
len(X_train)

24034

In [7]:
import collections

collections.Counter([x.annotation_name for x in spring_loader.load_all()]).most_common()

[('java.lang.Override', 18792),
 ('org.junit.jupiter.api.Test', 18127),
 ('org.springframework.lang.Nullable', 13399),
 ('org.springframework.context.annotation.Bean', 1125),
 ('org.springframework.context.annotation.Configuration', 830),
 ('java.lang.annotation.Retention', 798),
 ('org.springframework.beans.factory.annotation.Autowired', 773),
 ('org.junit.jupiter.api.BeforeEach', 539),
 ('org.springframework.core.annotation.AliasFor', 464),
 ('org.springframework.web.bind.annotation.RequestMapping', 423),
 ('java.lang.Deprecated', 410),
 ('org.springframework.lang.NonNullApi', 375),
 ('org.springframework.lang.NonNullFields', 375),
 ('java.lang.annotation.Target', 374),
 ('org.springframework.test.context.ContextConfiguration', 294),
 ('org.springframework.web.servlet.handler.PathPatternsParameterizedTest',
  234),
 ('org.springframework.stereotype.Controller', 216),
 ('org.junit.Test', 213),
 ('java.lang.FunctionalInterface', 178),
 ('java.lang.annotation.Documented', 178),
 ('org.s

In [6]:
len(X_test)

871

In [9]:
X_train, X_test = X_test, X_train
y_train, y_test = y_test, y_train

In [9]:
model = CatBoost(task_type='CPU', early_stopping_rounds=20, verbose=True, iterations=500, learning_rate=0.18, depth=6)

calculate_and_print(X_train, X_test, y_train, y_test, Baseline())
# calculate_and_print(X_train, X_test, y_train, y_test, LinearSVM())
calculate_and_print(X_train, X_test, y_train, y_test, model)

Baseline
Count: 871
Top 1: 0.7864523536165328
Top 2: 0.931113662456946
Top 3: 0.947187141216992
Top 4: 0.9598163030998852
Top 5: 0.9667049368541906
Top1 1: 0.931113662456946
Mean: 1.497129735935706
CatBoost
0:	learn: 0.9987174	test: 0.9902282	best: 0.9902282 (0)	total: 187ms	remaining: 1m 33s
1:	learn: 0.7795006	test: 0.7733222	best: 0.7733222 (1)	total: 383ms	remaining: 1m 35s
2:	learn: 0.6606129	test: 0.6541836	best: 0.6541836 (2)	total: 589ms	remaining: 1m 37s
3:	learn: 0.5612049	test: 0.5565942	best: 0.5565942 (3)	total: 875ms	remaining: 1m 48s
4:	learn: 0.4902404	test: 0.4854295	best: 0.4854295 (4)	total: 1.07s	remaining: 1m 45s
5:	learn: 0.4370543	test: 0.4333762	best: 0.4333762 (5)	total: 1.24s	remaining: 1m 41s
6:	learn: 0.3960843	test: 0.3924559	best: 0.3924559 (6)	total: 1.4s	remaining: 1m 38s
7:	learn: 0.3625759	test: 0.3590343	best: 0.3590343 (7)	total: 1.56s	remaining: 1m 36s
8:	learn: 0.3312184	test: 0.3277060	best: 0.3277060 (8)	total: 1.73s	remaining: 1m 34s
9:	learn: 0

In [11]:
list(zip(sorted(list(zip(model.model.get_feature_importance(), all_new_names)), reverse=True)))

[((58.39121491214253, 'targetType_void'),),
 ((7.849724945398914, 'fileName_test'),),
 ((6.3903421071697855, 'targetName_test'),),
 ((5.505703585772887, 'target_type'),),
 ((4.099630495338746, 'target_annotation'),),
 ((3.8213908514258366, 'className_test'),),
 ((2.404613431751173, 'filePath_integration'),),
 ((1.72990785640665, 'otherAnnotations_target'),),
 ((1.46709375837522, 'otherAnnotations_javalangannotation'),),
 ((1.4538663798971379, 'filePath_main'),),
 ((1.37218417388352, 'otherAnnotations_retention'),),
 ((0.6084124571137117, 'modifiers_public'),),
 ((0.5155313544130632, 'otherMethodsAnnotations_orgjunitjupiterapi'),),
 ((0.44293123204509854, 'otherMethodsNames_test'),),
 ((0.27275456384992436, 'otherMethodsAnnotations_test'),),
 ((0.2635582257772206, 'otherMethodsAnnotations_before'),),
 ((0.23700934212720964, 'modifiers_private'),),
 ((0.22953133878534968, 'otherMethodsNames_alias'),),
 ((0.20852737729498513, 'otherMethodsAnnotations_javalang'),),
 ((0.20759083068893253, 