### Setup: Install and import

In [1]:
!pip install transformers torch
!pip install nltk
!pip install tqdm
!pip install openai



In [2]:
# In order to make things work on google drive
from google.colab import drive

drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [3]:
import random
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from transformers import RobertaTokenizer, RobertaModel, RobertaConfig, AdamW
import nltk

from tqdm import tqdm
import numpy as np

import openai

In [4]:
openai.api_key = "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"

# A small example
completion = openai.chat.completions.create(
  model="gpt-3.5-turbo",
  messages=[
    {"role": "system", "content": "You are a poetic assistant, skilled in explaining complex programming concepts with creative flair."},
    {"role": "user", "content": "Compose a poem that explains the concept of recursion in programming."}
  ]
)

print(completion.choices[0].message)

ChatCompletionMessage(content="In the coding expanse, where logic blooms,\nA concept profound, recursion consumes.\nThrough loops and loops, it takes an elegant twist,\nAn enchanting dance in the code it insists.\n\nLike a sly serpent, it slithers through lines,\nUnwinding the puzzle with mystic designs.\nA self-repeating loop, where magic abounds,\nAn intricate pattern that astounds.\n\nImagine a mirror reflecting a reflection,\nWhere code calls itself, without an exception.\nIn a mesmerizing loop, it seeks its own truth,\nUnfolding a tale, infinite in its youth.\n\nA function, they say, is a portal quite rare,\nTo traverse deep realms, without worry or care.\nIt goes on and on, in a rhythmical chase,\nAn endless loop, like a gossamer lace.\n\nStacked like nesting dolls, it unfolds its might,\nUnraveling problems with dizzying height.\nA cascade of actions, like echoes subsist,\nSolving grand puzzles, persistently kissed.\n\nYet, beware, dear programmer, of infinite dreams,\nFor a run

### Load Pre-trained RoBERTa Model and Tokenizer

In [5]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
roberta_base = RobertaModel.from_pretrained('roberta-base')

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Dataset

In [6]:
!pip install datasets



In [7]:
# GPT- wiki-intro
# https://huggingface.co/datasets/aadityaubhat/GPT-wiki-intro
from datasets import load_dataset

dataset = load_dataset("aadityaubhat/GPT-wiki-intro")['train']

In [8]:
# truncate
def truncate(example):
    """
    Truncate 'wiki_intro' and 'generated_intro' to shorter length
    """
    min_length = min(len(example['wiki_intro']), len(example['generated_intro']))
    truncated_wiki_intro = example['wiki_intro'][:min_length]
    truncated_generated_intro = example['generated_intro'][:min_length]

    return {
        'wiki_intro': truncated_wiki_intro,
        'generated_intro': truncated_generated_intro,
        'title_len': example['title_len'],
        'wiki_intro_len': example['wiki_intro_len'],
        'generated_intro_len': example['generated_intro_len'],
        'prompt_tokens': example['prompt_tokens'],
        'generated_text_tokens': example['generated_text_tokens']
    }


# Wiki_train_data = Wiki_train_data.map(truncate)
# Wiki_val_data = Wiki_val_data.map(truncate)
# Wiki_test_data = Wiki_test_data.map(truncate)
Wiki_data = dataset.map(truncate)

In [9]:
# Generate labels
Wiki_texts = Wiki_data['wiki_intro'] + Wiki_data['generated_intro']

# 1 for human generated, 0 for machine generated
Wiki_labels = [1] * len(Wiki_data['wiki_intro']) + [0] * len(Wiki_data['generated_intro'])

In [10]:
def downsample_data(texts, labels, num_samples=2000):
    random.seed(42)
    combined_data = list(zip(texts, labels))
    sampled_data = random.sample(combined_data, num_samples)
    sampled_texts, sampled_labels = zip(*sampled_data)
    sampled_indices = [texts.index(text) for text, label in sampled_data]

    return list(sampled_texts), list(sampled_labels), sampled_indices

In [11]:
Wiki_sampled_texts, Wiki_sampled_labels, Wiki_sampled_indices = downsample_data(Wiki_texts, Wiki_labels, 2000)

Wiki_train_texts = Wiki_sampled_texts[:1700]
Wiki_train_labels = Wiki_sampled_labels[:1700]
Wiki_train_indices = Wiki_sampled_indices[:1700]

Wiki_val_texts = Wiki_sampled_texts[1700:1850]
Wiki_val_labels = Wiki_sampled_labels[1700:1850]
Wiki_val_indices = Wiki_sampled_indices[1700:1850]

Wiki_test_texts = Wiki_sampled_texts[1850:]
Wiki_test_labels = Wiki_sampled_labels[1850:]
Wiki_test_indices = Wiki_sampled_indices[1850:]

In [12]:
print(f"Training Data Size: {len(Wiki_train_texts)}")
print(f"Validation Data Size: {len(Wiki_val_texts)}")
print(f"Testing Data Size: {len(Wiki_test_texts)}")

print(Wiki_train_indices)
print(Wiki_val_indices)
print(Wiki_test_indices)

print("first string in Wiki_train_texts:", Wiki_train_texts[0])
print("first label in Wiki_train_texts:", Wiki_train_labels[0])
print("first index in Wiki_train_texts:", Wiki_train_indices[0])
print("this index in texts:", Wiki_texts[Wiki_train_indices[0]])

Training Data Size: 1700
Validation Data Size: 150
Testing Data Size: 150
[58369, 13112, 144194, 128393, 117026, 73158, 53736, 285929, 45580, 221208, 16663, 15622, 49123, 114629, 121981, 264951, 13912, 294254, 104248, 285706, 219949, 115574, 235514, 145852, 3407, 83707, 221571, 178389, 145684, 81516, 112886, 176472, 53587, 48625, 199191, 50707, 188208, 180331, 138685, 22780, 240870, 281137, 65444, 198461, 41313, 289428, 153709, 189600, 100814, 36466, 24025, 119484, 151722, 41833, 122049, 52953, 199295, 145738, 237717, 191277, 85277, 194081, 186264, 109842, 139973, 37435, 89725, 280043, 128351, 85669, 242357, 198943, 141530, 292002, 115141, 170017, 29327, 120087, 16829, 165388, 210325, 140373, 34701, 110615, 297365, 164981, 111477, 261740, 207425, 240570, 74905, 138873, 73206, 129303, 294318, 282579, 137752, 224622, 209400, 189790, 114987, 72525, 267138, 258744, 47662, 24702, 57487, 80132, 83876, 221333, 33306, 201728, 200078, 245392, 277408, 131813, 290049, 6019, 60058, 281527, 139893,

In [13]:
# PubMedQA
# https://pubmedqa.github.io/

# a directory structure in Files:
# data/ori_pqaa.json      - 2.6 MB Downloaded from https://drive.google.com/file/d/15v1x6aQDlZymaHGP7cZJZZYFfeJt2NdS/view
# data/ori_pqal.json      - 533.4 MB Downloaded from https://github.com/pubmedqa/pubmedqa/blob/master/data/ori_pqal.json
import json
import random

ori_pqal_path = './gdrive/MyDrive/CPSC_588_dataset/ori_pqal.json'
with open(ori_pqal_path, 'r') as file:
    ori_pqal = json.load(file)
machine_generated_dataset = [{"text": item["LONG_ANSWER"], "label": "0"} for item in ori_pqal.values()]

ori_pqaa_path = './gdrive/MyDrive/CPSC_588_dataset/ori_pqaa.json'
with open(ori_pqaa_path, 'r') as file:
    ori_pqaa = json.load(file)
human_generated_dataset = [{"text": item["LONG_ANSWER"], "label": "1"} for item in ori_pqaa.values()]

random.seed(0)
human_generated_dataset = random.sample(human_generated_dataset, 1000)

print(machine_generated_dataset[0]) # {'text': '...', 'label': 'machine_generated'}
print(human_generated_dataset[0]) # {'text': '...', 'label': 'human_generated'}

{'text': 'Results depicted mitochondrial dynamics in vivo as PCD progresses within the lace plant, and highlight the correlation of this organelle with other organelles during developmental PCD. To the best of our knowledge, this is the first report of mitochondria and chloroplasts moving on transvacuolar strands to form a ring structure surrounding the nucleus during developmental PCD. Also, for the first time, we have shown the feasibility for the use of CsA in a whole plant system. Overall, our findings implicate the mitochondria as playing a critical and early role in developmentally regulated PCD in the lace plant.', 'label': '0'}
{'text': 'Transgender patients are not accessing the same level of preventive cervical screening care as non-transgender female patients. There is a need to better understand barriers to care in this population. Contrary to findings in other settings, history of sex with women was not negatively associated with Pap utilization.', 'label': '1'}


In [14]:
combined_dataset = machine_generated_dataset + human_generated_dataset

texts = [item['text'] for item in combined_dataset]
labels = [int(item['label']) for item in combined_dataset]

PMQA_train_texts, temp_texts, PMQA_train_labels, temp_labels = train_test_split(
    texts, labels, test_size=0.2, random_state=42)

PMQA_val_texts, PMQA_test_texts, PMQA_val_labels, PMQA_test_labels = train_test_split(
    temp_texts, temp_labels, test_size=0.5, random_state=42)

In [15]:
PMQA_sampled_texts, PMQA_sampled_labels, PMQA_sampled_indices = downsample_data(texts, labels, 2000)

PMQA_train_texts = PMQA_sampled_texts[:1700]
PMQA_train_labels = PMQA_sampled_labels[:1700]
PMQA_train_indices = PMQA_sampled_indices[:1700]

PMQA_val_texts = PMQA_sampled_texts[1700:1850]
PMQA_val_labels = PMQA_sampled_labels[1700:1850]
PMQA_val_indices = PMQA_sampled_indices[1700:1850]

PMQA_test_texts = PMQA_sampled_texts[1850:]
PMQA_test_labels = PMQA_sampled_labels[1850:]
PMQA_test_indices = PMQA_sampled_indices[1850:]

In [16]:
print(f"Training Data Size: {len(PMQA_train_texts)}")
print(f"Validation Data Size: {len(PMQA_val_texts)}")
print(f"Testing Data Size: {len(PMQA_test_texts)}")

print("first string in PMQA_train_texts:", PMQA_train_texts[0])
print("first label in PMQA_train_texts:", PMQA_train_labels[0])
print("first index in PMQA_train_texts:", PMQA_train_indices[0])
print("this index in texts:", texts[PMQA_train_indices[0]])

Training Data Size: 1700
Validation Data Size: 150
Testing Data Size: 150
first string in PMQA_train_texts: The results suggest that rTMS is an effective and safe therapy in patients with PHN.
first label in PMQA_train_texts: 1
first index in PMQA_train_texts: 1309
this index in texts: The results suggest that rTMS is an effective and safe therapy in patients with PHN.


### Data Preparation

In [17]:
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, item):
        text = str(self.texts[item])
        label = self.labels[item]

        encoding = self.tokenizer.encode_plus(
          text,
          add_special_tokens=True,
          max_length=self.max_len,
          return_token_type_ids=False,
          padding='max_length',
          return_attention_mask=True,
          return_tensors='pt',
          truncation=True
        )

        return {
          'text': text,
          'input_ids': encoding['input_ids'].flatten(),
          'attention_mask': encoding['attention_mask'].flatten(),
          'labels': torch.tensor(label, dtype=torch.long)
        }

Wiki_train_dataset = TextDataset(Wiki_train_texts, Wiki_train_labels, tokenizer)
Wiki_val_dataset = TextDataset(Wiki_train_texts, Wiki_train_labels, tokenizer)
Wiki_test_dataset = TextDataset(Wiki_train_texts, Wiki_train_labels, tokenizer)

Wiki_train_loader = DataLoader(Wiki_train_dataset, batch_size=16, shuffle=True)
Wiki_val_loader = DataLoader(Wiki_val_dataset, batch_size=16, shuffle=True)
Wiki_test_loader = DataLoader(Wiki_test_dataset, batch_size=16, shuffle=True)

PMQA_train_dataset = TextDataset(PMQA_train_texts, PMQA_train_labels, tokenizer)
PMQA_val_dataset = TextDataset(PMQA_val_texts, PMQA_val_labels, tokenizer)
PMQA_test_dataset = TextDataset(PMQA_test_texts, PMQA_test_labels, tokenizer)

PMQA_train_loader = DataLoader(PMQA_train_dataset, batch_size=16, shuffle=True)
PMQA_val_loader = DataLoader(PMQA_val_dataset, batch_size=16, shuffle=False)
PMQA_test_loader = DataLoader(PMQA_test_dataset, batch_size=16, shuffle=False)

### Create a Custom Classifier

In [18]:
class BaselineRobertaClassifier(nn.Module):
    def __init__(self, roberta_base):
        super(BaselineRobertaClassifier, self).__init__()
        self.roberta = roberta_base
        self.classifier = nn.Linear(roberta_base.config.hidden_size, 2)

    def forward(self, input_ids, attention_mask):
        outputs = self.roberta(input_ids, attention_mask)
        pooled_output = outputs[1]
        logits = self.classifier(pooled_output)
        return logits

class RobertaClassifier(nn.Module):
    def __init__(self, roberta_base, stat_emb_dim, fusion_type='early'):
        super(RobertaClassifier, self).__init__()
        self.fusion_type = fusion_type
        self.roberta = roberta_base

        # Non-linear transformation for statistical embeddings
        self.stat_emb_transform = nn.Linear(stat_emb_dim, stat_emb_dim)
        self.activation = nn.ReLU()

        if fusion_type == 'early':
            self.classifier = nn.Linear(roberta_base.config.hidden_size + stat_emb_dim, 2)
        else:  # late fusion
            self.classifier = nn.Linear(roberta_base.config.hidden_size, 2)
            self.stat_emb_classifier = nn.Linear(stat_emb_dim, 2)

            # Conditional layer
            self.conditional_weights = nn.Linear(stat_emb_dim, roberta_base.config.hidden_size)

    def forward(self, input_ids, attention_mask, statistical_features):
        outputs = self.roberta(input_ids, attention_mask)
        pooled_output = outputs[1]

        # Apply non-linear transformation to statistical features
        transformed_stat_features = self.activation(self.stat_emb_transform(statistical_features))

        if self.fusion_type == 'early':
            combined_output = torch.cat((pooled_output, transformed_stat_features), dim=1)
            return self.classifier(combined_output)
        else:  # late fusion
            # Apply conditional layer
            conditional_weights = torch.sigmoid(self.conditional_weights(transformed_stat_features))
            conditioned_roberta_output = pooled_output * conditional_weights

            logits_from_roberta = self.classifier(conditioned_roberta_output)
            logits_from_stat_emb = self.stat_emb_classifier(transformed_stat_features)
            combined_logits = logits_from_roberta + logits_from_stat_emb
            return combined_logits

### Calculate the statistical features

In [19]:
import nltk
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('tagsets')
nltk.help.upenn_tagset()


$: dollar
    $ -$ --$ A$ C$ HK$ M$ NZ$ S$ U.S.$ US$
'': closing quotation mark
    ' ''
(: opening parenthesis
    ( [ {
): closing parenthesis
    ) ] }
,: comma
    ,
--: dash
    --
.: sentence terminator
    . ! ?
:: colon or ellipsis
    : ; ...
CC: conjunction, coordinating
    & 'n and both but either et for less minus neither nor or plus so
    therefore times v. versus vs. whether yet
CD: numeral, cardinal
    mid-1890 nine-thirty forty-two one-tenth ten million 0.5 one forty-
    seven 1987 twenty '79 zero two 78-degrees eighty-four IX '60s .025
    fifteen 271,124 dozen quintillion DM2,000 ...
DT: determiner
    all an another any both del each either every half la many much nary
    neither no some such that the them these this those
EX: existential there
    there
FW: foreign word
    gemeinschaft hund ich jeux habeas Haementeria Herr K'ang-si vous
    lutihaw alai je jour objets salutaris fille quibusdam pas trop Monte
    terram fiche oui corporis ...
IN: preposition or

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package tagsets to /root/nltk_data...
[nltk_data]   Package tagsets is already up-to-date!


In [20]:
from nltk.data import load
upenn_tagset_info = load('help/tagsets/upenn_tagset.pickle')
upenn_tagset = list(upenn_tagset_info.keys())
#print(upenn_tagset)
#print(len(upenn_tagset))
for index, tag in enumerate(upenn_tagset):
    print(f"index:{index} , tag:{tag}")
upenn_tagset_meaningful = upenn_tagset[0:3] + upenn_tagset[4:9] + upenn_tagset[10:14] + upenn_tagset[15:19] + upenn_tagset[25:]
#print(upenn_tagset_meaningful)
#print(len(upenn_tagset_meaningful))

index:0 , tag:LS
index:1 , tag:TO
index:2 , tag:VBN
index:3 , tag:''
index:4 , tag:WP
index:5 , tag:UH
index:6 , tag:VBG
index:7 , tag:JJ
index:8 , tag:VBZ
index:9 , tag:--
index:10 , tag:VBP
index:11 , tag:NN
index:12 , tag:DT
index:13 , tag:PRP
index:14 , tag::
index:15 , tag:WP$
index:16 , tag:NNPS
index:17 , tag:PRP$
index:18 , tag:WDT
index:19 , tag:(
index:20 , tag:)
index:21 , tag:.
index:22 , tag:,
index:23 , tag:``
index:24 , tag:$
index:25 , tag:RB
index:26 , tag:RBR
index:27 , tag:RBS
index:28 , tag:VBD
index:29 , tag:IN
index:30 , tag:FW
index:31 , tag:RP
index:32 , tag:JJR
index:33 , tag:JJS
index:34 , tag:PDT
index:35 , tag:MD
index:36 , tag:VB
index:37 , tag:WRB
index:38 , tag:NNP
index:39 , tag:EX
index:40 , tag:NNS
index:41 , tag:SYM
index:42 , tag:CC
index:43 , tag:CD
index:44 , tag:POS


In [21]:
import json
import pandas as pd

In [22]:
def calculate_tag_dist(text: str):
    text = nltk.tokenize.word_tokenize(text)
    tagged_text = nltk.pos_tag(text)
    tag_fd = nltk.FreqDist(tag for (word, tag) in tagged_text)
    tag_count = [tag_fd.get(tag, 0) for tag in upenn_tagset_meaningful]
    count_sum = sum(tag_count)
    tag_dist = [count / count_sum for count in tag_count]
    # tag_dist = [tag_fd.freq(tag) for tag in tag_fd]
    # print(dict(tag_fd))
    # print("length", len(tag_dist))
    # print(tag_dist)
    return tag_dist

def calculate_statistical_features_pos(input_text):
    # Implement the logic to calculate statistical features
    # This function should return a tensor of shape [batch_size, stat_emb_dim]
    # pos tag distribution
    pos_tag_dists = [calculate_tag_dist(text) for text in input_text]
    return torch.tensor(pos_tag_dists)

In [23]:
!pip install py-readability-metrics



In [24]:
from readability import Readability

In [25]:
def calculate_readability_metrics(text):
    metrics = []
    r = Readability(text)
    metrics.append(r.flesch_kincaid().score)
    metrics.append(r.flesch().score)
    metrics.append(r.gunning_fog().score)
    metrics.append(r.coleman_liau().score)
    metrics.append(r.dale_chall().score)
    metrics.append(r.ari().score)
    metrics.append(r.linsear_write().score)
    # metrics.append(r.smog().score)
    metrics.append(r.spache().score)
    return metrics

def calculate_statistical_features_readability(input_text):
    # Implement the logic to calculate statistical features
    # This function should return a tensor of shape [batch_size, stat_emb_dim]
    # readability measure
    readability_metrics = []
    for text in input_text:
        appended = False
        text_pad = text
        for i in range(100):
            try:
                readability_metrics.append(calculate_readability_metrics(text_pad))
                appended = True
            except:
                text_pad += text
                continue
            break
        if i == 99:
            if not appended:
                readability_metrics.append(([0] * 8))
    return torch.tensor(readability_metrics)

In [26]:
def calculate_statistical_features(input_text):
    pos_embs = calculate_statistical_features_pos(input_text)
    read_embs = calculate_statistical_features_readability(input_text)
    return torch.cat((pos_embs, read_embs), dim=1)

### Model Training

In [27]:
def train_epoch_baseline(model, data_loader, loss_fn, optimizer, device, n_examples):
    model.train()
    losses = []
    correct_predictions = 0

    for d in tqdm(data_loader, total=len(data_loader), desc="Training"):
        input_ids = d["input_ids"].to(device)
        attention_mask = d["attention_mask"].to(device)
        labels = d["labels"].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        _, preds = torch.max(outputs, dim=1)
        loss = loss_fn(outputs, labels)

        correct_predictions += torch.sum(preds == labels)
        losses.append(loss.item())

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        optimizer.zero_grad()

    return correct_predictions.double() / n_examples, np.mean(losses)

def train_epoch_stat_only(model, data_loader, loss_fn, optimizer, device, n_examples, calculate_stat_features):
    model.train()
    losses = []
    correct_predictions = 0

    for d in tqdm(data_loader, total=len(data_loader), desc="Training"):
        input_ids = d["input_ids"].to(device)
        attention_mask = d["attention_mask"].to(device)
        labels = d["labels"].to(device)
        texts = d["text"]
        stat_features = calculate_stat_features(texts).to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, statistical_features=stat_features)
        _, preds = torch.max(outputs, dim=1)
        loss = loss_fn(outputs, labels)

        correct_predictions += torch.sum(preds == labels)
        losses.append(loss.item())

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        optimizer.zero_grad()

    return correct_predictions.double() / n_examples, np.mean(losses)

def validate_epoch(model, data_loader, loss_fn, device, n_examples, calculate_stat_features):
    model.eval()
    losses = []
    correct_predictions = 0

    with torch.no_grad():
        for d in data_loader:
            input_ids = d["input_ids"].to(device)
            attention_mask = d["attention_mask"].to(device)
            labels = d["labels"].to(device)
            texts = d["text"]
            stat_features = calculate_stat_features(texts).to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, statistical_features=stat_features)
            _, preds = torch.max(outputs, dim=1)
            loss = loss_fn(outputs, labels)

            correct_predictions += torch.sum(preds == labels)
            losses.append(loss.item())

    return correct_predictions.double() / n_examples, np.mean(losses)

def eval_model(model, data_loader, loss_fn, device, n_examples, calculate_stat_features):
    model = model.eval()
    losses = []
    correct_predictions = 0

    with torch.no_grad():
        for d in tqdm(data_loader, total=len(data_loader), desc="Testing"):
            input_ids = d["input_ids"].to(device)
            attention_mask = d["attention_mask"].to(device)
            labels = d["labels"].to(device)
            stat_features = calculate_stat_features(d["text"]).to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, statistical_features=stat_features)
            _, preds = torch.max(outputs, dim=1)
            loss = loss_fn(outputs, labels)

            correct_predictions += torch.sum(preds == labels)
            losses.append(loss.item())

    return correct_predictions.double() / n_examples, np.mean(losses)


In [72]:
# Initialize model
dataset = "Wiki"
# dataset = "PMQA"

num_epochs = 5
stat_emb_dim = 8
fusion_type = "late"

model = RobertaClassifier(roberta_base, stat_emb_dim, fusion_type)
stat_fn = calculate_statistical_features_readability

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

optimizer = AdamW(model.parameters(), lr=2e-5)
loss_fn = nn.CrossEntropyLoss().to(device)

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    print('-' * 10)

    train_acc, train_loss = train_epoch_stat_only(
        model,
        Wiki_train_loader if dataset == "Wiki" else PMQA_train_loader,
        loss_fn,
        optimizer,
        device,
        len(Wiki_train_dataset) if dataset == "Wiki" else len(PMQA_train_dataset),
        stat_fn
    )
    print(f'Train loss {train_loss}, accuracy {train_acc}')

    val_acc, val_loss = validate_epoch(
        model,
        Wiki_val_loader if dataset == "Wiki" else PMQA_val_loader,
        loss_fn,
        device,
        len(Wiki_val_dataset) if dataset == "Wiki" else len(PMQA_val_dataset),
        stat_fn
    )
    print(f'Validation loss {val_loss}, accuracy {val_acc}')


Epoch 1/5
----------


Training: 100%|██████████| 107/107 [01:03<00:00,  1.68it/s]


Train loss 1.3131952709823012, accuracy 0.6711764705882353
Validation loss 0.06731624963963143, accuracy 0.9970588235294117
Epoch 2/5
----------


Training: 100%|██████████| 107/107 [01:03<00:00,  1.68it/s]


Train loss 0.026237429165342285, accuracy 0.9988235294117647
Validation loss 0.008272262047497538, accuracy 0.9988235294117647
Epoch 3/5
----------


Training: 100%|██████████| 107/107 [01:03<00:00,  1.68it/s]


Train loss 0.00638241946880018, accuracy 0.9988235294117647
Validation loss 0.004416423296516709, accuracy 0.9988235294117647
Epoch 4/5
----------


Training: 100%|██████████| 107/107 [01:03<00:00,  1.68it/s]


Train loss 0.004136909768268222, accuracy 0.9988235294117647
Validation loss 0.0029417350931772003, accuracy 0.9994117647058822
Epoch 5/5
----------


Training: 100%|██████████| 107/107 [01:03<00:00,  1.68it/s]


Train loss 0.0031387488603609325, accuracy 0.9994117647058822
Validation loss 0.0019651686796639985, accuracy 0.9994117647058822


### Model Evaluation

In [73]:
test_acc, test_loss = eval_model(
    model,
    Wiki_test_loader if dataset == "Wiki" else PMQA_test_loader,
    loss_fn,
    device,
    len(Wiki_test_dataset) if dataset == "Wiki" else len(PMQA_test_dataset),
    stat_fn
)

print(f'Test loss {test_loss}, accuracy {test_acc}')

Testing: 100%|██████████| 107/107 [00:31<00:00,  3.36it/s]

Test loss 0.0019573090134501875, accuracy 0.9994117647058822





In [None]:
"""
def preprocess(texts):
    # Tokenize the texts - this can be a single string or a list of strings
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
    return inputs

def predict(model, texts, device):
    model.eval()  # Set the model to evaluation mode

    # Preprocess and tokenize the texts
    inputs = preprocess(texts)
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    # Get predictions
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        predictions = torch.argmax(outputs, dim=1)

    return predictions.cpu().numpy()  # Return predictions as numpy array

text = "As ILC2s are elevated in patients with CRSwNP, they may drive nasal polyp formation in CRS. ILC2s are also linked with high tissue and blood eosinophilia and have a potential role in the activation and survival of eosinophils during the Th2 immune response. The association of innate lymphoid cells in CRS provides insights into its pathogenesis."
single_prediction = predict(model, text, device)
print(single_prediction[0])
"""