In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Stylometry + LM

## Requirements

In [None]:
!pip install numpy==1.19.4
!pip install PyYAML>=5.4
!pip install spacy==2.2.4
!pip install torch==1.7.0
!pip install torchtext==0.3.1
!pip install tqdm==4.53.0
!pip install pandas==1.1.5
!pip install transformers==4.3.2
!pip install fire==0.4.0
!pip install requests==2.23.0
!pip install tensorboard==2.4.1
!pip install download==0.3.5
!pip install nltk>=3.6.6

!pip install py-readability-metrics
!python -m nltk.downloader punkt
!pip install lexicalrichness

## Utils

In [None]:
import sys
from functools import reduce

from torch import nn
import torch.distributed as dist


def summary(model: nn.Module, file=sys.stdout):
    def repr(model):
        # We treat the extra repr like the sub-module, one item per line
        extra_lines = []
        extra_repr = model.extra_repr()
        # empty string will be split into list ['']
        if extra_repr:
            extra_lines = extra_repr.split('\n')
        child_lines = []
        total_params = 0
        for key, module in model._modules.items():
            mod_str, num_params = repr(module)
            mod_str = nn.modules.module._addindent(mod_str, 2)
            child_lines.append('(' + key + '): ' + mod_str)
            total_params += num_params
        lines = extra_lines + child_lines

        for name, p in model._parameters.items():
            if hasattr(p, 'shape'):
                total_params += reduce(lambda x, y: x * y, p.shape)

        main_str = model._get_name() + '('
        if lines:
            # simple one-liner info, which most builtin Modules will use
            if len(extra_lines) == 1 and not child_lines:
                main_str += extra_lines[0]
            else:
                main_str += '\n  ' + '\n  '.join(lines) + '\n'

        main_str += ')'
        if file is sys.stdout:
            main_str += ', \033[92m{:,}\033[0m params'.format(total_params)
        else:
            main_str += ', {:,} params'.format(total_params)
        return main_str, total_params

    string, count = repr(model)
    if file is not None:
        if isinstance(file, str):
            file = open(file, 'w')
        print(string, file=file)
        file.flush()

    return count


def grad_norm(model: nn.Module):
    total_norm = 0
    for p in model.parameters():
        param_norm = p.grad.data.norm(2)
        total_norm += param_norm.item() ** 2
    return total_norm ** 0.5


def distributed():
    return dist.is_available() and dist.is_initialized()

## Feature Extractor

In [None]:
import nltk
nltk.download('punkt')
from nltk.tokenize import word_tokenize 
from nltk.tokenize import sent_tokenize
import re
import numpy as np
import string 

from lexicalrichness import LexicalRichness
from readability import Readability

class Stylometry():

  def __init__(self, phraseology_features= True, diversity_features = True, punct_analysis_features = True):

    self.phraseology_features = phraseology_features
    self.diversity_features = diversity_features
    self.punct_analysis_features = punct_analysis_features

  def word_count(self, document):

    tokens = word_tokenize(document)

    nonPunct = re.compile('.*[A-Za-z0-9].*')  # must contain a letter or digit
    
    filtered = [w for w in tokens if nonPunct.match(w)]

    return len(filtered)

  def sentence_count(self, document):

    tokens = sent_tokenize(document)

    nonPunct = re.compile('.*[A-Za-z0-9].*')  # must contain a letter or digit
    
    filtered = [w for w in tokens if nonPunct.match(w)]

    return len(filtered)

  def paragraph_count(self, document):

    tokens = document.splitlines()

    nonPunct = re.compile('.*[A-Za-z0-9].*')  # must contain a letter or digit
    
    filtered = [w for w in tokens if nonPunct.match(w)]

    return len(filtered)

  def word_count_sent(self, document):

    tokens = sent_tokenize(document)

    nonPunct = re.compile('.*[A-Za-z0-9].*')  # must contain a letter or digit
    
    filtered = [w for w in tokens if nonPunct.match(w)]

    word_counts = [self.word_count(sent) for sent in filtered]

    if len(word_counts) ==0:

      return 0, 0

    mean = sum(word_counts) / len(word_counts)
    variance = sum([((x - mean) ** 2) for x in word_counts]) / len(word_counts)
    res = variance ** 0.5

    return mean, res

  def word_count_para(self, document):

    tokens = document.splitlines()

    nonPunct = re.compile('.*[A-Za-z0-9].*')  # must contain a letter or digit
    
    filtered = [w for w in tokens if nonPunct.match(w)]

    word_counts = [self.word_count(para) for para in filtered]

    if len(word_counts) ==0:

      return 0, 0

    mean = sum(word_counts) / len(word_counts)
    variance = sum([((x - mean) ** 2) for x in word_counts]) / len(word_counts)
    res = variance ** 0.5

    return mean, res

  def sent_count_para(self, document):

    tokens = document.splitlines()

    nonPunct = re.compile('.*[A-Za-z0-9].*')  # must contain a letter or digit
    
    filtered = [w for w in tokens if nonPunct.match(w)]

    sent_counts = [self.sentence_count(para) for para in filtered]

    if len(sent_counts) ==0:

      return 0, 0

    mean = sum(sent_counts) / len(sent_counts)
    variance = sum([((x - mean) ** 2) for x in sent_counts]) / len(sent_counts)
    res = variance ** 0.5

    return mean, res


  def total_punc_count(self, document):
    
    punct_count = 0

    for char in document:
      
      if char in string.punctuation:

        punct_count +=1
    
    return punct_count


  def special_punc_count(self, document, special_puncts):
    
    punct_count = []

    for punct in special_puncts:
      
      punct_count.append(document.count(punct))
    
    total_puncts = self.total_punc_count(document)
    if total_puncts==0:
      return [0 for count in punct_count]
    else:
      return [float(count)/ total_puncts for count in punct_count]

  def special_punc_count_sent(self, document, special_puncts):

    tokens = sent_tokenize(document)

    nonPunct = re.compile('.*[A-Za-z0-9].*')  # must contain a letter or digit
    
    filtered = [w for w in tokens if nonPunct.match(w)]

    punct_count = [0 for i in special_puncts] # Init as 0 

    if not filtered:
      return punct_count

    for sent in filtered:

      for punct in special_puncts:
        
        punct_count[special_puncts.index(punct)] += sent.count(punct)
      
    return [float(count)/ len(filtered) for count in punct_count]


  def special_punc_count_para(self, document, special_puncts):

    tokens = document.splitlines()

    nonPunct = re.compile('.*[A-Za-z0-9].*')  # must contain a letter or digit
    
    filtered = [w for w in tokens if nonPunct.match(w)]

    punct_count = [0 for i in special_puncts] # Init as 0 

    if not filtered:
      return punct_count

    for para in filtered:

      for punct in special_puncts:
        
        punct_count[special_puncts.index(punct)] += para.count(punct)
      
    return [float(count)/ len(filtered) for count in punct_count]

  
  def readability_score(self, document):

    try: 

      r = Readability(document)

      fk = r.flesch_kincaid()
      f = r.flesch()
      ari = r.ari()

    except:

      return 0, 0, 0
    
    else:
      
      return fk.score, f.score, ari.score


  def lexical_richness(self, document):

    sample_size = 10
    iterations = 50 
    
    lex = LexicalRichness(document)
    ret_list = []
    words = document.split()
    if len(words)>45:
      ret_list.append(lex.mattr(window_size=25))
    else:
      ret_list.append(lex.mattr(window_size=len(words)//3))
    ret_list.append(lex.mtld(threshold=0.72))
    return ret_list

  
  def get_features(self, document, special_puncts):

    feature_row = [] 

    if self.phraseology_features:
      ## phraseology features
      # print(document)
      feature_row.append(self.word_count(document))
      feature_row.append(self.sentence_count(document))
      feature_row.append(self.paragraph_count(document))

      # word count per sentence

      word_count_vals = self.word_count_sent(document)
      feature_row.append(word_count_vals[0])
      feature_row.append(word_count_vals[1])

      # word count per paragraph
      word_count_vals = self.word_count_para(document)
      feature_row.append(word_count_vals[0])
      feature_row.append(word_count_vals[1])

      # sentence count per paragraph
      sent_count_vals = self.sent_count_para(document)
      feature_row.append(sent_count_vals[0])
      feature_row.append(sent_count_vals[1])

    if self.diversity_features:
      # diversity features

      reareadability = self.readability_score(document)
      feature_row.append(reareadability[0])
      feature_row.append(reareadability[1])
      feature_row.append(reareadability[2])

      # word count per sentence
      richness = self.lexical_richness(document)
      feature_row.append(richness[0])
      feature_row.append(richness[1])

    if self.punct_analysis_features:
      ## punctuation features

      feature_row.append(self.total_punc_count(document))
      feature_row.extend(self.special_punc_count(document, special_puncts))
      feature_row.extend(self.special_punc_count_sent(document, special_puncts))
      feature_row.extend(self.special_punc_count_para(document, special_puncts))

    return feature_row

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


## Data Loader

In [None]:
import json
from typing import List

import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import PreTrainedTokenizer

import re
import unicodedata

import nltk
from nltk.corpus import stopwords
from nltk.tag import pos_tag
# from pycontractions import Contractions
nltk.download('averaged_perceptron_tagger')
nltk.download('stopwords')
nltk.download('wordnet')


CONTRACTION_MAP = { "ain't": "is not",
                    "aren't": "are not",
                    "can't": "cannot",
                    "can't've": "cannot have",
                    "'cause": "because",
                    "could've": "could have",
                    "couldn't": "could not",
                    "couldn't've": "could not have",
                    "didn't": "did not",
                    "doesn't": "does not",
                    "don't": "do not",
                    "hadn't": "had not",
                    "hadn't've": "had not have",
                    "hasn't": "has not",
                    "haven't": "have not",
                    "he'd": "he would",
                    "he'd've": "he would have",
                    "he'll": "he will",
                    "he'll've": "he he will have",
                    "he's": "he is",
                    "how'd": "how did",
                    "how'd'y": "how do you",
                    "how'll": "how will",
                    "how's": "how is",
                    "I'd": "I would",
                    "I ain't": "I am not",
                    "I'd've": "I would have",
                    "I'll": "I will",
                    "I'll've": "I will have",
                    "I'm": "I am",
                    "I've": "I have",
                    "i'd": "i would",
                    "i'd've": "i would have",
                    "i'll": "i will",
                    "i'll've": "i will have",
                    "i'm": "i am",
                    "i've": "i have",
                    "isn't": "is not",
                    "it'd": "it would",
                    "it'd've": "it would have",
                    "it'll": "it will",
                    "it'll've": "it will have",
                    "it's": "it is",
                    "let's": "let us",
                    "ma'am": "madam",
                    "mayn't": "may not",
                    "might've": "might have",
                    "mightn't": "might not",
                    "mightn't've": "might not have",
                    "must've": "must have",
                    "mustn't": "must not",
                    "mustn't've": "must not have",
                    "needn't": "need not",
                    "needn't've": "need not have",
                    "o'clock": "of the clock",
                    "oughtn't": "ought not",
                    "oughtn't've": "ought not have",
                    "shan't": "shall not",
                    "sha'n't": "shall not",
                    "shan't've": "shall not have",
                    "she'd": "she would",
                    "she'd've": "she would have",
                    "she'll": "she will",
                    "she'll've": "she will have",
                    "she's": "she is",
                    "should've": "should have",
                    "shouldn't": "should not",
                    "shouldn't've": "should not have",
                    "so've": "so have",
                    "so's": "so as",
                    "that'd": "that would",
                    "that'd've": "that would have",
                    "that's": "that is",
                    "there'd": "there would",
                    "there'd've": "there would have",
                    "there's": "there is",
                    "they'd": "they would",
                    "they'd've": "they would have",
                    "they'll": "they will",
                    "they'll've": "they will have",
                    "they're": "they are",
                    "they've": "they have",
                    "to've": "to have",
                    "wasn't": "was not",
                    "we'd": "we would",
                    "we'd've": "we would have",
                    "we'll": "we will",
                    "we'll've": "we will have",
                    "we're": "we are",
                    "we've": "we have",
                    "weren't": "were not",
                    "what'll": "what will",
                    "what'll've": "what will have",
                    "what're": "what are",
                    "what's": "what is",
                    "what've": "what have",
                    "when's": "when is",
                    "when've": "when have",
                    "where'd": "where did",
                    "where's": "where is",
                    "where've": "where have",
                    "who'll": "who will",
                    "who'll've": "who will have",
                    "who's": "who is",
                    "who've": "who have",
                    "why's": "why is",
                    "why've": "why have",
                    "will've": "will have",
                    "won't": "will not",
                    "won't've": "will not have",
                    "would've": "would have",
                    "wouldn't": "would not",
                    "wouldn't've": "would not have",
                    "y'all": "you all",
                    "y'all'd": "you all would",
                    "y'all'd've": "you all would have",
                    "y'all're": "you all are",
                    "y'all've": "you all have",
                    "you'd": "you would",
                    "you'd've": "you would have",
                    "you'll": "you will",
                    "you'll've": "you will have",
                    "you're": "you are",
                    "you've": "you have"
                    }


class PreProcess:
    def __init__(self, lowercase_norm=False, period_norm=False, special_chars_norm=False, accented_norm=False, contractions_norm=False,
                 stemming_norm=False, lemma_norm=False, stopword_norm=False, proper_norm=False):

        self.lowercase_norm = lowercase_norm
        self.period_norm = period_norm
        self.special_chars_norm = special_chars_norm
        self.accented_norm = accented_norm
        self.contractions_norm = contractions_norm
        self.stemming_norm = stemming_norm
        self.lemma_norm = lemma_norm
        self.stopword_norm = stopword_norm
        self.proper_norm = proper_norm

    def lowercase_normalization(self, data):

        return data.lower()

    def period_remove(self, data):

        return data.replace(".", " ")

    def special_char_remove(self, data, remove_digits=False):  # Remove special characters
        tokens = self.tokenization(data)
        special_char_norm_data = []

        for token in tokens:
            sentence = ""
            for word in token:
                sentence += word + " "
            sentence.rstrip()

            clean_remove = re.compile('<.*?>')
            norm_sentence = re.sub(clean_remove, '', sentence)

            norm_sentence = re.sub(r'[^\x00-\x7F]+','', norm_sentence)
            norm_sentence = norm_sentence.replace("\\", "")
            norm_sentence = norm_sentence.replace("-", " ")
            norm_sentence = norm_sentence.replace(",", "")
            special_char_norm_data.append(norm_sentence)

        return special_char_norm_data

    def accented_word_normalization(self, data):  # Normalize accented chars/words
        tokens = self.tokenization(data)
        accented_norm_data = []

        for token in tokens:
            sentence = ""
            for word in token:
                sentence += word + " "
            sentence.rstrip()
            norm_sentence = unicodedata.normalize('NFKD', sentence).encode('ascii', 'ignore').decode('utf-8', 'ignore')

            accented_norm_data.append(norm_sentence)

        return accented_norm_data

    def expand_contractions(self, data, pycontrct=False):  # Expand contractions

        # Simple contraction removal based on pre-defined set of contractions
        contraction_mapping = CONTRACTION_MAP
        contractions_pattern = re.compile('({})'.format('|'.join(contraction_mapping.keys())),
                                          flags=re.IGNORECASE | re.DOTALL)

        def expand_match(contraction):
            match = contraction.group(0)
            first_char = match[0]
            expanded_contraction = contraction_mapping.get(match) \
                if contraction_mapping.get(match) \
                else contraction_mapping.get(match.lower())
            expanded_contraction = first_char + expanded_contraction[1:]
            return expanded_contraction

        tokens = self.tokenization(data)
        contraction_norm_data = []

        for token in tokens:
            sentence = ""
            for word in token:
                sentence += word + " "
            sentence.rstrip()

            expanded_text = contractions_pattern.sub(expand_match, sentence)
            expanded_text = re.sub("'", "", expanded_text)

            contraction_norm_data.append(expanded_text)

        return contraction_norm_data

    def stemming(self, data):
        stemmer = nltk.stem.PorterStemmer()
        tokens = self.tokenization(data)
        stemmed_data = []

        for i in range(len(tokens)):
            s1 = " ".join(stemmer.stem(tokens[i][j]) for j in range(len(tokens[i])))
            stemmed_data.append(s1)

        return stemmed_data

    def lemmatization(self, data):
        lemma = nltk.stem.WordNetLemmatizer()
        tokens = self.tokenization(data)
        lemmatized_data = []

        for i in range(len(tokens)):
            s1 = " ".join(lemma.lemmatize(tokens[i][j]) for j in range(len(tokens[i])))
            lemmatized_data.append(s1)

        return lemmatized_data

    def stopword_remove(self, data):  # Remove special characters
        filtered_sentence = []
        stop_words = set(stopwords.words('english'))
        data = self.tokenization(data)

        for i in range(len(data)):
            res = ""
            for j in range(len(data[i])):
                if data[i][j].lower() not in stop_words:
                    res = res + " " + data[i][j]
            filtered_sentence.append(res)

        return filtered_sentence

    def remove_proper_nouns(self, data):
        common_words = []
        data = self.tokenization(data)
        for i in range(len(data)):
            tagged_sent = pos_tag(data[i])
            proper_nouns = [word for word, pos in tagged_sent if pos == 'NNP']
            res = ""
            for j in range(len(data[i])):
                if data[i][j] not in proper_nouns:
                    res = res + " " + data[i][j]
            common_words.append(res)

        return common_words

    def tokenization(self, data):
        tokens = []
        for i in range(len(data)):
            tokenizer = nltk.tokenize.WhitespaceTokenizer()
            tokens.append(tokenizer.tokenize(data[i]))
        return tokens

    def fit(self, data):

        data = [str(data)]

        if self.special_chars_norm:
            data = self.special_char_remove(data, remove_digits=False)

        # if self.contractions_norm:
        #     data = self.expand_contractions(data)

        if self.accented_norm:
            data = self.accented_word_normalization(data)

        if self.stemming_norm:
            data = self.stemming(data)

        if self.proper_norm:
            data = self.remove_proper_nouns(data)

        if self.stopword_norm:
            data = self.stopword_remove(data)

        if self.lemma_norm:
            data = self.lemmatization(data)

        data = data[0]

        if self.lowercase_norm:
            data = self.lowercase_normalization(str(data))

        if self.period_norm:
            data = self.period_remove(str(data))

        return data

def load_texts(data_file, label=False, expected_size=None):
    texts = []

    for line in tqdm(open(data_file), desc=f'Loading {data_file}'):
        texts.append(json.loads(line)['text'])

    if label:
        label = []
        for line in tqdm(open(data_file), desc=f'Loading {data_file}'):
            label.append(json.loads(line)['label'])

        return texts, label

    return texts


class Corpus:
    def __init__(self, name, data_dir='data', label=False, skip_train=False, single_file=False):

        self.name = name

        if single_file:

            if label:
                self.data, self.label = load_texts(f'{data_dir}/{name}.jsonl', label=True)
            else:
                self.data = load_texts(f'{data_dir}/{name}.jsonl')

        else:

            self.train = load_texts(f'{data_dir}/{name}.train.jsonl') if not skip_train else None
            self.test = load_texts(f'{data_dir}/{name}.test.jsonl')
            self.valid = load_texts(f'{data_dir}/{name}.valid.jsonl')


class EncodedDataset(Dataset):
    def __init__(self, real_texts: List[str], fake_texts: List[str], tokenizer: PreTrainedTokenizer, special_puncts: List[str],
                 max_sequence_length: int = None, min_sequence_length: int = None):
        self.real_texts = real_texts
        self.fake_texts = fake_texts
        self.tokenizer = tokenizer
        self.max_sequence_length = max_sequence_length
        self.min_sequence_length = min_sequence_length
        self.special_puncts = special_puncts
        self.style_extractor= Stylometry(phraseology_features= True, diversity_features = False, punct_analysis_features = True)

    def __len__(self):
        return len(self.real_texts) + len(self.fake_texts)

    def __getitem__(self, index):

        if index < len(self.real_texts):
            text = self.real_texts[index]
            label = 0
        else:
            text = self.fake_texts[index - len(self.real_texts)]
            label = 1

        stylo_features = self.style_extractor.get_features(text, self.special_puncts)
        # Preprocessing
        preprocessor = PreProcess(special_chars_norm=True, lowercase_norm=True, period_norm=True, proper_norm=True, accented_norm=True)

        text = preprocessor.fit(text)

        padded_sequences = self.tokenizer(text, padding='max_length', max_length= self.max_sequence_length, truncation=True)

        
        return torch.tensor(padded_sequences['input_ids']), torch.tensor(padded_sequences['attention_mask']), torch.tensor(stylo_features), label



class EncodedSingleDataset(Dataset):
    def __init__(self, input_texts: List[str], input_labels: List[int], tokenizer: PreTrainedTokenizer, special_puncts: List[str],
                 max_sequence_length: int = None, min_sequence_length: int = None):
        self.input_texts = input_texts
        self.input_labels = input_labels
        self.tokenizer = tokenizer
        self.max_sequence_length = max_sequence_length
        self.min_sequence_length = min_sequence_length
        self.special_puncts = special_puncts
        self.style_extractor= Stylometry(phraseology_features= True, diversity_features = False, punct_analysis_features = True)


    def __len__(self):
        return len(self.input_texts)

    def __getitem__(self, index):

        text = self.input_texts[index]
        label = self.input_labels[index]

        stylo_features = self.style_extractor.get_features(text, self.special_puncts)

        # Preprocessing
        preprocessor = PreProcess(special_chars_norm=True, lowercase_norm=True, period_norm=True, proper_norm=True, accented_norm=True)

        text = preprocessor.fit(text)

        padded_sequences = self.tokenizer(text, padding='max_length', max_length=self.max_sequence_length, truncation=True)


        return torch.tensor(padded_sequences['input_ids']), torch.tensor(padded_sequences['attention_mask']), torch.tensor(stylo_features), label


class EncodeEvalData(Dataset):
    def __init__(self, input_texts: List[str], tokenizer: PreTrainedTokenizer, special_puncts: List[str],
                 max_sequence_length: int = None, min_sequence_length: int = None):

        self.input_texts = input_texts
        self.tokenizer = tokenizer
        self.max_sequence_length = max_sequence_length
        self.min_sequence_length = min_sequence_length
        self.special_puncts = special_puncts
        self.style_extractor= Stylometry(phraseology_features= True, diversity_features = False, punct_analysis_features = True)


    def __len__(self):
        return len(self.input_texts)

    def __getitem__(self, index):
        text = self.input_texts[index]

        stylo_features = self.style_extractor.get_features(text, self.special_puncts)
        # Preprocessing
        preprocessor = PreProcess(special_chars_norm=True, lowercase_norm=True, period_norm=True, proper_norm=True, accented_norm=True)

        text = preprocessor.fit(text)

        padded_sequences = self.tokenizer(text, padding='max_length', max_length=self.max_sequence_length, truncation=True)


        return torch.tensor(padded_sequences['input_ids']), torch.tensor(padded_sequences['attention_mask']), torch.tensor(stylo_features)


[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Unzipping taggers/averaged_perceptron_tagger.zip.
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Unzipping corpora/wordnet.zip.


## Model Code

### Roberta Model

In [None]:
import torch
from torch.nn import Softmax
from torch.nn import CrossEntropyLoss, MSELoss
from typing import Optional, Tuple

from transformers import RobertaForSequenceClassification

from transformers.modeling_outputs import SequenceClassifierOutput

from dataclasses import dataclass

@dataclass
class SequenceClassifierOutputWithLastLayer(SequenceClassifierOutput):

    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    last_hidden_state: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None


class RobertaForFusion(RobertaForSequenceClassification):
    _keys_to_ignore_on_load_missing = [r"position_ids"]

    def __init__(self, config):
        super().__init__(config)

        self.soft_max = Softmax(dim=1)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
            config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = outputs[0]
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            loss = None
            if labels is not None:
                if self.num_labels == 1:
                    #  We are doing regression
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), labels.view(-1))
                else:
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        softmax_logits = self.soft_max(logits)

        if not return_dict:
            output = (softmax_logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutputWithLastLayer(
            loss=loss,
            logits=softmax_logits,
            last_hidden_state=sequence_output,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )



### Joined Ensemble (Stylo + LM)

In [None]:
class FusedClassifier(torch.nn.Module):
    def __init__(self, lm, device, FUSED_INPUT_SIZE):
        super(FusedClassifier, self).__init__()

        self.lm = lm

        # move to device
        self.lm.to(device)

        self.reducer = nn.Sequential(
            nn.Linear(FUSED_INPUT_SIZE, 512),
            nn.ReLU(),
            nn.Linear(512, 64)
        ).to(device)

        self.classifier = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 2),
            nn.Softmax(dim=-1)
        ).to(device)

        # the LM is already pre-trained, no need to calc grads anymore
        for param in self.lm.parameters():
            param.requires_grad = False

    def forward(self, data, custom_features):
        # output = self.BERT(data[0].to(device), attention_mask=data[2].to(device))
        # output = output[-1][0][:, -1, :].detach()

        if len(data) < 3:
          output_dic = self.lm(data[0], attention_mask=data[1])
        
        else:
          output_dic = self.lm(data[0], attention_mask=data[1], labels=data[2])

        lm_emb_output = output_dic["last_hidden_state"][:, -1, :].detach()

        # append manuall features to Roberta features
        x = torch.cat((lm_emb_output, custom_features), axis=-1)
        c = self.reducer(x)

        return self.classifier(c)

## Train Code

### LM Attribution Fine-Tuning

In [None]:
"""Training code for the detector model"""

import argparse
import os
import subprocess
import sys
from itertools import count
from multiprocessing import Process

import torch
import torch.distributed as dist
from torch import nn
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Adam
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
from tqdm import tqdm
from transformers import *
from torch.nn import CrossEntropyLoss

import sys

torch.manual_seed(int(1000))

def setup_distributed(port=29500):
    if not dist.is_available() or not torch.cuda.is_available() or torch.cuda.device_count() <= 1:
        return 0, 1

    if 'MPIR_CVAR_CH3_INTERFACE_HOSTNAME' in os.environ:
        from mpi4py import MPI
        mpi_rank = MPI.COMM_WORLD.Get_rank()
        mpi_size = MPI.COMM_WORLD.Get_size()

        os.environ["MASTER_ADDR"] = '127.0.0.1'
        os.environ["MASTER_PORT"] = str(port)

        dist.init_process_group(backend="nccl", world_size=mpi_size, rank=mpi_rank)
        return mpi_rank, mpi_size

    dist.init_process_group(backend="nccl", init_method="env://")
    return dist.get_rank(), dist.get_world_size()


def load_datasets(data_dir, real_dataset, fake_dataset, tokenizer, special_puncts, batch_size,
                  max_sequence_length, random_sequence_length):

    real_corpus = Corpus(real_dataset, data_dir=data_dir)

    if fake_dataset == "TWO":
        real_train, real_valid = real_corpus.train * 2, real_corpus.valid * 2
        fake_corpora = [Corpus(name, data_dir=data_dir) for name in ['grover_fake', 'gpt2_fake']]
        fake_train = sum([corpus.train for corpus in fake_corpora], [])
        fake_valid = sum([corpus.valid for corpus in fake_corpora], [])

    else:
        fake_corpus = Corpus(fake_dataset, data_dir=data_dir)

        real_train, real_valid = real_corpus.train, real_corpus.valid
        fake_train, fake_valid = fake_corpus.train, fake_corpus.valid

    Sampler = DistributedSampler if distributed() and dist.get_world_size() > 1 else RandomSampler

    min_sequence_length = 10 if random_sequence_length else None
    train_dataset = EncodedDataset(real_train, fake_train, tokenizer, special_puncts, max_sequence_length, min_sequence_length)
    train_loader = DataLoader(train_dataset, batch_size, sampler=Sampler(train_dataset), num_workers=0)

    validation_dataset = EncodedDataset(real_valid, fake_valid, tokenizer, special_puncts, max_sequence_length, min_sequence_length)
    validation_loader = DataLoader(validation_dataset, batch_size=1, sampler=Sampler(validation_dataset))

    return train_loader, validation_loader


def accuracy_sum(logits, labels):
    if list(logits.shape) == list(labels.shape) + [2]:
        # 2-d outputs
        classification = (logits[..., 0] < logits[..., 1]).long().flatten()
    else:
        classification = (logits > 0).long().flatten()
    assert classification.shape == labels.shape
    return (classification == labels).float().sum().item()


def train(model: nn.Module, optimizer, device: str, loader: DataLoader, desc='Train'):
    model.train()

    train_accuracy = 0
    train_epoch_size = 0
    train_loss = 0

    with tqdm(loader, desc=desc, disable=distributed() and dist.get_rank() > 0) as loop:
        for texts, masks, _, labels in loop:

            texts, masks, labels = texts.to(device), masks.to(device), labels.to(device)
            batch_size = texts.shape[0]
            # print(texts)
            # print(labels)

            optimizer.zero_grad()
            output_dic = model(texts, attention_mask=masks, labels=labels)
            loss, logits = output_dic["loss"], output_dic["logits"]
            # print("Loss is:" , model(texts, attention_mask=masks, labels=labels))
            loss.backward()
            optimizer.step()

            batch_accuracy = accuracy_sum(logits, labels)
            train_accuracy += batch_accuracy
            train_epoch_size += batch_size
            train_loss += loss.item() * batch_size

            loop.set_postfix(loss=loss.item(), acc=train_accuracy / train_epoch_size)

    return {
        "train/accuracy": train_accuracy,
        "train/epoch_size": train_epoch_size,
        "train/loss": train_loss
    }


def validate(model: nn.Module, device: str, loader: DataLoader, votes=1, desc='Validation'):
    model.eval()

    validation_accuracy = 0
    validation_epoch_size = 0
    validation_loss = 0

    records = [record for v in range(votes) for record in tqdm(loader, desc=f'Preloading data ... {v}',
                                                               disable=distributed() and dist.get_rank() > 0)]
    records = [[records[v * len(loader) + i] for v in range(votes)] for i in range(len(loader))]

    with tqdm(records, desc=desc, disable=distributed() and dist.get_rank() > 0) as loop, torch.no_grad():
        for example in loop:
            losses = []
            logit_votes = []

            for texts, masks, _, labels in example:
                texts, masks, labels = texts.to(device), masks.to(device), labels.to(device)
                batch_size = texts.shape[0]

                output_dic = model(texts, attention_mask=masks, labels=labels)
                loss, logits = output_dic["loss"], output_dic["logits"]
                losses.append(loss)
                logit_votes.append(logits)

            loss = torch.stack(losses).mean(dim=0)
            logits = torch.stack(logit_votes).mean(dim=0)

            batch_accuracy = accuracy_sum(logits, labels)
            validation_accuracy += batch_accuracy
            validation_epoch_size += batch_size
            validation_loss += loss.item() * batch_size

            loop.set_postfix(loss=loss.item(), acc=validation_accuracy / validation_epoch_size)

    return {
        "validation/accuracy": validation_accuracy,
        "validation/epoch_size": validation_epoch_size,
        "validation/loss": validation_loss
    }


def _all_reduce_dict(d, device):
    # wrap in tensor and use reduce to gpu0 tensor
    output_d = {}
    for (key, value) in sorted(d.items()):
        tensor_input = torch.tensor([[value]]).to(device)
        # torch.distributed.all_reduce(tensor_input)
        output_d[key] = tensor_input.item()
    return output_d


def run(max_epochs=None,
        device=None,
        batch_size=8,
        max_sequence_length=256,
        random_sequence_length=False,
        epoch_size=None,
        seed=None,
        data_dir='data',
        real_dataset='real',
        fake_dataset='grover_fake',
        token_dropout=None,
        large=True,
        learning_rate=2e-5,
        weight_decay=0,
        load_from_checkpoint=False,
        checkpoint_name='neuralnews',
        special_puncts= [],
        **kwargs):
    args = locals()
    rank, world_size = setup_distributed()

    print(args)

    if device is None:
        device = f'cuda:{rank}' if torch.cuda.is_available() else 'cpu'

    print('rank:', rank, 'world_size:', world_size, 'device:', device)

    import torch.distributed as dist
    if distributed() and rank > 0:
        dist.barrier()

    model_name = 'roberta-large' if large else 'roberta-base'
    tokenization_utils.logger.setLevel('ERROR')
    tokenizer = RobertaTokenizer.from_pretrained(model_name)
    model = RobertaForFusion.from_pretrained(model_name).to(device)

    # Load the model from checkpoints
    if load_from_checkpoint:
        if device == "cpu":
            model.load_state_dict(torch.load((data_dir + '{}.pt').format(checkpoint_name),
                                             map_location='cpu')['model_state_dict'])
        else:
            model.load_state_dict(
                torch.load((data_dir + '{}.pt').format(checkpoint_name))['model_state_dict'])

    if rank == 0:
        summary(model)
        if distributed():
            dist.barrier()

    if world_size > 1:
        model = DistributedDataParallel(model, [rank], output_device=rank, find_unused_parameters=True)

    train_loader, validation_loader = load_datasets(data_dir, real_dataset, fake_dataset, tokenizer, special_puncts, batch_size,
                                                    max_sequence_length, random_sequence_length)

    optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    epoch_loop = count(1) if max_epochs is None else range(1, max_epochs + 1)

    logdir = os.environ.get("OPENAI_LOGDIR", "logs")
    os.makedirs(logdir, exist_ok=True)

    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter(logdir) if rank == 0 else None
    best_validation_accuracy = 0
    without_progress = 0
    earlystop_epochs = 3

    for epoch in epoch_loop:
        if world_size > 1:
            train_loader.sampler.set_epoch(epoch)
            validation_loader.sampler.set_epoch(epoch)

        train_metrics = train(model, optimizer, device, train_loader, f'Epoch {epoch}')
        validation_metrics = validate(model, device, validation_loader)

        combined_metrics = _all_reduce_dict({**validation_metrics, **train_metrics}, device)

        combined_metrics["train/accuracy"] /= combined_metrics["train/epoch_size"]
        combined_metrics["train/loss"] /= combined_metrics["train/epoch_size"]
        combined_metrics["validation/accuracy"] /= combined_metrics["validation/epoch_size"]
        combined_metrics["validation/loss"] /= combined_metrics["validation/epoch_size"]

        if rank == 0:
            for key, value in combined_metrics.items():
                writer.add_scalar(key, value, global_step=epoch)

            if combined_metrics["validation/accuracy"] > best_validation_accuracy:
                without_progress = 0
                best_validation_accuracy = combined_metrics["validation/accuracy"]

                model_to_save = model.module if hasattr(model, 'module') else model
                torch.save(dict(
                        epoch=epoch,
                        model_state_dict=model_to_save.state_dict(),
                        optimizer_state_dict=optimizer.state_dict(),
                        args=args
                    ),
                    os.path.join("", "roberta_ft.pt")
                )

        without_progress += 1

        if without_progress >= earlystop_epochs:
            break


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--max-epochs', type=int, default=None)
    parser.add_argument('--device', type=str, default=None)
    parser.add_argument('--batch-size', type=int, default=16)
    parser.add_argument('--max-sequence-length', type=int, default=256)
    parser.add_argument('--random-sequence-length', action='store_true')
    parser.add_argument('--epoch-size', type=int, default=None)
    parser.add_argument('--seed', type=int, default=None)
    parser.add_argument('--data-dir', type=str, default="")
    parser.add_argument('--real-dataset', type=str, default='')
    parser.add_argument('--fake-dataset', type=str, default='')
    parser.add_argument('--token-dropout', type=float, default=None)

    parser.add_argument('--large', action='store_true', help='use the roberta-large model instead of roberta-base')
    parser.add_argument('--learning-rate', type=float, default=1e-5)
    parser.add_argument('--weight-decay', type=float, default=0)
    parser.add_argument('--load-decay', type=float, default=0)
    parser.add_argument('--special_puncts', type=list, default=["!","'", ",", "-", ":", ";", "?", "@", "\"", "=", "#"])

    args = parser.parse_args(args=['--max-epochs=20'])

    nproc = int(subprocess.check_output([sys.executable, '-c', "import torch;"
                                         "print(torch.cuda.device_count() if torch.cuda.is_available() else 1)"]))
    if nproc > 1:
        print(f'Launching {nproc} processes ...', file=sys.stderr)

        os.environ["MASTER_ADDR"] = '127.0.0.1'
        os.environ["MASTER_PORT"] = str(29500)
        os.environ['WORLD_SIZE'] = str(nproc)
        os.environ['OMP_NUM_THREAD'] = str(1)
        subprocesses = []

        for i in range(nproc):
            os.environ['RANK'] = str(i)
            os.environ['LOCAL_RANK'] = str(i)
            process = Process(target=run, kwargs=vars(args))
            process.start()
            subprocesses.append(process)

        for process in subprocesses:
            process.join()
    else:
        run(**vars(args))

### LM + Sylo Fusion Training

In [None]:
"""Training code for the detector model"""

import argparse
import os
import subprocess
import sys
from itertools import count
from multiprocessing import Process

import torch
import torch.distributed as dist
from torch import nn
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Adam
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
from tqdm import tqdm
from transformers import *


torch.manual_seed(int(1000))

def setup_distributed(port=29500):
    if not dist.is_available() or not torch.cuda.is_available() or torch.cuda.device_count() <= 1:
        return 0, 1

    if 'MPIR_CVAR_CH3_INTERFACE_HOSTNAME' in os.environ:
        from mpi4py import MPI
        mpi_rank = MPI.COMM_WORLD.Get_rank()
        mpi_size = MPI.COMM_WORLD.Get_size()

        os.environ["MASTER_ADDR"] = '127.0.0.1'
        os.environ["MASTER_PORT"] = str(port)

        dist.init_process_group(backend="nccl", world_size=mpi_size, rank=mpi_rank)
        return mpi_rank, mpi_size

    dist.init_process_group(backend="nccl", init_method="env://")
    return dist.get_rank(), dist.get_world_size()


def load_datasets(data_dir, real_dataset, fake_dataset, tokenizer, special_puncts, batch_size,
                  max_sequence_length, random_sequence_length):

    real_corpus = Corpus(real_dataset, data_dir=data_dir)

    if fake_dataset == "TWO":
        real_train, real_valid = real_corpus.train * 2, real_corpus.valid * 2
        fake_corpora = [Corpus(name, data_dir=data_dir) for name in ['grover_fake', 'gpt2_fake']]
        fake_train = sum([corpus.train for corpus in fake_corpora], [])
        fake_valid = sum([corpus.valid for corpus in fake_corpora], [])

    else:
        fake_corpus = Corpus(fake_dataset, data_dir=data_dir)

        real_train, real_valid = real_corpus.train, real_corpus.valid
        fake_train, fake_valid = fake_corpus.train, fake_corpus.valid

    Sampler = DistributedSampler if distributed() and dist.get_world_size() > 1 else RandomSampler

    min_sequence_length = 10 if random_sequence_length else None
    train_dataset = EncodedDataset(real_train, fake_train, tokenizer, special_puncts, max_sequence_length, min_sequence_length)
    train_loader = DataLoader(train_dataset, batch_size, sampler=Sampler(train_dataset), num_workers=0)

    validation_dataset = EncodedDataset(real_valid, fake_valid, tokenizer, special_puncts, max_sequence_length, min_sequence_length)
    validation_loader = DataLoader(validation_dataset, batch_size=1, sampler=Sampler(validation_dataset))

    return train_loader, validation_loader



def accuracy_sum(logits, labels):
    if list(logits.shape) == list(labels.shape) + [2]:
        # 2-d outputs
        classification = (logits[..., 0] < logits[..., 1]).long().flatten()
    else:
        classification = (logits > 0).long().flatten()
    assert classification.shape == labels.shape
    return (classification == labels).float().sum().item()


def train(model: nn.Module, optimizer, device: str, loader: DataLoader, desc='Train'):
    model.train()

    train_accuracy = 0
    train_epoch_size = 0
    train_loss = 0

    with tqdm(loader, desc=desc, disable=distributed() and dist.get_rank() > 0) as loop:
        for texts, masks, custom_features, labels in loop:

            texts, masks, custom_features, labels = texts.to(device), masks.to(device), custom_features.to(device), labels.to(device)
            batch_size = texts.shape[0]

            optimizer.zero_grad()
            predict_label = model(data=[texts, masks, labels], custom_features = custom_features)

            loss_fct = CrossEntropyLoss()
            loss = loss_fct(predict_label, labels)

            loss.backward()
            optimizer.step()

            batch_accuracy = accuracy_sum(predict_label, labels)
            train_accuracy += batch_accuracy
            train_epoch_size += batch_size
            train_loss += loss.item() * batch_size

            loop.set_postfix(loss=loss.item(), acc=train_accuracy / train_epoch_size)

    return {
        "train/accuracy": train_accuracy,
        "train/epoch_size": train_epoch_size,
        "train/loss": train_loss
    }


def validate(model: nn.Module, device: str, loader: DataLoader, votes=1, desc='Validation'):
    model.eval()

    validation_accuracy = 0
    validation_epoch_size = 0
    validation_loss = 0

    records = [record for v in range(votes) for record in tqdm(loader, desc=f'Preloading data ... {v}',
                                                               disable=distributed() and dist.get_rank() > 0)]
    records = [[records[v * len(loader) + i] for v in range(votes)] for i in range(len(loader))]

    with tqdm(records, desc=desc, disable=distributed() and dist.get_rank() > 0) as loop, torch.no_grad():
        for example in loop:
            losses = []
            logit_votes = []

            for texts, masks, custom_features, labels in example:

              texts, masks, custom_features, labels = texts.to(device), masks.to(device), custom_features.to(device), labels.to(device)
              batch_size = texts.shape[0]

              predict_label = model(data=[texts, masks, labels], custom_features = custom_features)

              loss_fct = CrossEntropyLoss()
              loss = loss_fct(predict_label, labels)
              losses.append(loss)
              logit_votes.append(predict_label)

            loss = torch.stack(losses).mean(dim=0)
            logits = torch.stack(logit_votes).mean(dim=0)

            batch_accuracy = accuracy_sum(logits, labels)
            validation_accuracy += batch_accuracy
            validation_epoch_size += batch_size
            validation_loss += loss.item() * batch_size

            loop.set_postfix(loss=loss.item(), acc=validation_accuracy / validation_epoch_size)

    return {
        "validation/accuracy": validation_accuracy,
        "validation/epoch_size": validation_epoch_size,
        "validation/loss": validation_loss
    }


def _all_reduce_dict(d, device):
    # wrap in tensor and use reduce to gpu0 tensor
    output_d = {}
    for (key, value) in sorted(d.items()):
        tensor_input = torch.tensor([[value]]).to(device)
        # torch.distributed.all_reduce(tensor_input)
        output_d[key] = tensor_input.item()
    return output_d


def run(max_epochs=None,
        device=None,
        batch_size=16,
        max_sequence_length=256,
        random_sequence_length=False,
        epoch_size=None,
        seed=None,
        data_dir='data',
        real_dataset='real',
        fake_dataset='grover_fake',
        token_dropout=None,
        large=True,
        learning_rate=2e-5,
        weight_decay=0,
        load_from_checkpoint=False,
        checkpoint_name='neuralnews',
        special_puncts= [],
        FUSED_INPUT_SIZE = 811,
        **kwargs):
    args = locals()
    rank, world_size = setup_distributed()

    if device is None:
        device = f'cuda:{rank}' if torch.cuda.is_available() else 'cpu'

    print('rank:', rank, 'world_size:', world_size, 'device:', device)

    import torch.distributed as dist
    if distributed() and rank > 0:
        dist.barrier()

    model_name = 'roberta-large' if large else 'roberta-base'
    tokenization_utils.logger.setLevel('ERROR')
    tokenizer = RobertaTokenizer.from_pretrained(model_name)
    lm = RobertaForFusion.from_pretrained(model_name).to(device)

    # Load the model from checkpoints
    if load_from_checkpoint:
        if device == "cpu":
            lm.load_state_dict(torch.load((data_dir + '{}.pt').format(checkpoint_name),
                                             map_location='cpu')['model_state_dict'])
        else:
            lm.load_state_dict(
                torch.load((data_dir + '{}.pt').format(checkpoint_name))['model_state_dict'])


    model = FusedClassifier(lm=lm, device=device, FUSED_INPUT_SIZE=FUSED_INPUT_SIZE)
    
    if rank == 0:
        summary(model)
        if distributed():
            dist.barrier()

    if world_size > 1:
        model = DistributedDataParallel(model, [rank], output_device=rank, find_unused_parameters=True)

    train_loader, validation_loader = load_datasets(data_dir, real_dataset, fake_dataset, tokenizer, special_puncts, batch_size,
                                                    max_sequence_length, random_sequence_length)

    optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    epoch_loop = count(1) if max_epochs is None else range(1, max_epochs + 1)

    logdir = os.environ.get("OPENAI_LOGDIR", "logs")
    os.makedirs(logdir, exist_ok=True)

    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter(logdir) if rank == 0 else None
    best_validation_accuracy = 0
    without_progress = 0
    earlystop_epochs = 3

    for epoch in epoch_loop:
        if world_size > 1:
            train_loader.sampler.set_epoch(epoch)
            validation_loader.sampler.set_epoch(epoch)

        train_metrics = train(model, optimizer, device, train_loader, f'Epoch {epoch}')
        validation_metrics = validate(model, device, validation_loader)

        combined_metrics = _all_reduce_dict({**validation_metrics, **train_metrics}, device)

        combined_metrics["train/accuracy"] /= combined_metrics["train/epoch_size"]
        combined_metrics["train/loss"] /= combined_metrics["train/epoch_size"]
        combined_metrics["validation/accuracy"] /= combined_metrics["validation/epoch_size"]
        combined_metrics["validation/loss"] /= combined_metrics["validation/epoch_size"]

        if rank == 0:
            for key, value in combined_metrics.items():
                writer.add_scalar(key, value, global_step=epoch)

            if combined_metrics["validation/accuracy"] > best_validation_accuracy:
                without_progress = 0
                best_validation_accuracy = combined_metrics["validation/accuracy"]

                model_to_save = model.module if hasattr(model, 'module') else model
                torch.save(dict(
                        epoch=epoch,
                        model_state_dict=model_to_save.state_dict(),
                        optimizer_state_dict=optimizer.state_dict(),
                        args=args
                    ),
                    os.path.join("", "robertagenattr_fusion_grover.pt")
                )

        without_progress += 1

        if without_progress >= earlystop_epochs:
            break


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--max-epochs', type=int, default=None)
    parser.add_argument('--device', type=str, default=None)
    parser.add_argument('--batch-size', type=int, default=16)
    parser.add_argument('--max-sequence-length', type=int, default=256)
    parser.add_argument('--random-sequence-length', action='store_true')
    parser.add_argument('--epoch-size', type=int, default=None)
    parser.add_argument('--seed', type=int, default=None)
    parser.add_argument('--data-dir', type=str, default="")
    parser.add_argument('--real-dataset', type=str, default='')
    parser.add_argument('--fake-dataset', type=str, default='')
    parser.add_argument('--token-dropout', type=float, default=None)

    parser.add_argument('--large', action='store_true', help='use the roberta-large model instead of roberta-base')
    parser.add_argument('--learning-rate', type=float, default=2e-5)
    parser.add_argument('--weight-decay', type=float, default=0)
    parser.add_argument('--load-decay', type=float, default=0)

    parser.add_argument('--special_puncts', type=list, default=["!","'", ",", "-", ":", ";", "?", "@", "\"", "=", "#"])
    parser.add_argument('--FUSED_INPUT_SIZE', type=int, default=811)

    args = parser.parse_args(args=['--max-epochs=20'])

    nproc = int(subprocess.check_output([sys.executable, '-c', "import torch;"
                                         "print(torch.cuda.device_count() if torch.cuda.is_available() else 1)"]))
    if nproc > 1:
        print(f'Launching {nproc} processes ...', file=sys.stderr)

        os.environ["MASTER_ADDR"] = '127.0.0.1'
        os.environ["MASTER_PORT"] = str(29500)
        os.environ['WORLD_SIZE'] = str(nproc)
        os.environ['OMP_NUM_THREAD'] = str(1)
        subprocesses = []

        for i in range(nproc):
            os.environ['RANK'] = str(i)
            os.environ['LOCAL_RANK'] = str(i)
            process = Process(target=run, kwargs=vars(args))
            process.start()
            subprocesses.append(process)

        for process in subprocesses:
            process.join()
    else:
        run(**vars(args))

## Evaluation

In [None]:
import math
import torch
import argparse
from tqdm import tqdm
import pandas as pd
import random
import time

from torch.utils.data import DataLoader


from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from matplotlib import pyplot

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

import decimal


from transformers import *


def float_range(start, stop, step):
    while start < stop:
        yield float(start)
        start += decimal.Decimal(step)


def calculate_program_metrics(far, pd):

    pd_at_far = 0.0
    pd_at_eer = 0.0
    far_at_eer = 0.0

    for i in range(len(far)):
      if far[i] > 0.1:
        pd_at_far = pd[i-1]
        break

    for i in range(len(far)):
      if pd[i] > 1 - far[i]:
        pd_at_eer = (pd[i-1] + pd[i])/2
        far_at_eer = (far[i-1] + far[i])/2
        break
    
    
    print("pD @ 0.1 FAR = %.3f" % (pd_at_far))
    print("pD @ EER = %.3f" % (pd_at_eer))
    print("FAR @ EER = %.3f" % (far_at_eer))



class GeneratedTextDetection:
    """
    Artifact class
    """

    def __init__(self, args):
        torch.manual_seed(1000)

        self.args = args

        # Load the model from checkpoints
        self.init_dict = self._init_detector()

    def _init_detector(self):

        init_dict = {"kn_model": None, "kn_tokenizer": None,
                    "unk_model": None, "unk_tokenizer": None,
                   "attr_model": None, "attr_tokenizer": None, }

        if self.args.init_method == "fused":
            model_name = 'roberta-large' if self.args.kn_large else 'roberta-base'
            tokenization_utils.logger.setLevel('ERROR')
            tokenizer = RobertaTokenizer.from_pretrained(model_name)
            lm = RobertaForFusion.from_pretrained(model_name).to(self.args.device)

            model = FusedClassifier(lm=lm, device=self.args.device, FUSED_INPUT_SIZE=self.args.FUSED_INPUT_SIZE)
            # Load the model from checkpoints
            if self.args.device == "cpu":
                model.load_state_dict(torch.load((self.args.check_point + '{}.pt').format(self.args.known_model_name),
                                                 map_location='cpu')['model_state_dict'])
            else:
                print((self.args.check_point + '{}.pt').format(self.args.known_model_name))
                model.load_state_dict(
                    torch.load((self.args.check_point + '{}.pt').format(self.args.known_model_name))['model_state_dict'])
            
            init_dict["kn_model"] = model
            init_dict["kn_tokenizer"] = tokenizer
            return init_dict

      
    def evaluate(self, input_text):
        """
           Method that runs the evaluation and generate scores and evidence
        """

        # Encapsulate the inputs
        eval_dataset = EncodeEvalData(input_text, self.init_dict["kn_tokenizer"], self.args.special_puncts, self.args.max_sequence_length)
        eval_loader = DataLoader(eval_dataset)

        # Dictionary will contain all the scores and evidences generated by the model
        results = {"cls": [], "LLR_score": [], "prob_score": {"cls_0": [], "cls_1": []}, "generator": None}

        # Set eval mode
        if self.args.init_method == "fused":
            self.init_dict["kn_model"].eval()

      
        with torch.no_grad():
              for texts, masks, custom_features in eval_loader:
                  texts, masks, custom_features = texts.to(self.args.device), masks.to(self.args.device), custom_features.to(self.args.device)

                  if self.args.init_method == "fused":
                      # Individual model take care all the probes
                      output_dic = self.init_dict["kn_model"](data=[texts, masks], custom_features = custom_features)
                      disc_out = output_dic

                      cls0_prob = disc_out[:, 0].tolist()
                      cls1_prob = disc_out[:, 1].tolist()

                      results["prob_score"]["cls_0"].extend(cls0_prob)
                      results["prob_score"]["cls_1"].extend(cls1_prob)

                      prior_llr = math.log10(self.args.kn_priors[0]/self.args.kn_priors[1])

                      results["LLR_score"].extend([math.log10(prob/(1-prob)) + prior_llr for prob in cls1_prob])

                      _, predicted = torch.max(disc_out, 1)

                      results["cls"].extend(predicted.tolist())
                     
        return results



def main():
    parser = argparse.ArgumentParser(
        description='Generated Text: Detection'
    )

    # Input data and files
    parser.add_argument('--known_model_name', default="robertagenattr_fusion_grover", type=str,
                        help='name of the known generator detector model')

    parser.add_argument('--init_method', default="fused", type=str,
                        help='name of the generator attribution model')

    parser.add_argument('--check_point', default="/content/", type=str,
                        help='saved model checkpoint directory')

    # Model parameters
    parser.add_argument('--device', type=str, default=None)

    parser.add_argument('--kn_priors', type=list, default=[0.5, 0.5])
    parser.add_argument('--unk_priors', type=list, default=[0.5, 0.5])

    parser.add_argument('--batch-size', type=int, default=1)
    parser.add_argument('--max-sequence-length', type=int, default=256)
    parser.add_argument('--kn_large', type=bool, default=False)

    parser.add_argument('--special_puncts', type=list, default=["!","'", ",", "-", ":", ";", "?", "@", "\"", "=", "#"])
    parser.add_argument('--FUSED_INPUT_SIZE', type=int, default=811)


    args = parser.parse_args(args=['--check_point="'])


    if args.device is None:
        args.device = f'cuda:{0}' if torch.cuda.is_available() else 'cpu'


    predict_prob = []

    y = []

    artifact = GeneratedTextDetection(args)

    test_data = pd.read_csv("test.csv")

    multiple_lines = 0

    tp = 0
    tn = 0
    fn = 0
    fp = 0  

    for value in tqdm(test_data.itertuples()):
      
      if value.text.count("\n\n\n") > 0:
        multiple_lines +=1
      
      main_body_text = value.text
     

      if main_body_text == "":
        continue

      results = artifact.evaluate([main_body_text])

      y.append(value.label)


      predict_prob.append(results["LLR_score"][0])

      predicted = results["cls"][0]

      tp += ((predicted == value.label) & (value.label == 1))
      tn += ((predicted == value.label) & (value.label == 0))
      fn += ((predicted != value.label) & (value.label == 1))
      fp += ((predicted != value.label) & (value.label == 0))

    recall = float(tp) / (tp+fn)
    precision = float(tp) / (tp+fp)
    f1_score = 2 * float(precision) * recall / (precision + recall)

    print('TP: %d' % (
        tp))
    print('TN: %d' % (
        tn))
    print('FP: %d' % (
        fp))
    print('FN: %d' % (
        fn))

    print('Accuracy of the discriminator: %d %%' % (
            100 * (tp + tn) / (tp + tn + fp + fn)))
    print('Recall of the discriminator: %d %%' % (
        100 * recall))
    print('Precision of the discriminator: %d %%' % (
        100 * precision))
    print('f1_score of the discriminator: %d %%' % (
        100 * f1_score))
    

    # calculate scores
    lr_auc = roc_auc_score(y, predict_prob)

    # summarize scores
    print("\n")
    print(" ----- Extra Metrics -----")
    print()
    print('Classifier: ROC AUC=%.3f' % (lr_auc))

    # calculate roc curves
    lr_fpr, lr_tpr, _ = roc_curve(y, predict_prob)

    calculate_program_metrics(lr_fpr, lr_tpr)

    eq_fpr = list(float_range(0, 1, 1 / len(lr_fpr)))
    eq_tpr = [item for item in eq_fpr]

    # plot the roc curve for the model
    pyplot.plot(lr_fpr, lr_tpr, marker='.', label='Roberta')
    pyplot.plot(eq_fpr, eq_tpr, marker='.', label='Random Chance')
    # axis labels

    pyplot.xlabel('Probability of False Alarm')
    pyplot.ylabel('Probability of Detection')
    # show the legend
    pyplot.legend()
    # show the plot
    pyplot.show()
        

if __name__ == "__main__":
    main()
