In [2]:
import os, subprocess, json, pandas as pd, matplotlib.pyplot as plt, datetime as dt, torch
while not os.getcwd().endswith('-analysis'): os.chdir('..')

SERVER_DATA_DIR = 'REDACTED' 
DATA_DIR = os.path.join('..', 'data')

# NOTE: do not sync 
# _ = subprocess.run(['rsync', '-avz', f'larisa:{SERVER_DATA_DIR}', DATA_DIR])
DATA_DIR = os.path.join(DATA_DIR, SERVER_DATA_DIR.split('/')[-1])

## Running CodeBERTScore on Prediction & Ground-Truth per Filter

In [3]:
import code_bert_score 
import datetime as dt, json

from pprint import pprint
from dataclasses import dataclass
from typing import Optional

@dataclass
class Query:

    user            : str
    prefix          : str
    suffix          : str
    trigger         : str
    language        : str
    ide             : str
    version         : str
    store           : bool
    timestamp       : dt.datetime
    predictions     : dict
    predict_time    : float
    survey          : bool

    time_since_last_completion: float
    filter_type     : str
    filter_time     : float
    should_filter   : bool
    study_version   : str

    @classmethod 
    def from_dict(cls, data: dict) -> 'Query':
        ''' Parse json dict into pythonic types, returning Query or VerifiedQuery '''

        try: 
            data['timestamp'] = dt.datetime.fromisoformat(data['timestamp'])
            return VerifiedQuery.from_dict(data) \
                if 'verifyToken' in data else cls(**data)
        except Exception as e:
            print(f'Error parsing query: {e}')
            pprint(data)

@dataclass 
class VerifiedQuery(Query):
    verify_token    : str
    chosen_model    : str
    ground_truth    : str
    shown_times     : Optional[list[dt.datetime]]
    accept_time     : dt.datetime

    @classmethod
    def from_dict(cls, data: dict):
        ''' parse json dict into pythonic types '''

        data['verify_token'] = data.pop('verifyToken')
        data['accept_time'] = dt.datetime.fromisoformat(data['accept_time'])
        
        # pushed bugfix for shown_times field on client-side
        data['shown_times'] = [dt.datetime.fromisoformat(t) for t in data['shown_times']] \
            if 'shown_times' in data else None

        return cls(**data)

In [4]:
unfiltered_queries_per_user = {
    user: sorted([
        Query.from_dict({
            **json.load(open(os.path.join(DATA_DIR, user, query))),
            'user': user
        }) for query in os.listdir(os.path.join(DATA_DIR, user))
    ], key=lambda q: q.timestamp, reverse=True) 
    for user in os.listdir(DATA_DIR) if user != '.DS_Store'
}

oops_date = dt.datetime(2024, 3, 21, 17, 17, 50)
queries_per_user = {
    user: [q for q in queries if q.timestamp >= oops_date] \
    for user, queries in unfiltered_queries_per_user.items()
}

queries = [q for user_queries in queries_per_user.values() for q in user_queries]
del queries_per_user, unfiltered_queries_per_user

print(f'''
    {len(queries):,} total (valid) queries

    {len([q for q in queries if isinstance(q, VerifiedQuery)]):,} verified queries
    {len([q for q in queries if not isinstance(q, VerifiedQuery)]):,} unverified queries

    {len([q for q in queries if hasattr(q, 'ground_truth') and q.ground_truth])} with ground truth
    ''')


    64,870 total (valid) queries

    498 verified queries
    64,372 unverified queries

    387 with ground truth
    


In [5]:
gt_queries = [q for q in queries if hasattr(q, 'ground_truth')]
len(gt_queries)
del queries

# wrong_qs = [q for q in gt_queries if not q.ground_truth] 
# pprint(wrong_qs[0])

# We're just going to assume that the bsc students that implemented this logic did it correctly. 
# but given my experience, they almost certainly did not

In [6]:
preds = [q.predictions[q.chosen_model] for q in gt_queries]
targets = [q.ground_truth for q in gt_queries]
langs = [q.language for q in gt_queries]

len(preds), len(targets)

for q in gt_queries: 
    if q.ground_truth: continue
    # print(q.predictions, q.ground_truth)
    if type(q.ground_truth) == None: print('found none')

#### CodeBERTScore
We need to sort our prediction/ground_truth on language, as this is what codebertscore expects. 
The only languages we can use are `python`, `javascript`, `c`, `cpp`, and `java`. 

- We probably can pool `javascriptreact`, `typescript`, and `typescriptreact` to `javascript`
- Also `ipynb` to `python`.

In [7]:
queries_per_language = {lang: [q for q in gt_queries if q.language == lang] for lang in set(langs)}

# sort on list length 
queries_per_language = dict(sorted(queries_per_language.items(), key=lambda x: len(x[1]), reverse=True))
for lang, queries in queries_per_language.items():
    print(f'{lang}: {len(queries):,}')

python: 236
html: 81
php: 48
latex: 35
css: 27
javascript: 18
markdown: 15
typescriptreact: 12
typescript: 9
java: 4
properties: 3
jsonc: 2
json: 2
bicep: 1
yaml: 1
ignore: 1
cpp: 1
scminput: 1
pip-requirements: 1


In [8]:
# NOTE: Let's merge them into their supposed groups 
# python < pip-requirements 
# javascript < typescript, typescriptreact, json, jsonc 
# and, optionally, php can maybe be merged with typescriptreact due to their similar design

def construct_valid_inputs(queries): 

    valid_langs = {
        'python': ['python', 'pip-requirements'], 
        'javascript': ['javascript', 'typescript', 'typescriptreact', 'php', 'html'],  # TODO: and maybe php & html
        'c': ['c'], 
        'cpp': ['cpp'], 
        'java': ['java']
    }

    queries_per_language = {lang: [q for q in queries if q.language in langs] for lang, langs in valid_langs.items()}
    return queries_per_language

# NOTE: for all filters
for lang, queries in construct_valid_inputs(gt_queries).items():
    print(f'{lang}: {len(queries):,}')

python: 237
javascript: 39
c: 0
cpp: 1
java: 4


In [9]:
# let's divide per filter 
q_per_filter = {}
for q in gt_queries:
    if q.ground_truth == None: continue 
    if q.filter_type not in q_per_filter: q_per_filter[q.filter_type] = []
    q_per_filter[q.filter_type].append(q)

for filter_type, queries in q_per_filter.items():
    lang_counts = {lang: len(v) for lang, v in construct_valid_inputs(queries).items()}
    print(f'{filter_type}: {len(queries):,} \t {lang_counts}')

joint_h: 130 	 {'python': 91, 'javascript': 7, 'c': 0, 'cpp': 0, 'java': 0}
context: 109 	 {'python': 45, 'javascript': 3, 'c': 0, 'cpp': 1, 'java': 2}
feature: 42 	 {'python': 20, 'javascript': 3, 'c': 0, 'cpp': 0, 'java': 0}
joint_a: 75 	 {'python': 37, 'javascript': 10, 'c': 0, 'cpp': 0, 'java': 0}
no_filter: 138 	 {'python': 42, 'javascript': 16, 'c': 0, 'cpp': 0, 'java': 2}


In [12]:
pairs = [(q.predictions[q.chosen_model], q.ground_truth) for q in q_per_filter['joint_h']]
for pred, gt in pairs: 
    if pred.strip() == '':  print(f'empty pr: PR:{pred:40} \tGT:{gt:40}')
    if gt.strip() == '':    print(f'empty gt: PR:{pred:40} \tGT:{gt:40}')
    # if pred is not gt: 
    #     print(f'{pred:40} \t{gt:40}')

empty gt: PR:={styles.top}>{t('chat')}</span>         	GT:                                        
empty gt: PR:={styles.top}>{t('chat')}</span>         	GT:                                        
empty gt: PR:function createClient(options) {         	GT:                                        
empty gt: PR:ARIFICATION REQUIRED                     	GT:                                        
empty gt: PR:th_accuracy(weights.values()))           	GT:                                        
empty gt: PR:INGO_DECIMAL                             	GT:                                        
empty gt: PR:bd(total)                                	GT:                                        
empty gt: PR:ghts = {issue: Decimal(weights[issue]) for issue in self.domain.getIssues()} 	GT:                                        
empty gt: PR:.get(key)                                	GT:                                        
empty gt: PR:Store                                    	GT:               

In [31]:
def score(q_per_filter, use_sources=False):
    results = {filter_type: {
        'precision': torch.tensor([]), 
        'recall': torch.tensor([]), 
        'f1': torch.tensor([]), 
        'f3': torch.tensor([]),
        } for filter_type in q_per_filter.keys()}

    for filter_type, queries in q_per_filter.items(): 
        print(f'\ndoing {filter_type}', end='')
        q_per_lang = construct_valid_inputs(queries)

        for lang, queries in q_per_lang.items():
            print(f'\t{lang}: {len(queries)}', end='')

            if len(queries) == 0: continue
            preds   = [q.predictions[q.chosen_model] for q in queries]
            gts     = [q.ground_truth for q in queries]

            if not use_sources:
                precision, recall, f1, f3 = code_bert_score.score(
                    preds, gts, lang, no_punc=False, verbose=True, batch_size=1
                )
            else: 
                sources = [q.prefix[-7984 // 2:] for q in queries] # let's not kill my computer
                precision, recall, f1, f3 = code_bert_score.score(
                    preds, gts, lang, verbose=True, sources=sources, no_punc=False, device='mps', 
                    batch_size=1
                )

            for metric in results[filter_type].keys(): 
                results[filter_type][metric] = torch.cat((results[filter_type][metric], locals()[metric]))

    return results 

In [21]:
results = score(q_per_filter)
print('', flush=True)

for filter_type, metrics in results.items():
    print(f'{filter_type:10}: {[metric + " {:2.4}".format(str(torch.mean(values).item())) for metric, values in metrics.items()]}')


doing joint_h	python: 91



	javascript: 7



	c: 0	cpp: 0	java: 0
doing context	python: 45



	javascript: 3	c: 0	cpp: 1	java: 2




doing feature	python: 20



	javascript: 3	c: 0	cpp: 0	java: 0
doing joint_a	python: 37



	javascript: 10



	c: 0	cpp: 0	java: 0
doing no_filter	python: 42



	javascript: 16



	c: 0	cpp: 0	java: 2
joint_h   : ['precision 0.75', 'recall 0.74', 'f1 0.74', 'f3 0.74']
context   : ['precision 0.77', 'recall 0.76', 'f1 0.77', 'f3 0.76']
feature   : ['precision 0.86', 'recall 0.84', 'f1 0.85', 'f3 0.85']
joint_a   : ['precision 0.67', 'recall 0.65', 'f1 0.66', 'f3 0.65']
no_filter : ['precision 0.60', 'recall 0.60', 'f1 0.60', 'f3 0.60']




In [32]:
results = score(q_per_filter, use_sources=True)
print('', flush=True)

for filter_type, metrics in results.items():
    print(f'{filter_type:10}: {[metric + " {:2.4}".format(str(torch.mean(values).item())) for metric, values in metrics.items()]}')


doing joint_h	python: 91calculating scores...
computing bert embedding.


  0%|          | 0/182 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/91 [00:00<?, ?it/s]



done in 212.05 seconds, 0.43 sentences/sec
	javascript: 7calculating scores...
computing bert embedding.


  0%|          | 0/14 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/7 [00:00<?, ?it/s]



done in 4.96 seconds, 1.41 sentences/sec
	c: 0	cpp: 0	java: 0
doing context	python: 45calculating scores...
computing bert embedding.


  0%|          | 0/90 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/45 [00:00<?, ?it/s]



done in 81.52 seconds, 0.55 sentences/sec
	javascript: 3calculating scores...
computing bert embedding.


  0%|          | 0/6 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/3 [00:00<?, ?it/s]

done in 5.05 seconds, 0.59 sentences/sec
	c: 0	cpp: 1calculating scores...
computing bert embedding.


  0%|          | 0/2 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 5.45 seconds, 0.18 sentences/sec
	java: 2calculating scores...
computing bert embedding.


  0%|          | 0/4 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/2 [00:00<?, ?it/s]



done in 15.02 seconds, 0.13 sentences/sec

doing feature	python: 20calculating scores...
computing bert embedding.


  0%|          | 0/40 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/20 [00:00<?, ?it/s]

done in 96.80 seconds, 0.21 sentences/sec
	javascript: 3calculating scores...
computing bert embedding.


  0%|          | 0/6 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/3 [00:00<?, ?it/s]

done in 5.89 seconds, 0.51 sentences/sec
	c: 0	cpp: 0	java: 0
doing joint_a	python: 37calculating scores...
computing bert embedding.


  0%|          | 0/74 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/37 [00:00<?, ?it/s]



done in 51.56 seconds, 0.72 sentences/sec
	javascript: 10calculating scores...
computing bert embedding.


  0%|          | 0/20 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/10 [00:00<?, ?it/s]



done in 16.51 seconds, 0.61 sentences/sec
	c: 0	cpp: 0	java: 0
doing no_filter	python: 42calculating scores...
computing bert embedding.


  0%|          | 0/84 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/42 [00:00<?, ?it/s]



done in 100.28 seconds, 0.42 sentences/sec
	javascript: 16calculating scores...
computing bert embedding.


  0%|          | 0/32 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/16 [00:00<?, ?it/s]



done in 20.47 seconds, 0.78 sentences/sec
	c: 0	cpp: 0	java: 2calculating scores...
computing bert embedding.


  0%|          | 0/4 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/2 [00:00<?, ?it/s]

done in 4.96 seconds, 0.40 sentences/sec

joint_h   : ['precision 0.82', 'recall 0.82', 'f1 0.82', 'f3 0.82']
context   : ['precision 0.85', 'recall 0.85', 'f1 0.85', 'f3 0.85']
feature   : ['precision 0.94', 'recall 0.94', 'f1 0.94', 'f3 0.94']
joint_a   : ['precision 0.89', 'recall 0.88', 'f1 0.88', 'f3 0.88']
no_filter : ['precision 0.76', 'recall 0.77', 'f1 0.76', 'f3 0.76']


### Summary
 
- 387 with ground truth
- 281 supported by CodeBERTScore (others are things like html: 81, php: 48, latex: 35, css: 27 invocations)
    (could consider including html and php as `javascript`, which is supported by CodeBERTScore due to similar syntax)

CodeBERTScore first constructs embeddings, then computes similarity. Embeddings can be constructed together with the `prefix`, but the `prefix` is excluded from the similarity computation (somehow). 

- without prefix for embedding computation
```
joint_a   : ['precision 0.67', 'recall 0.65', 'f1 0.66', 'f3 0.65']
joint_h   : ['precision 0.75', 'recall 0.74', 'f1 0.74', 'f3 0.74']
codeberta : ['precision 0.77', 'recall 0.76', 'f1 0.77', 'f3 0.76']
feature   : ['precision 0.86', 'recall 0.84', 'f1 0.85', 'f3 0.85']
no_filter : ['precision 0.60', 'recall 0.60', 'f1 0.60', 'f3 0.60']
```

- with prefix for embedding computation 
```
joint_a   : ['precision 0.89', 'recall 0.88', 'f1 0.88', 'f3 0.88']
joint_h   : ['precision 0.82', 'recall 0.82', 'f1 0.82', 'f3 0.82']
codeberta : ['precision 0.85', 'recall 0.85', 'f1 0.85', 'f3 0.85']
feature   : ['precision 0.94', 'recall 0.94', 'f1 0.94', 'f3 0.94']
no_filter : ['precision 0.76', 'recall 0.77', 'f1 0.76', 'f3 0.76']
```