In [None]:
import numpy as np
import pandas as pd
import random

import torch
import torch.nn as nn
import re
import pickle as pk

from sklearn.model_selection import cross_validate
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression, LinearRegression
from transformers import AutoTokenizer, AutoModel#RobertaTokenizer, RobertaModel
from scipy.spatial.distance import cdist

from sklearn.utils import shuffle

from tqdm import tqdm

import matplotlib.pyplot as plt
import matplotlib as mpl

np.random.seed(42)
random.seed(42)

In [None]:
wiki_path_o = "data/GPT-4-o-generation/gpt-4-o-wiki-correct-1500.csv"
reddit_path_o = "data/GPT-4-o-generation/gpt-4-o-reddit-1500.csv"
stackexchange_path_o = "data/GPT-4-o-generation/gpt-4-o-stackexchange-1500.csv"

wiki_path_gpt3= "data/davinci_generation/gpt3_davinci_003_300_len_wiki.jsonl"
reddit_path_gpt3 = "data/davinci_generation/gpt3_davinci_003_300_len_reddit.jsonl"
stackexchange_path_gpt3 = "data/davinci_generation/gpt3_davinci_003_300_len_stackexchange.jsonl"

In [None]:
df_gpt3_w = pd.read_json(wiki_path_gpt3, lines=True).sample(frac=1, random_state=42).reset_index()[:1500]
df_gpt3_r = pd.read_json(reddit_path_gpt3, lines=True).sample(frac=1, random_state=42).reset_index()[:1500]
df_gpt3_s = pd.read_json(stackexchange_path_gpt3, lines=True).sample(frac=1, random_state=42).reset_index()[:1500]

df_gpt4_o_w = pd.read_csv(wiki_path_o).sample(frac=1, random_state=42).reset_index()
df_gpt4_o_r = pd.read_csv(reddit_path_o).sample(frac=1, random_state=42).reset_index()
df_gpt4_o_s = pd.read_csv(stackexchange_path_o).sample(frac=1, random_state=42).reset_index()

In [None]:
DEVICE = "cuda:0"

model_path = 'bert-base-uncased' #'roberta-base'
tokenizer_path = model_path

tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
model = AutoModel.from_pretrained(model_path, output_attentions=True, output_hidden_states=True)
model = model.to(DEVICE)

In [None]:
# class to perform estimation of ID with PH 
EPS = 0.000000001
class PH():
    def __init__(self):
        pass

    def fit_transform(self, X):
        mx_points = X.shape[0]
        mn_points = 40
        
        if mx_points < mn_points:
            return 0
        
        step = max(1, ( mx_points - mn_points ) // 7)

        return self.calculate_ph_dim(X, min_points=mn_points, max_points=mx_points, point_jump=step)

    def sample_W(self, W, nSamples):
        '''
        Sample <<nSamples>> points from the cloud <<W>>
        '''
        n = W.shape[0]
        random_indices = np.random.choice(n, size=nSamples, replace=False)
        return W[random_indices]

    def prim_tree(self, adj_matrix, power=1.0):
        '''
        Computation of H0S for a point cloud with distance matrix <<adj_matrix>> by using Prim's algorithm 
        for minimal spanning tree
        '''
        infty = np.max(adj_matrix) + 1.0
    
        dst = np.ones(adj_matrix.shape[0]) * infty
        visited = np.zeros(adj_matrix.shape[0], dtype=bool)
        ancestor = -np.ones(adj_matrix.shape[0], dtype=int)

        v, s = 0, 0.0
        for i in range(adj_matrix.shape[0] - 1):
            visited[v] = 1
            ancestor[dst > adj_matrix[v]] = v
            dst = np.minimum(dst, adj_matrix[v])
            dst[visited] = infty
            
            v = np.argmin(dst)
            
            s += adj_matrix[v][ancestor[v]] ** power
        return s.item()

    def calculate_ph_dim(self, W, min_points, max_points, point_jump, alpha=1.0, restarts=3, resamples=3):
        '''
        Estimation of the intrinsic (upper-box) dimension of the given point cloud W.
        Parameters:
        
        min_points --- size of minimal subsample to draw
        max_points --- size of maximal subsample to draw
        point_jump --- size of step between subsamples
        restarts --- number of iterations at each sampling size
        print_error -- to print or not computational error
        '''
        max_points = W.shape[0]

        m_candidates = []
        for i in range(restarts): 
            test_n = range(min_points, max_points, point_jump)
            lengths = []

            for n in test_n:
                reruns = np.ones(resamples)
                for i in range(resamples):
                    tmp = self.sample_W(W, n)
                    reruns[i] = self.prim_tree(cdist(tmp, tmp), power=alpha)
                lengths.append(np.median(reruns))

            lengths = np.array(lengths)
            x = np.log(np.array(list(test_n)))
            y = np.log(lengths)

            N = len(x)
            divisor = N * (x ** 2).sum() - x.sum() ** 2
            if divisor < EPS:
                divisor = EPS
            m_candidates.append((N * (x * y).sum() - x.sum() * y.sum()) / divisor)
        m = np.mean(m_candidates)
        return alpha / (1 - m)

In [None]:
def text_preprocessing(text):
    # Remove trailing whitespace
    text = re.sub(r'\s+', ' ', text).strip()
    return text

In [None]:
def get_phd(df, components_to_keep, key='text', real=True, verbose=True):
    ph = PH()

    dims, lens, dims1 = [], [], []
    if verbose:
        iterable = tqdm(df[key])
    else:
        iterable = df[key]
    for ss in iterable:
        if not real:
            sss = ss[0]
        else:
            sss = ss
        sss = text_preprocessing(sss)

        inputs = tokenizer(sss, truncation=True, max_length=512, return_tensors="pt")
        inputs = inputs.to(DEVICE)
        with torch.no_grad():
            outp = model(**inputs)     
        #print(outp)
        output_matrix = outp[0][0, 1:-1, :].cpu().numpy()

        #res = ph.fit_transform(output_matrix)
        #dims.append(res)
        
        output_matrix = output_matrix[:, components_to_keep]
        res = ph.fit_transform(output_matrix)

        dims.append(res)
        lens.append(outp[0][1:-1,:].shape[1])
    
    return np.array(dims), lens

In [None]:
to_remove_1 = (330, 551, 77, 217, 397, 219, 286, 724, 276, 22, 461, 201, 131, 689, 414, 526, 122, 620, 767, 766, 746, 763, 707, 545, 506, 450, 761, 757, 762, 755, 117, 759, 734, 730, 725, 723, 521, 248, 704, 744, 743, 497, 760, 738, 722, 721, 715, 711, 710, 713, 712, 708, 706, 700, 703, 697, 681, 685, 742, 747, 748, 686, 208, 690, 680, 682, 675, 672, 668, 658, 654, 657, 665, 673, 660, 688, 659, 692, 649, 650, 602, 648, 666, 636, 642, 635, 634, 621, 644, 653, 719, 670, 640, 669, 628, 597, 576, 702, 645, 425, 452, 599, 544, 728, 718, 638, 739, 632, 460, 593, 629, 579, 619, 590, 571, 613, 616, 705, 693, 578, 677, 671, 601, 695, 716, 569, 567, 317, 758, 741, 691, 633, 379, 696, 133, 534, 529, 401, 631, 358, 618, 622, 605, 683, 410, 623, 607, 242, 581, 553, 192, 651, 472, 751, 417, 501, 470, 637, 726, 754, 591, 584, 566, 565, 594, 626, 556, 575, 598, 547, 543, 532, 552, 583, 525, 509, 555, 562, 524, 577, 630, 592, 549, 537, 559, 563, 493, 531, 522, 515, 533, 518, 513, 507, 505, 495, 587, 489, 512, 617, 431, 494, 487, 558, 516, 514, 406, 300, 500, 476, 477, 484, 717, 346, 374, 372, 483, 482, 481, 646, 427, 480, 474, 508, 475, 469, 510, 490, 471, 260, 456, 463, 435, 416, 432, 535, 387, 676, 110, 568, 151, 504, 464, 458, 438, 370, 423, 465, 383, 412, 402, 395, 274, 277, 426, 434, 288, 8, 709, 442, 603, 188, 457, 326, 341, 444, 335, 322, 294, 304, 291, 336, 378, 313, 321, 266, 367, 437, 314, 385, 347, 141, 337, 301, 275, 298, 297, 279, 338, 389, 305, 256, 267, 303, 376, 400, 429, 355, 268, 391, 281, 340, 352, 253, 272, 83, 143, 43, 145, 318, 290, 224, 407, 271, 356, 250, 139, 280, 334, 459, 405, 246, 228, 366, 295, 273, 73, 354, 306, 557, 351, 14, 445, 610, 390, 546, 572, 436, 369, 420, 466, 375, 230, 384, 243, 241, 502, 179, 284, 212, 156, 261, 252, 249, 48, 152, 81, 190, 56, 200, 237, 198, 240, 220, 296, 196, 129, 210, 270, 174, 162, 178, 142, 269, 173, 121, 315, 125, 186, 182, 107, 184, 137, 104, 101, 49, 4, 114, 181, 185, 118, 94, 119, 130, 161, 462, 44, 199, 92, 57, 109, 96, 11, 41, 35, 52, 392, 320, 140, 26, 411, 278, 343, 362, 123, 740, 745, 171, 614, 589, 523, 639, 287, 364, 408, 382, 328, 582, 731, 102, 488, 155, 311, 27, 307, 433, 302, 727, 421, 455, 46, 449, 194, 164, 54, 229, 415, 319, 144, 227, 714, 609, 560, 732, 154, 258, 75, 396, 503, 554, 399, 499, 413, 418, 147, 345, 750, 342, 520, 257, 371, 168, 323, 377, 548, 612, 403, 106, 447, 211, 428, 542, 699, 180, 479, 663, 753, 643, 316, 394, 332, 684, 439, 310, 59, 204, 263, 293, 527, 485, 541, 167, 424, 373, 289, 128, 127, 37, 187, 231, 225, 667, 233, 247, 31, 292, 126, 234, 441, 160, 207, 120, 72, 34, 12, 2, 135, 98, 112, 729, 519, 236, 86, 28, 65, 309, 197, 146, 244, 66, 132, 191, 136, 70, 88, 357, 20, 103, 698, 108, 627, 134, 344, 312, 53, 138, 89, 216, 440, 19, 0, 50, 339, 69, 60, 36, 84, 93, 32, 80, 17, 39, 6, 25, 158, 203, 221, 113, 283, 9, 540, 100, 206, 61, 664, 625, 251, 539, 585, 687, 454, 388, 63, 149, 386, 51, 380, 166, 99, 538, 486, 95, 157, 33, 530, 91, 245, 90, 467, 40, 349, 737, 285, 1, 299, 239, 393, 232, 308, 215, 177, 150, 615, 18, 511, 116, 368, 701, 679, 153, 528, 363, 222, 254, 600, 641, 381, 764, 214, 282, 47, 148, 361, 661, 264, 238, 536, 652, 76, 736, 331, 213, 172, 360, 68, 329, 226, 21, 655, 74, 350, 561, 16, 175, 608, 498, 29, 588, 595, 550, 422, 183, 176, 570, 468, 756, 451, 62, 348, 448, 404, 105, 7, 218, 30, 662, 674, 327, 492, 325, 443, 517, 10, 79, 255, 478, 473, 586, 324, 23, 491, 13, 656, 64, 195, 170, 45, 58, 574, 205, 165, 163, 720, 359, 398, 111, 202, 262, 169, 78, 189, 765, 446, 430, 24, 749, 235, 647, 5, 223, 419, 735, 38, 580, 42, 611)
to_remove_2 = (731, 624, 131, 219, 330, 652, 749, 672, 764, 286, 490, 4, 397, 102, 331, 240, 366, 29, 467, 655, 491, 110, 585, 611, 78, 534, 425, 660, 680, 435, 342, 613, 283, 639, 538, 699, 647, 428, 130, 166, 738, 751, 528, 333, 687, 285, 278, 742, 590, 634, 282, 465, 697, 306, 223, 60, 177, 416, 455, 683, 503, 546, 695, 229, 463, 690, 117, 599, 227, 714, 251, 661, 732, 512, 653, 730, 572, 737, 682, 522, 446, 645, 288, 176, 197, 182, 663, 225, 312, 469, 249, 84, 570, 406, 46, 623, 552, 502, 241, 741, 693, 541, 236, 676, 156, 605, 358, 498, 186, 357, 164, 412, 763, 724, 549, 148, 727, 461, 517, 495, 383, 626, 615, 755, 668, 62, 313, 355, 678, 295, 75, 352, 484, 296, 317, 191, 632, 280, 235, 426, 106, 363, 155, 111, 293, 384, 21, 619, 171, 127, 520, 429, 221, 188, 725, 766, 198, 575, 504, 141, 582, 638, 242, 213, 74, 686, 555, 51, 343, 13, 705, 506, 89, 550, 180, 717, 262, 483, 97, 547, 607, 73, 53, 722, 557, 628, 370, 119, 519, 86, 729, 588, 108, 294, 561, 445, 348, 245, 407, 511, 14, 487, 665, 478, 614, 279, 743, 681, 444, 303, 109, 526, 548, 173, 216, 151, 334, 497, 320, 222, 457, 529, 136, 472, 589, 399, 123, 356, 689, 6, 36, 707, 567, 712, 47, 485, 302, 601, 17, 64, 761, 754, 760, 746, 718, 250, 659, 709, 759, 736, 723, 703, 688, 748, 401, 728, 642, 602, 701, 657, 704, 338, 598, 584, 656, 629, 620, 654, 710, 713, 603, 637, 696, 658, 706, 139, 700, 762, 574, 336, 23, 368, 576, 255, 650, 644, 482, 635, 684, 494, 636, 437, 510, 671, 609, 563, 617, 594, 677, 610, 734, 597, 593, 692, 116, 583, 726, 466, 571, 543, 622, 440, 565, 533, 539, 500, 719, 470, 460, 488, 515, 509, 631, 562, 501, 579, 459, 514, 531, 396, 471, 284, 545, 113, 480, 400, 641, 415, 369, 450, 559, 475, 523, 434, 527, 442, 556, 757, 513, 569, 346, 365, 507, 449, 458, 431, 508, 493, 44, 305, 354, 174, 273, 675, 505, 386, 423, 486, 215, 381, 464, 414, 185, 744, 392, 739, 640, 447, 299, 337, 272, 257, 7, 298, 270, 41, 322, 580, 516, 361, 18, 353, 420, 256, 372, 8, 382, 554, 290, 269, 390, 344, 206, 404, 708, 408, 246, 377, 304, 271, 362, 421, 48, 618, 630, 328, 403, 592, 566, 443, 477, 564, 310, 158, 720, 388, 389, 456, 387, 220, 276, 474, 192, 332, 160, 170, 70, 179, 54, 12, 542, 181, 138, 481, 405, 573, 715, 147, 691, 1, 27, 323, 69, 114, 162, 45, 260, 321, 411, 237, 608, 277, 274, 140, 578, 267, 209, 393, 612, 254, 395, 238, 758, 621, 244, 711, 169, 268, 287, 258, 153, 107, 202, 318, 224, 0, 228, 253, 275, 150, 172, 178, 263, 616, 753, 375, 349, 92, 300, 430, 394, 479, 281, 26, 168, 315, 544, 685, 208, 340, 537, 184, 307, 427, 297, 424, 31, 52, 142, 88, 132, 532, 129, 391, 134, 595, 103, 149, 87, 525, 367, 266, 226, 289, 669, 234, 499, 604, 441, 210, 143, 163, 135, 49, 190, 90, 144, 698, 120, 2, 207, 492, 99, 360, 413, 104, 96, 25, 667, 448, 374, 339, 175, 187, 16, 20, 152, 133, 740, 65, 211, 518, 452, 118, 410, 716, 292, 200, 56, 756, 101, 230)
to_remove = set(to_remove_1).intersection(set(to_remove_2))
all_components = [i for i in range(768)]
COMPONENTS_TO_KEEP = list(set(all_components).difference(to_remove))
components_to_keep = COMPONENTS_TO_KEEP

In [None]:
cls_gpt_w_3, lens_gpt_w_3 = get_phd(df_gpt3_w, components_to_keep, 'gen_completion', real=True)
cls_human_w_3, lens_human_w_3 = get_phd(df_gpt3_w, components_to_keep, 'gold_completion', real=True)

cls_gpt_r_3, lens_gpt_r_3 = get_phd(df_gpt3_r, components_to_keep, 'gen_completion', real=True)
cls_human_r_3, lens_human_r_3 = get_phd(df_gpt3_r, components_to_keep, 'gold_completion', real=True)

cls_gpt_s_3, lens_gpt_s_3 = get_phd(df_gpt3_s, components_to_keep, 'gen_completion', real=True)
cls_human_s_3, lens_human_s_3 = get_phd(df_gpt3_s, components_to_keep, 'gold_completion', real=True)

In [None]:
cls_gpt_w_4_o, lens_gpt_w_4_o = get_phd(df_gpt4_o_w, components_to_keep, 'gen_completion', real=True)
cls_human_w_4_o, lens_human_w_4_o = get_phd(df_gpt4_o_w, components_to_keep, 'gold_completion', real=True)

cls_gpt_r_4_o, lens_gpt_r_4_o = get_phd(df_gpt4_o_r, components_to_keep, 'gen_completion', real=True)
cls_human_r_4_o, lens_human_r_4_o = get_phd(df_gpt4_o_r, components_to_keep, 'gold_completion', real=True)

cls_gpt_s_4_o, lens_gpt_s_4_o = get_phd(df_gpt4_o_s, components_to_keep, 'gen_completion', real=True)
cls_human_s_4_o, lens_human_s_4_o = get_phd(df_gpt4_o_s, components_to_keep, 'gold_completion', real=True)

In [None]:
X_train_sets = {"gpt_3_wikipedia": (cls_gpt_w_3, cls_human_w_3), 
                "gpt_3_reddit": (cls_gpt_r_3, cls_human_r_3), 
                "gpt_3_stackexchange": (cls_gpt_s_3, cls_human_s_3),
                
                "gpt_4_o_wikipedia": (cls_gpt_w_4_o, cls_human_w_4_o),
                "gpt_4_o_reddit": (cls_gpt_r_4_o, cls_human_r_4_o), 
                "gpt_4_o_stackexchange": (cls_gpt_s_4_o, cls_human_s_4_o)} 

pk.dump(X_train_sets, open("phd_train_sets_Berta_after.pk", "wb"))