In [1]:
from transformers import GemmaTokenizer, AutoModelForCausalLM
import torch
import pickle

import numpy as np
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, confusion_matrix

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_id = "google/codegemma-2b"
tokenizer = GemmaTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id
).to('cuda:0')

Downloading shards: 100%|██████████| 2/2 [00:00<00:00, 17.56it/s]
Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|██████████| 2/2 [00:10<00:00,  5.38s/it]


In [3]:
def confusion(labels, predictions, logits):
    cm = confusion_matrix(labels, predictions)

    accuracy = 100 * np.sum(np.diag(cm)) / np.sum(cm)

    true_positives = cm[1, 1]
    false_positives = cm[0, 1]
    false_negatives = cm[1, 0]
    true_negatives = cm[0, 0]
    
    precision = true_positives / (true_positives + false_positives + 1e-5)
    recall = true_positives / (true_positives + false_negatives + 1e-5)
    f1 = 2 * precision * recall / (precision + recall + 1e-5)
    return accuracy, precision, recall, f1, roc_auc_score(labels, logits)


def eval(key, stage, model, tokenizer, max_length=1024):
    file_path = f'{key}_{stage}.pickle'
    with open(file_path, 'rb') as file:
        sources = pickle.load(file)
        
    logits = []
    labels = []
    predictions = []

    model.eval()
    with torch.no_grad():
        for source, dest, label in tqdm(sources):
            if len(source) > max_length:
                source = source[:max_length]
            if len(dest) > max_length:
                dest = dest[:max_length]
            inputs = tokenizer([source, dest], return_tensors="pt", padding=True).to(model.device)
            outputs = model(**inputs).logits[:, -1, :]
            outputs /= outputs.norm(dim=1, keepdim=True)
            logit = (outputs[0] * outputs[1]).sum().detach().item()
            
            predictions.append(0 if logit < 0 else 1)
            logits.append((logit + 1) / 2)
            labels.append(label)
                
        accuracy, precision, recall, f1, auroc = confusion(labels, predictions, labels)
        print('{}_{}: accuracy, precision, recall, f1, auroc: {:.2f}, {:.2f}, {:.2f}, {:.2f}, {:.2f}'\
              .format(key, stage, accuracy, auroc, precision, recall, f1))
            
            
            
        

In [4]:
mapping = {
    'exp_all_m1': 'w_sco',
    'exp_var_m1': 'w_object',
    'exp_val_m1': 'w_string',
    'exp_ast_m1': 'w_control',
    'exp_all_m2': 'sco',
    'exp_var_m2': 'object',
    'exp_val_m2': 'string',
    'exp_ast_m2': 'control',
    'exp_all_m3': 'sco_dead',
    'exp_var_m3': 'object_dead',
    'exp_val_m3': 'string_dead',
    'exp_ast_m3': 'control_dead',
}

In [5]:
for key in mapping.keys():
    eval(key, 'test', model, tokenizer)

100%|██████████| 3602/3602 [19:07<00:00,  3.14it/s]


exp_all_m1_test: accuracy, precision, recall, f1, auroc: 49.53, 1.00, 0.49, 1.00, 0.66


100%|██████████| 3602/3602 [13:13<00:00,  4.54it/s]


exp_var_m1_test: accuracy, precision, recall, f1, auroc: 49.19, 1.00, 0.49, 1.00, 0.66


100%|██████████| 3602/3602 [14:33<00:00,  4.12it/s]


exp_val_m1_test: accuracy, precision, recall, f1, auroc: 50.44, 1.00, 0.50, 1.00, 0.67


100%|██████████| 3602/3602 [15:18<00:00,  3.92it/s]


exp_ast_m1_test: accuracy, precision, recall, f1, auroc: 50.28, 1.00, 0.50, 1.00, 0.67


100%|██████████| 3602/3602 [26:28<00:00,  2.27it/s]


exp_all_m2_test: accuracy, precision, recall, f1, auroc: 49.75, 1.00, 0.50, 0.99, 0.66


100%|██████████| 3602/3602 [20:04<00:00,  2.99it/s]


exp_var_m2_test: accuracy, precision, recall, f1, auroc: 50.47, 1.00, 0.51, 1.00, 0.67


100%|██████████| 3602/3602 [22:16<00:00,  2.70it/s]


exp_val_m2_test: accuracy, precision, recall, f1, auroc: 50.28, 1.00, 0.50, 1.00, 0.67


100%|██████████| 3602/3602 [21:49<00:00,  2.75it/s]


exp_ast_m2_test: accuracy, precision, recall, f1, auroc: 50.03, 1.00, 0.50, 0.99, 0.66


100%|██████████| 3602/3602 [30:46<00:00,  1.95it/s]


exp_all_m3_test: accuracy, precision, recall, f1, auroc: 49.56, 1.00, 0.50, 0.99, 0.66


100%|██████████| 3602/3602 [25:02<00:00,  2.40it/s]


exp_var_m3_test: accuracy, precision, recall, f1, auroc: 50.64, 1.00, 0.51, 1.00, 0.67


100%|██████████| 3602/3602 [29:04<00:00,  2.06it/s]


exp_val_m3_test: accuracy, precision, recall, f1, auroc: 49.78, 1.00, 0.50, 1.00, 0.66


100%|██████████| 3602/3602 [26:11<00:00,  2.29it/s]

exp_ast_m3_test: accuracy, precision, recall, f1, auroc: 49.67, 1.00, 0.50, 0.99, 0.66



