### Setup: Install and import

In [None]:
!pip install transformers torch
!pip install nltk
!pip install tqdm
!pip install openai

In [None]:
# In order to make things work on google drive
from google.colab import drive

drive.mount('/content/gdrive')

In [3]:
import random
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from transformers import RobertaTokenizer, RobertaModel, RobertaConfig, AdamW
import nltk

from tqdm import tqdm
import numpy as np

from openai import OpenAI

In [None]:
%env OPENAI_API_KEY # PUT UR API KEY HERE

In [None]:
client = OpenAI()

# A small example
completion = client.chat.completions.create(
  model="gpt-3.5-turbo",
  messages=[
    {"role": "system", "content": "You are a poetic assistant, skilled in explaining complex programming concepts with creative flair."},
    {"role": "user", "content": "Compose a poem that explains the concept of recursion in programming."}
  ]
)

print(completion.choices[0].message.content)

### Load Pre-trained RoBERTa Model and Tokenizer

In [None]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
roberta_base = RobertaModel.from_pretrained('roberta-base')

### Dataset

In [None]:
!pip install datasets

In [None]:
# GPT- wiki-intro
# https://huggingface.co/datasets/aadityaubhat/GPT-wiki-intro
from datasets import load_dataset

dataset = load_dataset("aadityaubhat/GPT-wiki-intro")['train']

In [None]:
# truncate
def truncate(example):
    """
    Truncate 'wiki_intro' and 'generated_intro' to shorter length
    """
    min_length = min(len(example['wiki_intro']), len(example['generated_intro']))
    truncated_wiki_intro = example['wiki_intro'][:min_length]
    truncated_generated_intro = example['generated_intro'][:min_length]

    return {
        'wiki_intro': truncated_wiki_intro,
        'generated_intro': truncated_generated_intro,
        'title_len': example['title_len'],
        'wiki_intro_len': example['wiki_intro_len'],
        'generated_intro_len': example['generated_intro_len'],
        'prompt_tokens': example['prompt_tokens'],
        'generated_text_tokens': example['generated_text_tokens']
    }


# Wiki_train_data = Wiki_train_data.map(truncate)
# Wiki_val_data = Wiki_val_data.map(truncate)
# Wiki_test_data = Wiki_test_data.map(truncate)
Wiki_data = dataset.map(truncate)

In [10]:
# Generate labels
Wiki_texts = Wiki_data['wiki_intro'] + Wiki_data['generated_intro']

# 1 for human generated, 0 for machine generated
Wiki_labels = [1] * len(Wiki_data['wiki_intro']) + [0] * len(Wiki_data['generated_intro'])

In [11]:
def downsample_data(texts, labels, num_samples=2000):
    combined_data = list(zip(texts, labels))
    sampled_data = random.sample(combined_data, num_samples)
    sampled_texts, sampled_labels = zip(*sampled_data)
    sampled_indices = [texts.index(text) for text, label in sampled_data]

    return list(sampled_texts), list(sampled_labels), sampled_indices


In [12]:
Wiki_sampled_texts, Wiki_sampled_labels, Wiki_sampled_indices = downsample_data(Wiki_texts, Wiki_labels)

Wiki_train_texts = Wiki_sampled_texts[:1700]
Wiki_train_labels = Wiki_sampled_labels[:1700]
Wiki_train_indices = Wiki_sampled_indices[:1700]

Wiki_val_texts = Wiki_sampled_texts[1700:1850]
Wiki_val_labels = Wiki_sampled_labels[1700:1850]
Wiki_val_indices = Wiki_sampled_indices[1700:1850]

Wiki_test_texts = Wiki_sampled_texts[1850:]
Wiki_test_labels = Wiki_sampled_labels[1850:]
Wiki_test_indices = Wiki_sampled_indices[1850:]

In [None]:
print(f"Training Data Size: {len(Wiki_train_texts)}")
print(f"Validation Data Size: {len(Wiki_val_texts)}")
print(f"Testing Data Size: {len(Wiki_test_texts)}")

In [None]:
# check if the splits are mostly balanced
print(sum(Wiki_train_labels))
print(sum(Wiki_val_labels))
print(sum(Wiki_test_labels))

In [None]:
# PubMedQA
# https://pubmedqa.github.io/

# a directory structure in Files:
# data/ori_pqaa.json      - 2.6 MB Downloaded from https://drive.google.com/file/d/15v1x6aQDlZymaHGP7cZJZZYFfeJt2NdS/view
# data/ori_pqal.json      - 533.4 MB Downloaded from https://github.com/pubmedqa/pubmedqa/blob/master/data/ori_pqal.json
import json
import random

ori_pqal_path = './gdrive/MyDrive/CPSC_588_dataset/ori_pqal.json'
with open(ori_pqal_path, 'r') as file:
    ori_pqal = json.load(file)
machine_generated_dataset = [{"text": item["LONG_ANSWER"], "label": "0"} for item in ori_pqal.values()]

ori_pqaa_path = './gdrive/MyDrive/CPSC_588_dataset/ori_pqaa.json'
with open(ori_pqaa_path, 'r') as file:
    ori_pqaa = json.load(file)
human_generated_dataset = [{"text": item["LONG_ANSWER"], "label": "1"} for item in ori_pqaa.values()]

human_generated_dataset = random.sample(human_generated_dataset, 1000)

print(machine_generated_dataset[0]) # {'text': '...', 'label': 'machine_generated'}
print(human_generated_dataset[0]) # {'text': '...', 'label': 'human_generated'}

In [16]:
combined_dataset = machine_generated_dataset + human_generated_dataset

texts = [item['text'] for item in combined_dataset]
labels = [int(item['label']) for item in combined_dataset]

PMQA_train_texts, temp_texts, PMQA_train_labels, temp_labels = train_test_split(
    texts, labels, test_size=0.2, random_state=42)

PMQA_val_texts, PMQA_test_texts, PMQA_val_labels, PMQA_test_labels = train_test_split(
    temp_texts, temp_labels, test_size=0.5, random_state=42)

In [17]:
PMQA_sampled_texts, PMQA_sampled_labels, PMQA_sampled_indices = downsample_data(texts, labels, 2000)

PMQA_train_texts = PMQA_sampled_texts[:1700]
PMQA_train_labels = PMQA_sampled_labels[:1700]
PMQA_train_indices = PMQA_sampled_indices[:1700]

PMQA_val_texts = PMQA_sampled_texts[1700:1850]
PMQA_val_labels = PMQA_sampled_labels[1700:1850]
PMQA_val_indices = PMQA_sampled_indices[1700:1850]

PMQA_test_texts = PMQA_sampled_texts[1850:]
PMQA_test_labels = PMQA_sampled_labels[1850:]
PMQA_test_indices = PMQA_sampled_indices[1850:]

In [None]:
print(f"Training Data Size: {len(PMQA_train_texts)}")
print(f"Validation Data Size: {len(PMQA_val_texts)}")
print(f"Testing Data Size: {len(PMQA_test_texts)}")

In [None]:
# Load testing dataset from json
load_path = './gdrive/MyDrive/CPSC_588_dataset/ghostbuster/ghostbuster_data.json'

with open(load_path, 'r', encoding='utf-8') as json_file:
    loaded_data = json.load(json_file)

gb_texts = loaded_data['texts']
gb_labels = loaded_data['labels']


gb_sampled_texts, gb_sampled_labels, gb_sampled_indices = downsample_data(gb_texts, gb_labels, 200)
print(sum(gb_sampled_labels))

### Data Preparation

In [19]:
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, item):
        text = str(self.texts[item])
        label = self.labels[item]

        encoding = self.tokenizer.encode_plus(
          text,
          add_special_tokens=True,
          max_length=self.max_len,
          return_token_type_ids=False,
          padding='max_length',
          return_attention_mask=True,
          return_tensors='pt',
          truncation=True
        )

        return {
          'text': text,
          'input_ids': encoding['input_ids'].flatten(),
          'attention_mask': encoding['attention_mask'].flatten(),
          'labels': torch.tensor(label, dtype=torch.long)
        }

Wiki_train_dataset = TextDataset(Wiki_train_texts, Wiki_train_labels, tokenizer)
Wiki_val_dataset = TextDataset(Wiki_val_texts, Wiki_val_labels, tokenizer)
Wiki_test_dataset = TextDataset(Wiki_test_texts, Wiki_test_labels, tokenizer)

Wiki_train_loader = DataLoader(Wiki_train_dataset, batch_size=16, shuffle=True)
Wiki_val_loader = DataLoader(Wiki_val_dataset, batch_size=16, shuffle=True)
Wiki_test_loader = DataLoader(Wiki_test_dataset, batch_size=16, shuffle=True)

PMQA_train_dataset = TextDataset(PMQA_train_texts, PMQA_train_labels, tokenizer)
PMQA_val_dataset = TextDataset(PMQA_val_texts, PMQA_val_labels, tokenizer)
PMQA_test_dataset = TextDataset(PMQA_test_texts, PMQA_test_labels, tokenizer)

PMQA_train_loader = DataLoader(PMQA_train_dataset, batch_size=16, shuffle=True)
PMQA_val_loader = DataLoader(PMQA_val_dataset, batch_size=16, shuffle=False)
PMQA_test_loader = DataLoader(PMQA_test_dataset, batch_size=16, shuffle=False)

### Create a Custom Classifier

In [20]:
class RobertaClassifier(nn.Module):
    def __init__(self, roberta_base, stat_emb_dim, fusion_type='early'):
        super(RobertaClassifier, self).__init__()
        self.fusion_type = fusion_type
        self.roberta = roberta_base

        # Non-linear transformation for statistical embeddings
        self.stat_emb_transform = nn.Linear(stat_emb_dim, stat_emb_dim)
        self.activation = nn.ReLU()

        if fusion_type == 'early':
            self.classifier = nn.Linear(roberta_base.config.hidden_size + stat_emb_dim, 2)
        else:  # late fusion
            self.classifier = nn.Linear(roberta_base.config.hidden_size, 2)
            self.stat_emb_classifier = nn.Linear(stat_emb_dim, 2)

            # Conditional layer
            self.conditional_weights = nn.Linear(stat_emb_dim, roberta_base.config.hidden_size)

    def forward(self, input_ids, attention_mask, statistical_features):
        outputs = self.roberta(input_ids, attention_mask)
        pooled_output = outputs[1]

        # Apply non-linear transformation to statistical features
        transformed_stat_features = self.activation(self.stat_emb_transform(statistical_features))

        if self.fusion_type == 'early':
            combined_output = torch.cat((pooled_output, transformed_stat_features), dim=1)
            return self.classifier(combined_output)
        else:  # late fusion
            # Apply conditional layer
            conditional_weights = torch.sigmoid(self.conditional_weights(transformed_stat_features))
            conditioned_roberta_output = pooled_output * conditional_weights

            logits_from_roberta = self.classifier(conditioned_roberta_output)
            logits_from_stat_emb = self.stat_emb_classifier(transformed_stat_features)
            combined_logits = logits_from_roberta + logits_from_stat_emb
            return combined_logits

### Create a Custom Attacker

In [21]:
class LLMAttacker():
    def __init__(self, model="gpt-3.5-turbo", paraphrase_prompt=None, in_context_prompt=None):
        self.model = model
        self.paraphrase_prompt = paraphrase_prompt
        self.in_context_prompt = in_context_prompt

    def generate_in_context_prompt(self, features):
        return None

    def attack(self, mode, text, features=None):
        message = ""

        # Human written and predicted human written
        if mode == "paraphrase":
            completion= client.chat.completions.create(
                model=self.model,
                messages=[
                    {"role": "system", "content": self.paraphrase_prompt},
                    {"role": "user", f"content": "Text: {text}"}
                ],
                max_tokens=128
            )
            message = completion.choices[0].message.content.strip()

        # AI-generated and predicted AI-generated
        elif mode == "in_context":
            completion = client.chat.completions.create(
                model=self.model,
                messages=[
                    {"role": "system", "content": self.in_context_prompt},
                    {"role": "user", f"content": "Text: {text}"}
                ],
                max_tokens=128
            )
            message = completion.choices[0].message.content.strip()

        else:
            raise ValueError("Invalid attack mode")

        return message


In [22]:
# # word embedding
# from transformers import RobertaTokenizer, RobertaModel

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# output_file = './gdrive/MyDrive/CPSC_588_dataset/word_embeddings_wiki.pt'

# tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
# model = RobertaModel.from_pretrained('roberta-base').to(device)

# pbar = tqdm(total=len(Wiki_train_texts), desc="Processing texts")

# embeddings_list = []

# for text in Wiki_train_texts:
#     inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
#     outputs = model(**inputs)
#     embeddings = outputs.last_hidden_state.mean(dim=1).detach().cpu()
#     embeddings_list.append(embeddings)
#     pbar.update(1)

# pbar.close()

# embeddings = torch.cat(embeddings_list, dim=0)

# torch.save(embeddings, output_file)

# print(f"Embeddings saved to {output_file}")

In [None]:
word_embeddings_loaded = torch.load("./gdrive/MyDrive/CPSC_588_dataset/word_embeddings_wiki.pt")
print(word_embeddings_loaded[:5])

In [24]:
# # word embedding of PMQA
# from transformers import RobertaTokenizer, RobertaModel

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# output_file = './gdrive/MyDrive/CPSC_588_dataset/word_embeddings_pmqa.pt'

# tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
# model = RobertaModel.from_pretrained('roberta-base').to(device)

# pbar = tqdm(total=len(PMQA_train_texts), desc="Processing texts")

# embeddings_list = []

# for text in PMQA_train_texts:
#     inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
#     outputs = model(**inputs)
#     embeddings = outputs.last_hidden_state.mean(dim=1).detach().cpu()
#     embeddings_list.append(embeddings)
#     pbar.update(1)

# pbar.close()

# embeddings = torch.cat(embeddings_list, dim=0)

# torch.save(embeddings, output_file)

# print(f"Embeddings saved to {output_file}")

In [None]:
word_embeddings_loaded = torch.load("./gdrive/MyDrive/CPSC_588_dataset/word_embeddings_pmqa.pt")
print(word_embeddings_loaded[:5])
print(len(word_embeddings_loaded))

### Intrinsic Dimension Estimation for Robust Detection of AI-Generated Texts

---


Paper here: https://arxiv.org/pdf/2306.04723.pdf

Code here: https://github.com/ArGintum/GPTID/blob/main/IntrinsicDim.py

Maybe a good feature.

In [26]:
# https://github.com/ArGintum/GPTID/blob/main/IntrinsicDim.py

import numpy as np

from scipy.spatial.distance import cdist
from threading import Thread

MINIMAL_CLOUD = 47

def prim_tree(adj_matrix, alpha=1.0):
    infty = np.max(adj_matrix) + 10

    dst = np.ones(adj_matrix.shape[0]) * infty
    visited = np.zeros(adj_matrix.shape[0], dtype=bool)
    ancestor = -np.ones(adj_matrix.shape[0], dtype=int)

    v, s = 0, 0.0
    for i in range(adj_matrix.shape[0] - 1):
        visited[v] = 1
        ancestor[dst > adj_matrix[v]] = v
        dst = np.minimum(dst, adj_matrix[v])
        dst[visited] = infty

        v = np.argmin(dst)
        s += (adj_matrix[v][ancestor[v]] ** alpha)

    return s.item()

def process_string(sss):
    return sss.replace('\n', ' ').replace('  ', ' ')

class PHD():
    def __init__(self, alpha=1.0, metric='euclidean', n_reruns=3, n_points=7, n_points_min=3):
        '''
        Initializes the instance of PH-dim computer
        Parameters:
            1) alpha --- real-valued parameter Alpha for computing PH-dim (see the reference paper). Alpha should be chosen lower than
        the ground-truth Intrinsic Dimensionality; however, Alpha=1.0 works just fine for our kind of data.
            2) metric --- String or Callable, distance function for the metric space (see documentation for Scipy.cdist)
            3) n_reruns --- Number of restarts of whole calculations (each restart is made in a separate thread)
            4) n_points --- Number of subsamples to be drawn at each subsample
            5) n_points_min --- Number of subsamples to be drawn at larger subsamples (more than half of the point cloud)
        '''
        self.alpha = alpha
        self.n_reruns = n_reruns
        self.n_points = n_points
        self.n_points_min = n_points_min
        self.metric = metric
        self.is_fitted_ = False

    def _sample_W(self, W, nSamples):
        n = W.shape[0]
        random_indices = np.random.choice(n, size=nSamples, replace=False)
        return W[random_indices]

    def _calc_ph_dim_single(self, W, test_n, outp, thread_id):
        lengths = []
        for n in test_n:
            if W.shape[0] <= 2 * n:
                restarts = self.n_points_min
            else:
                restarts = self.n_points

            reruns = np.ones(restarts)
            for i in range(restarts):
                tmp = self._sample_W(W, n)
                reruns[i] = prim_tree(cdist(tmp, tmp, metric=self.metric), self.alpha)

            lengths.append(np.median(reruns))
        lengths = np.array(lengths)

        x = np.log(np.array(list(test_n)))
        y = np.log(lengths)
        N = len(x)
        outp[thread_id] = (N * (x * y).sum() - x.sum() * y.sum()) / (N * (x ** 2).sum() - x.sum() ** 2)

    def fit_transform(self, X, y=None, min_points=50, max_points=512, point_jump=40):
        '''
        Computing the PH-dim
        Parameters:
            1) X --- point cloud of shape (n_points, n_features),
            2) y --- fictional parameter to fit with Sklearn interface
            3) min_points --- size of minimal subsample to be drawn
            4) max_points --- size of maximal subsample to be drawn
            5) point_jump --- step between subsamples
        '''
        ms = np.zeros(self.n_reruns)
        test_n = range(min_points, max_points, point_jump)
        threads = []

        for i in range(self.n_reruns):
            threads.append(Thread(target=self._calc_ph_dim_single, args=[X, test_n, ms, i]))
            threads[-1].start()

        for i in range(self.n_reruns):
            threads[i].join()

        m = np.mean(ms)
        return 1 / (1 - m)


### Calculate the statistical features

In [None]:
import nltk
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('tagsets')
nltk.help.upenn_tagset()


In [None]:
from nltk.data import load
upenn_tagset_info = load('help/tagsets/upenn_tagset.pickle')
upenn_tagset = list(upenn_tagset_info.keys())
#print(upenn_tagset)
#print(len(upenn_tagset))
for index, tag in enumerate(upenn_tagset):
    print(f"index:{index} , tag:{tag}")
upenn_tagset_meaningful = upenn_tagset[0:3] + upenn_tagset[4:9] + upenn_tagset[10:14] + upenn_tagset[15:19] + upenn_tagset[25:]
#print(upenn_tagset_meaningful)
#print(len(upenn_tagset_meaningful))

In [29]:
import json
import pandas as pd
import os


In [30]:
def calculate_tag_dist(text: str):
    text = nltk.tokenize.word_tokenize(text)
    tagged_text = nltk.pos_tag(text)
    tag_fd = nltk.FreqDist(tag for (word, tag) in tagged_text)
    tag_count = [tag_fd.get(tag, 0) for tag in upenn_tagset_meaningful]
    count_sum = sum(tag_count)
    tag_dist = [count / count_sum for count in tag_count]
    # tag_dist = [tag_fd.freq(tag) for tag in tag_fd]
    # print(dict(tag_fd))
    # print("length", len(tag_dist))
    # print(tag_dist)
    return tag_dist

def calculate_statistical_features_pos(input_text):
    # Implement the logic to calculate statistical features
    # This function should return a tensor of shape [batch_size, stat_emb_dim]
    # pos tag distribution
    pos_tag_dists = [calculate_tag_dist(text) for text in input_text]
    return torch.tensor(pos_tag_dists)

# pos_embeddings = calculate_statistical_features_pos(Wiki_train_texts)
# torch.save(pos_embeddings, "/content/gdrive/MyDrive/CPSC_588_dataset/pos_embeddings_wiki.pt")


In [31]:
# pos_embeddings = calculate_statistical_features_pos(PMQA_train_texts)
# torch.save(pos_embeddings, "/content/gdrive/MyDrive/CPSC_588_dataset/pos_embeddings_pmqa.pt")

In [32]:
# pos_embeddings_wiki = torch.load("/content/gdrive/MyDrive/CPSC_588_dataset/pos_embeddings_wiki.pt")
# print(pos_embeddings_wiki[:2])
# print(pos_embeddings_wiki.shape)
# pos_embeddings_pmqa = torch.load("/content/gdrive/MyDrive/CPSC_588_dataset/pos_embeddings_pmqa.pt")
# print(pos_embeddings_pmqa[:2])
# print(pos_embeddings_pmqa.shape)

In [None]:
!pip install py-readability-metrics

In [34]:
from readability import Readability

In [35]:
def calculate_readability_metrics(text):
    metrics = []
    r = Readability(text)
    metrics.append(r.flesch_kincaid().score)
    metrics.append(r.flesch().score)
    metrics.append(r.gunning_fog().score)
    metrics.append(r.coleman_liau().score)
    metrics.append(r.dale_chall().score)
    metrics.append(r.ari().score)
    metrics.append(r.linsear_write().score)
    # metrics.append(r.smog().score)
    metrics.append(r.spache().score)
    return metrics

def calculate_statistical_features_readability(input_text):
    # Implement the logic to calculate statistical features
    # This function should return a tensor of shape [batch_size, stat_emb_dim]
    # readability measure
    readability_metrics = []
    for text in input_text:
        appended = False
        text_pad = text
        for i in range(100):
            try:
                readability_metrics.append(calculate_readability_metrics(text_pad))
                appended = True
            except:
                text_pad += text
                continue
            break
        if i == 99:
            if not appended:
                readability_metrics.append(([0] * 8))
    return torch.tensor(readability_metrics)

# readability_embeddings = calculate_statistical_features_readability(Wiki_train_texts)
# torch.save(readability_embeddings, "/content/gdrive/MyDrive/CPSC_588_dataset/readability_embeddings_wiki.pt")

In [36]:
# readability_embeddings = calculate_statistical_features_readability(PMQA_train_texts)
# torch.save(readability_embeddings, "/content/gdrive/MyDrive/CPSC_588_dataset/readability_embeddings_pmqa.pt")

In [37]:
# read_embeddings_wiki = torch.load("/content/gdrive/MyDrive/CPSC_588_dataset/readability_embeddings_wiki.pt")
# print(read_embeddings_wiki[:5])
# print(read_embeddings_wiki.shape)
# read_embeddings_pmqa = torch.load("/content/gdrive/MyDrive/CPSC_588_dataset/readability_embeddings_pmqa.pt")
# print(read_embeddings_pmqa[:5])
# print(read_embeddings_pmqa.shape)

In [38]:
def calculate_statistical_features(input_text):
    pos_embs = calculate_statistical_features_pos(input_text)
    read_embs = calculate_statistical_features_readability(input_text)
    return torch.cat((pos_embs, read_embs), dim=1)

### Model Training

In [39]:
def train_epoch_baseline(model, attacker, data_loader, loss_fn, optimizer, device, n_examples, calculate_stat_features):
    model.train()
    losses = []
    correct_predictions = 0

    for d in tqdm(data_loader, total=len(data_loader), desc="Training"):
        input_ids = d["input_ids"].to(device)
        attention_mask = d["attention_mask"].to(device)
        labels = d["labels"].to(device)
        stat_features = calculate_stat_features(d["text"]).to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, statistical_features=stat_features)
        _, preds = torch.max(outputs, dim=1)
        loss = loss_fn(outputs, labels)

        correct_predictions += torch.sum(preds == labels)
        losses.append(loss.item())

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        optimizer.zero_grad()

        return correct_predictions.double() / n_examples, np.mean(losses)

def train_epoch_adversarial_emphasize(model, attacker, data_loader, loss_fn, optimizer, device, n_examples, calculate_stat_features):
    model.train()
    losses = []
    correct_predictions = 0

    for d in tqdm(data_loader, total=len(data_loader), desc="Training"):
        input_ids = d["input_ids"].to(device)
        attention_mask = d["attention_mask"].to(device)
        labels = d["labels"].to(device)
        texts = d["text"]
        stat_features = calculate_stat_features(texts).to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, statistical_features=stat_features)
        _, preds = torch.max(outputs, dim=1)
        loss = loss_fn(outputs, labels)

        # Attack and retrain if the prediction is correct
        for idx, (pred, label, text) in enumerate(zip(preds, labels, texts)):
            if pred == label:
                attack_mode = "paraphrase" if label.item() == 1 else "in_context"  # Assuming 1 is human, 0 is AI
                modified_text = attacker.attack(attack_mode, text, stat_features)
                modified_stat_features = calculate_stat_features([modified_text]).to(device)
                modified_output = model(input_ids=input_ids, attention_mask=attention_mask, statistical_features=modified_stat_features)
                modified_loss = loss_fn(modified_output, labels)

                # Add modified loss to original loss
                loss += modified_loss

        correct_predictions += torch.sum(preds == labels)
        losses.append(loss.item() / (len(preds) + 1))

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        optimizer.zero_grad()

    return correct_predictions.double() / n_examples, np.mean(losses)

def train_epoch(model, attacker, data_loader, loss_fn, optimizer, device, n_examples, calculate_stat_features):
    model.train()
    losses = []
    correct_predictions = 0

    for d in tqdm(data_loader, total=len(data_loader), desc="Training"):
        input_ids = d["input_ids"].to(device)
        attention_mask = d["attention_mask"].to(device)
        labels = d["labels"].to(device)
        texts = d["text"]
        stat_features = calculate_stat_features(texts).to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, statistical_features=stat_features)
        _, preds = torch.max(outputs, dim=1)
        loss = loss_fn(outputs, labels)

        adv_labels_list = []
        adv_outputs_list = []

        for idx, (pred, label, text) in enumerate(zip(preds, labels, texts)):
            if pred == label:
                attack_mode = "paraphrase" if label.item() == 1 else "in_context"  # Assuming 1 is human, 0 is AI
                modified_text = attacker.attack(attack_mode, text, stat_features[idx].unsqueeze(0))  # Process one stat_feature at a time
                modified_input_id = input_ids[idx].unsqueeze(0)  # Single instance
                modified_attention_mask = attention_mask[idx].unsqueeze(0)  # Single instance
                modified_stat_features = calculate_stat_features([modified_text]).to(device)

                modified_output = model(input_ids=modified_input_id,
                                        attention_mask=modified_attention_mask,
                                        statistical_features=modified_stat_features)

                # Store adversarial examples
                adv_outputs_list.append(modified_output)
                adv_labels_list.append(label.unsqueeze(0))

        if adv_outputs_list:
            adv_outputs = torch.cat(adv_outputs_list, dim=0)
            adv_labels = torch.cat(adv_labels_list, dim=0)
            combined_outputs = torch.cat((outputs, adv_outputs), dim=0)
            combined_labels = torch.cat((labels, adv_labels), dim=0)
            combined_loss = loss_fn(combined_outputs, combined_labels)
            losses.append(combined_loss.item())
        else:
            losses.append(loss.item())

        correct_predictions += torch.sum(preds == labels)

        total_loss = combined_loss if adv_outputs_list else loss
        total_loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        optimizer.zero_grad()

    return correct_predictions.double() / n_examples, np.mean(losses)

def validate_epoch(model, data_loader, loss_fn, device, n_examples, calculate_stat_features):
    model.eval()
    losses = []
    correct_predictions = 0

    with torch.no_grad():
        for d in data_loader:
            input_ids = d["input_ids"].to(device)
            attention_mask = d["attention_mask"].to(device)
            labels = d["labels"].to(device)
            texts = d["text"]
            stat_features = calculate_stat_features(texts).to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, statistical_features=stat_features)
            _, preds = torch.max(outputs, dim=1)
            loss = loss_fn(outputs, labels)

            correct_predictions += torch.sum(preds == labels)
            losses.append(loss.item())

    return correct_predictions.double() / n_examples, np.mean(losses)

def eval_model(model, data_loader, loss_fn, device, n_examples, calculate_stat_features):
    model = model.eval()
    losses = []
    correct_predictions = 0

    with torch.no_grad():
        for d in tqdm(data_loader, total=len(data_loader), desc="Testing"):
            input_ids = d["input_ids"].to(device)
            attention_mask = d["attention_mask"].to(device)
            labels = d["labels"].to(device)
            stat_features = calculate_stat_features(d["text"]).to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, statistical_features=stat_features)
            _, preds = torch.max(outputs, dim=1)
            loss = loss_fn(outputs, labels)

            correct_predictions += torch.sum(preds == labels)
            losses.append(loss.item())

    return correct_predictions.double() / n_examples, np.mean(losses)


In [None]:
# Initialize model
dataset = "PMQA" # "PMQA"

num_epochs = 5
stat_emb_dim = 44
fusion_type = "late"
LLM_model = "gpt-3.5-turbo"
paraphrase_prompt = "I have a piece of machine-generated text that I need paraphrased to sound more human-like. "
#paraphrase_prompt = "I have a piece of machine-generated text that I need paraphrased to sound more human-like. The text primarily is about "
#+ ("Wikipedia introduction for various topics." if dataset = "Wiki" else "answering research questions with yes/no/maybe.")
in_context_prompt = "I have a piece of human-written text that I need regenerated to maintain a style as close as possible to the original. The goal is to replicate the unique characteristics of the original text, such as its tone, word choice, sentence structure, and overall flow. "
#paraphrase_prompt = "I have a piece of human-written text that I need regenerated to maintain a style as close as possible to the original. The goal is to replicate the unique characteristics of the original text, such as its tone, word choice, sentence structure, and overall flow. Please pay special attention to the nuances in language and the specific manner in which ideas are expressed. The text should read as if it were written by the same author, preserving the essence and subtlety of the original writing. This task is for adversarial training purposes, where the aim is to challenge machine prediction capabilities in distinguishing between human and machine-generated texts. The text in question [include or describe the text here, or specify its main themes or style]. Keep the length and format consistent with the original, and ensure that the regenerated text mirrors the original as closely as possible in all aspects. "

model = RobertaClassifier(roberta_base, stat_emb_dim, fusion_type)
attacker = LLMAttacker(LLM_model, paraphrase_prompt, in_context_prompt)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

optimizer = AdamW(model.parameters(), lr=2e-5)
loss_fn = nn.CrossEntropyLoss().to(device)
stat_fn = calculate_statistical_features

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    print('-' * 10)

    train_acc, train_loss = train_epoch(
        model,
        attacker,
        Wiki_train_loader if dataset == "Wiki" else PMQA_train_loader,
        loss_fn,
        optimizer,
        device,
        len(Wiki_train_dataset) if dataset == "Wiki" else len(PMQA_train_dataset),
        stat_fn
    )
    print(f'Train loss {train_loss}, accuracy {train_acc}')

    val_acc, val_loss = validate_epoch(
        model,
        Wiki_val_loader if dataset == "Wiki" else PMQA_val_loader,
        loss_fn,
        device,
        len(Wiki_val_dataset) if dataset == "Wiki" else len(PMQA_val_dataset),
        stat_fn
    )
    print(f'Validation loss {val_loss}, accuracy {val_acc}')


### Model Evaluation

In [None]:
test_acc, test_loss = eval_model(
    model,
    Wiki_test_loader if dataset == "Wiki" else PMQA_test_loader,
    loss_fn,
    device,
    len(Wiki_test_dataset) if dataset == "Wiki" else len(PMQA_test_dataset),
    stat_fn
)

print(f'Test loss {test_loss}, accuracy {test_acc}')

### Test the Trained Model

In [None]:
def preprocess(texts):
    # Tokenize the texts - this can be a single string or a list of strings
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
    return inputs

def predict(model, texts, device):
    model.eval()

    inputs = preprocess(texts)
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        predictions = torch.argmax(outputs, dim=1)

    return predictions.cpu().numpy()

In [None]:
text = "As ILC2s are elevated in patients with CRSwNP, they may drive nasal polyp formation in CRS. ILC2s are also linked with high tissue and blood eosinophilia and have a potential role in the activation and survival of eosinophils during the Th2 immune response. The association of innate lymphoid cells in CRS provides insights into its pathogenesis."
single_prediction = predict(model, text, device)
print(single_prediction[0])