#### Imports


In [1]:
import pandas as pd
pd.set_option('display.max_colwidth', None)


Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [2]:
from sentence_transformers import InputExample, losses, models, SentenceTransformer, util


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from torch import device, cuda, save, load

In [4]:
device = device('cuda' if cuda.is_available() else 'cpu')
device

device(type='cuda')

In [5]:
model = SentenceTransformer('all-mpnet-base-v2')
# telling our model to run on the gpu instead of cpu
model.to(device)

SentenceTransformer(
  (0): Transformer({'max_seq_length': 384, 'do_lower_case': False}) with Transformer model: MPNetModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False})
  (2): Normalize()
)

#### Get dataset with duplicates 

In [6]:
pairs_df = pd.read_csv('../fine-tuning-experimentation/data/firefox_pairs.csv')

In [7]:
pairs_df.head()

Unnamed: 0,issue_id,duplicate
0,10954,
1,14871,243500;410103;505684;515027;528678
2,19118,326494;328227;414070;436576;443686;457861;475975
3,54746,
4,56892,191258;281233;290692;300719;307581;310641;311707;313310;320257;327174;359714;371510;431038


#### Get encodings

In [10]:
enc_dict = load('../fine-tuning-experimentation/data/enc_dict.pt')
encodings_df = pd.DataFrame.from_dict(enc_dict)

In [11]:
encodings_df.head()

Unnamed: 0,encoded_desc,bug_id
0,"[-0.01127419, 0.09213824, -0.010780326, -0.036360744, 0.008821234, -0.014747565, 0.044438973, -0.017641861, -0.005438591, 0.012211488, -0.0480949, -0.022684705, 0.017518967, 0.07959968, 0.010277145, -0.050826497, -0.016864303, -0.010669791, -0.04101929, 0.087596074, 0.032721374, -0.010069381, 0.0078828875, 0.021323869, 0.008898046, -0.00025508835, 0.047721863, -0.031515848, -0.0025235766, -0.0047467956, 0.009059487, 0.009790037, -0.0236818, -0.07447995, 1.7802085e-06, -0.005082205, -0.0037235976, 0.0017059274, -0.057524327, 0.067195594, 0.02299202, 0.031367958, 0.021168903, 0.013167133, 0.03307475, 0.027962737, -0.024002282, 0.003942254, -0.030754773, 0.024776492, 0.0038120097, 0.06482462, -0.0043128105, -0.001264508, -0.048844222, -0.021744475, 0.0053113345, -0.03462892, -0.06670586, 0.016362512, 0.016815448, 0.015525976, -0.009661066, -0.0014679095, -0.030054966, -0.012817667, -0.002205806, -0.02808958, -0.018274853, 0.06452928, 0.059253726, -0.0006403717, -0.038203105, -0.04631362, 0.07206086, -0.040146906, 0.036716063, -0.09374257, -0.020628802, -0.024082614, -0.08759612, -0.020053133, 5.6112156e-05, 0.00840942, -0.021146268, -0.008677628, 0.023085555, 0.009518357, -0.015741983, -0.04504591, 0.05501785, 0.042054914, 0.006209637, -0.0033518923, -0.0066216467, -0.026676483, 0.03313994, 0.03455793, 0.0072474484, 0.009271839, ...]",10954
1,"[0.02882547, 0.049672984, -0.0020014874, 0.043006953, 0.004260654, -0.015860891, 0.055736672, 0.09124813, 0.00214753, -0.024571493, -0.043630272, -0.040884845, -0.06515039, 0.029930543, 0.012362121, 0.0022516227, 0.050048564, 0.03730777, 0.029207045, 0.006706878, 0.015388241, 0.017296001, -0.019971818, 0.04311141, -0.002059603, 0.02415514, 0.06682128, -0.060276847, -0.03077622, -0.037465733, 0.05168396, -0.0025320165, -0.041912336, 0.010558018, 1.175431e-06, -0.032464426, -0.040101048, -0.031849407, -0.05495906, -0.023896568, -0.039975446, 0.07022626, -0.0074526896, 0.035847563, 0.05384835, 0.034423478, 0.014600788, 0.057188652, -0.034820035, 0.014526436, 0.026598573, 0.052928228, 0.032607622, -0.03049453, 0.046402864, -0.037227895, -0.01791191, -0.08377861, 0.040173378, -0.02095788, 0.008785141, 0.055906292, -0.086716145, -0.019168051, -0.0054451018, -0.028331162, -0.029368, -0.086946376, 0.06080151, 0.041546274, 0.022515832, 0.018670565, 0.011473462, -0.016075155, -0.019044854, 0.06319244, 0.02130457, 0.027938846, -0.04241278, -0.01961137, -0.016724195, -0.025880083, 0.038666576, 0.0011917221, -0.011339993, 0.07744246, 0.04409753, -0.006475372, 0.018368188, -0.06934106, 0.007920357, -0.0023501646, 0.018097911, 0.008196992, -0.005038384, -0.0320736, -0.057495505, -0.03626568, 0.04055407, 0.0034099494, ...]",14871
2,"[0.006700601, -0.004019192, -0.0318143, 0.021704571, -0.004402159, 0.0284156, 0.04094607, 0.005694298, 0.05134395, -0.0114592435, -0.06279794, 0.041715343, -0.0393204, 0.020884788, -0.00027920332, -0.046922762, 0.014737423, -0.023171797, 0.0015259387, 0.02659671, 0.00844856, 0.03409411, -0.015694052, -0.027621962, -0.0014698722, 0.02903046, 0.025903856, -0.048716243, -0.033758406, -0.033939466, 0.022311943, 0.01931929, 0.010301303, 0.004417263, 2.0243074e-06, -0.01657759, 0.003426655, -0.057916842, 0.01290153, 0.008284493, -0.06543149, 0.024857253, 0.0004147165, 0.04298741, 0.031871412, 0.004252707, -0.0020274122, -0.067916825, -0.053658664, 0.047093194, -0.019345133, 0.04434098, -0.021815557, -0.034068596, 0.015164568, 0.040828735, 0.0020078423, -0.06385667, 0.028445732, -0.0066827973, 0.019090457, -0.02335225, -0.049966283, 0.039272085, 0.047744684, 0.00032195268, -0.02105324, -0.054989345, 0.011412853, 0.005475069, 0.0041454234, 0.049349315, -0.005415229, -0.044286273, 0.01764389, 0.06368156, 0.014679507, -0.017725844, -0.010503373, -0.03352048, 0.007817125, 0.058375776, -0.01374185, 0.0036495198, -0.07532817, 0.06497978, 0.020627657, -0.016687313, 0.022952065, -0.039781567, 0.013982417, -0.04292883, -0.0099044265, -0.026572034, -0.016901512, 0.0037568097, 0.048808914, 0.002518334, 0.041487284, 0.004410792, ...]",19118
3,"[0.013975125, -0.027381886, 0.012256325, 0.06366445, 0.08699522, 0.041640576, 0.026173443, -0.00842815, 0.001480479, -0.012105309, 0.04235706, -0.042579178, 0.008394941, 0.0065396572, 0.06342379, 0.061359167, 0.0073190117, -0.041641776, -0.052370768, 0.0045029437, -0.054267358, 0.02231135, 0.059449024, 0.030434469, -0.009184048, 0.046186097, 0.04800892, -0.01041949, -0.008449376, -0.040949285, -0.024331175, 0.011671206, 0.01532857, -0.03197344, 1.814704e-06, -0.0250821, -0.04973859, -0.036345597, -0.02858022, 0.008879605, 0.012454291, -0.016430149, -0.05037595, 0.017405946, -0.0004677903, -0.061269313, 0.036538243, -0.06365136, -0.04812416, 0.029996617, -0.0044330023, -0.062826164, -0.02264076, -0.03136708, -0.0064045815, 0.019925267, 0.021456532, -0.0038213576, -0.057139125, 0.06775931, -0.024031669, -0.0302923, -0.0070322542, -0.0043332023, -0.07692371, -0.05223118, -0.039505627, -0.040362928, 0.06418831, -0.0019792037, 0.072202995, 0.006840637, -0.017822238, 0.0074143363, 0.027003514, -0.03656964, 0.049607374, -0.0033024154, 0.017109288, 0.01530386, 0.033268522, 0.05472781, 0.009436116, -0.00363534, -0.028548086, 0.099416345, -0.025462367, 0.047893677, -0.0041851876, -0.043623965, -0.024486087, -0.0009080458, -0.022845978, -0.06817033, 0.03743567, 0.014134539, 0.016033102, -0.015871966, 0.011676998, -0.023717465, ...]",54746
4,"[0.0023555872, 0.039769016, 0.0009729534, 0.016215794, -0.015417997, 0.02733129, 0.049319144, 0.025507936, 0.032793526, -0.010358468, 0.010956046, -0.02630035, 0.019997533, -0.041917354, 0.04128754, 0.023069711, 0.0003757152, -0.020604327, -0.013798791, 0.018756937, -0.02711662, -0.042769827, -0.043815885, -0.0067237457, -0.0189784, 0.0661899, -0.005748351, 0.007915113, 0.015398354, 0.026141468, 0.034941047, -0.03469631, 0.009194629, -0.0044098827, 2.2203435e-06, 0.027426133, 0.06965741, 0.000143063, 0.06995282, 0.01671017, 0.00057288236, -0.0013976423, 0.012706621, 0.0076256907, 0.067503974, -0.040626045, -0.05153153, 0.017610198, -0.0046529677, 0.0333582, 0.02342221, -0.033086266, 0.03370382, 0.039309315, 0.10799291, -0.036977883, -0.026455095, -0.08567557, -0.0075224596, -0.021364788, 0.0025516893, -0.021130938, -0.03595264, -0.049106937, -0.003445796, -0.034620956, 0.015374156, -0.048240796, 0.08111518, -0.02876194, 0.06330422, 0.0040498357, 0.0022435072, -0.031651136, 0.01880304, 0.022670748, 0.04332304, -0.0062558986, -0.0018016634, -0.005723244, -0.010037816, 0.011499027, -0.020531055, -0.033791926, 0.009190378, 0.056099396, -0.0039254543, -0.012484346, -0.017605746, -0.036050495, 0.14745809, 0.04531386, -0.016986126, -0.0012577168, 0.008849473, 0.07033548, -0.0033728338, -0.03446986, 0.048439845, 0.004929029, ...]",56892


In [75]:
count = 0
test_set = encodings_df[0:5000]
k = 10
rr_k_all = 0
rr_k_all_mean = 0

# Iterate through the test set
for index, row in test_set.iterrows():
    # Get bug ID and encoded description of current row
    curr_br_enc_desc = row[0]
    bug_id = row[1]

    # Get row in pairs table corresponding to the current bug report ID
    dup_of_row = pairs_df.loc[pairs_df['issue_id'] == bug_id]
    dup_of_row = dup_of_row.iloc[0][1]
    bug_similarity_scores = []

    # if current row of duplicates from pairs table actually has duplicates (i.e. is not nan)
    if not pd.isna(dup_of_row):
        # Turn list of pairs  into ints
        dup_split = dup_of_row.split(';')
        dup_of_row_ints = [int(dup) for dup in dup_split]

        # calculate cos_sim and append to list
        for index_compare, row_compare in test_set.iterrows():
            bug_similarity_scores.append((
                util.cos_sim(curr_br_enc_desc, row_compare.iloc[0]),
                row_compare.iloc[1]
            ))

        # Sort cos_sim scores in descending order
        bug_similarity_scores.sort(reverse=True)

        # get top k scores, ignore first score since it is from the current report that we compare all the others to
        top_k = bug_similarity_scores[1:k+1]

        # calculate recall rate
        num_duplicates_found = 0
        for score in top_k:
            for pos_dup in dup_of_row_ints:
                if score[1] == pos_dup:
                    num_duplicates_found += 1



        rr_k = num_duplicates_found / len(dup_of_row_ints)
        rr_k_all += rr_k
        # print("RECALL: ", rr_k)

rr_k_all_mean = rr_k_all / len(test_set)
print(rr_k_all_mean)




        


  curr_br_enc_desc = row[0]
  bug_id = row[1]
  dup_of_row = dup_of_row.iloc[0][1]


0.16525412032568826
