In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from transformers import T5Tokenizer, T5ForConditionalGeneration, AdamW
import pytorch_lightning as pl
from sklearn.model_selection import train_test_split
from pytorch_lightning.callbacks import ModelCheckpoint
from sklearn.metrics import accuracy_score


class CustomDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_len):
        self.dataframe = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        row = self.dataframe.iloc[index]
        inp, hyp, label = row['inp'], row['hyp'], row['label']
        prompt = f'PROMPT: {inp} \n\n PARTIAL RESPONSE: {hyp}'

        inputs = self.tokenizer.encode_plus(
            prompt,
            padding='max_length',
            truncation=True,
            max_length=self.max_len,
            return_tensors='pt'
        )

        input_ids = inputs['input_ids'][0]
        attention_mask = inputs['attention_mask'][0]

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'label': torch.tensor(label, dtype=torch.long)
        }

class T5BinaryClassifier(pl.LightningModule):
    def __init__(self, model_name, tokenizer, learning_rate, max_len=512):
        super().__init__()

        self.model = T5ForConditionalGeneration.from_pretrained(model_name)
        
        self.tokenizer = tokenizer
        self.learning_rate = learning_rate
        self.max_len = max_len

    def forward(self, input_ids, attention_mask, labels=None):
        if len(input_ids.shape)==1:
            input_ids = input_ids.unsqueeze(0)
            attention_mask = attention_mask.unsqueeze(0)
        if labels is not None:
            labels = labels.unsqueeze(-1)
            return self.model(input_ids, attention_mask=attention_mask, labels=labels)
        else:
            return self.model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=2)

    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, labels = batch['input_ids'], batch['attention_mask'], batch['label']
        outputs = self(input_ids, attention_mask, labels)
        loss = outputs.loss
        self.log('train_loss', loss, sync_dist=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, labels = batch['input_ids'], batch['attention_mask'], batch['label']
        logits = self(input_ids, attention_mask)
        preds = torch.argmax(logits, dim=-1)
        
        acc = accuracy_score(labels.cpu().numpy(), preds.cpu().numpy())
        # self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        #self.log("val_accuracy", acc, on_step=False, on_epoch=True, prog_bar=True)
    
        accuracy = (preds == labels).float().mean()
        self.log("val_accuracy", acc, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)

        #self.log('val_accuracy', accuracy, sync_dist=True)
        return accuracy

    def configure_optimizers(self):
        return AdamW(self.parameters(), lr=self.learning_rate)

def balance_dataframe(dataframe):
    label_counts = dataframe['label'].value_counts()
    min_count = label_counts.min()
    balanced_data = dataframe.groupby('label').apply(lambda grp: grp.sample(min_count)).reset_index(drop=True)
    return balanced_data

def train_val_split(dataframe, test_size=0.2, random_state=42):
    unique_inp = dataframe['inp'].unique()
    train_inp, test_inp = train_test_split(unique_inp, test_size=test_size, random_state=random_state)
    
    train_df = dataframe[dataframe['inp'].isin(train_inp)].reset_index(drop=True)
    test_df = dataframe[dataframe['inp'].isin(test_inp)].reset_index(drop=True)
    
    return train_df, test_df

def train(dataframe, model_name='t5-small', epochs=2, batch_size=8, learning_rate=3e-5, max_len=512, val_interval=1):
    # Balance DataFrame and split into train and test
    dataframe = balance_dataframe(dataframe)
    train_df, test_df = train_val_split(dataframe)

    tokenizer = T5Tokenizer.from_pretrained(model_name)
    train_dataset = CustomDataset(train_df, tokenizer, max_len)
    test_dataset = CustomDataset(test_df, tokenizer, max_len)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=10)
    val_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=10)

    
    model = T5BinaryClassifier(model_name, tokenizer, learning_rate, max_len)

    checkpoint_callback = ModelCheckpoint(
        dirpath="./checkpoints",
        filename="{epoch:02d}-{val_accuracy:.4f}",
        save_top_k=2,
        monitor="val_accuracy",
        mode="max",
        save_last=True,
        save_weights_only=False,
        verbose=True,
    )
    trainer = pl.Trainer(
        max_epochs=epochs,
        # gpus=torch.cuda.device_count(),
        log_every_n_steps=val_interval,
        #check_val_every_n_epoch=val_interval,
        val_check_interval=500,
        #callbacks=[checkpoint_callback],
        enable_checkpointing=True,
        #early_stop_callback=None
    )
    trainer.fit(model, train_loader, val_loader, ckpt_path="./lightning_logs/version_3/checkpoints/epoch=0-step=2000.ckpt")
    
def validate(dataframe, model_name='t5-small', epochs=2, batch_size=8, learning_rate=3e-5, max_len=512, val_interval=1):

    tokenizer = T5Tokenizer.from_pretrained(model_name)
    test_dataset = CustomDataset(dataframe, tokenizer, max_len)

    val_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=10)

    
    model = T5BinaryClassifier(model_name, tokenizer, learning_rate, max_len)


    trainer = pl.Trainer(
        max_epochs=epochs,
        # gpus=torch.cuda.device_count(),
        log_every_n_steps=val_interval,
        #check_val_every_n_epoch=val_interval,
        val_check_interval=500,
        #callbacks=[checkpoint_callback],
        enable_checkpointing=True,
        #early_stop_callback=None
    )
    #trainer.fit(model, train_loader, val_loader, ckpt_path="./lightning_logs/version_3/checkpoints/epoch=0-step=2000.ckpt")
    trainer.validate(model, val_loader, ckpt_path="./lightning_logs/version_4/checkpoints/epoch=2-step=11896.ckpt")
    

    

In [2]:
torch.set_float32_matmul_precision('medium')

In [5]:
# Replace with your actual DataFrame
inpdf = pd.read_json("std_dataset.jsonl", lines=True, orient="records")
inpdf['label'] = (inpdf['sco']>.05).astype(int)

In [4]:
test_df = pd.read_json("testimp.jsonl", lines=True, orient="records")

In [28]:
# Balance DataFrame and split into train and test
dataframe = balance_dataframe(inpdf)
train_df, test_df = train_val_split(dataframe, 0.1)

In [5]:
test_df

Unnamed: 0,inp,prefix,max_scos_first,mean_scos_first,max_scos_best,mean_scos_best,stdev_first,stdev_best,sco,hyp,label
0,"Why is the word ""'nother"" so typically said af...","Another is not the same as another noun, it's ...",0.900835,0.817461,0.926091,0.907470,0.088781,0.021048,0.021048,"Another is not the same as another noun, it's ...",0
1,What Does make a Western country be considered...,Latin America is the only sub-region that is n...,0.887012,0.780066,0.898514,0.880013,0.156884,0.017371,0.017371,Latin America is the only sub-region that is n...,0
2,What is happening in C# when I call a method?C...,The In portion is the name of the method,0.767647,0.632500,0.770063,0.703528,0.157825,0.049528,0.049528,The In portion is the name of the method,0
3,"Why is ""Argument from authority"" considered a ...",Argument from Authority is used to argue that ...,0.839993,0.799229,0.860857,0.826473,0.036452,0.042735,0.042735,Argument from Authority is used to argue that ...,0
4,What would happen if the US just adopted all o...,The armed forces would have the same level of ...,0.936990,0.904340,0.959570,0.921739,0.024732,0.031181,0.031181,The armed forces would have the same level of ...,0
...,...,...,...,...,...,...,...,...,...,...,...
842,How does the 'scene' community work? Where \ni...,"The answer is, that there is no real money. Pe...",0.883097,0.796359,0.863940,0.818401,0.092224,0.061619,0.061619,"The answer is, that there is no real money. Pe...",1
843,- Why doesn't the moon rotate?This may be a st...,The moon rotates every 23.8 Earth days. That's...,0.812355,0.729764,0.898722,0.730280,0.080822,0.112762,0.112762,The moon rotates every 23.8 Earth days. That's...,1
844,"In detail, how do sperm actually come in conta...",While a man has to ejaculate semen (the male r...,0.805675,0.701103,0.859314,0.722655,0.106865,0.132386,0.132386,While a man has to ejaculate semen (the male r...,1
845,Are humans genetically inclined to stay with o...,A general theory is that long-term monogamy in...,0.895491,0.781531,0.911074,0.858425,0.122413,0.060815,0.060815,A general theory is that long-term monogamy in...,1


In [6]:
test_df['hyp']=test_df['prefix']

In [7]:
model_name = 'stanfordnlp/SteamSHP-flan-t5-large'
max_len = 512
learning_rate = 3e-5
batch_size=8
epochs=1
val_interval=1


In [10]:
del model

In [34]:
start = .05
end = .08
tmpdf = test_df[test_df['sco']<end].reset_index()
tmpdf = tmpdf[tmpdf['sco']>start].reset_index()

In [35]:
tokenizer = T5Tokenizer.from_pretrained(model_name)
test_dataset = CustomDataset(tmpdf, tokenizer, max_len)

val_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=10)

In [9]:
model = T5BinaryClassifier(model_name, tokenizer, learning_rate, max_len)
#model.to("cuda:0")

In [10]:
trainer = pl.Trainer(
    max_epochs=epochs,
    # gpus=torch.cuda.device_count(),
    log_every_n_steps=val_interval,
    devices=1,
    #check_val_every_n_epoch=val_interval,
    val_check_interval=500,
    #callbacks=[checkpoint_callback],
    enable_checkpointing=True,
    #early_stop_callback=None
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [36]:
#trainer.fit(model, train_loader, val_loader, ckpt_path="./lightning_logs/version_3/checkpoints/epoch=0-step=2000.ckpt")
trainer.validate(model, val_loader, ckpt_path="./lightning_logs/version_11/checkpoints/epoch=3-step=3359.ckpt")

Restoring states from the checkpoint path at ./lightning_logs/version_11/checkpoints/epoch=3-step=3359.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
Loaded model weights from the checkpoint at ./lightning_logs/version_11/checkpoints/epoch=3-step=3359.ckpt


Validation: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val_accuracy          0.6558139534883721
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_accuracy': 0.6558139534883721}]

Unnamed: 0,inp,prefix,max_scos_first,mean_scos_first,max_scos_best,mean_scos_best,stdev_first,stdev_best,sco,hyp,label
0,"Why is the word ""'nother"" so typically said af...","Another is not the same as another noun, it's ...",0.900835,0.817461,0.926091,0.907470,0.088781,0.021048,0.021048,"Another is not the same as another noun, it's ...",0
1,What Does make a Western country be considered...,Latin America is the only sub-region that is n...,0.887012,0.780066,0.898514,0.880013,0.156884,0.017371,0.017371,Latin America is the only sub-region that is n...,0
2,What is happening in C# when I call a method?C...,The In portion is the name of the method,0.767647,0.632500,0.770063,0.703528,0.157825,0.049528,0.049528,The In portion is the name of the method,0
3,"Why is ""Argument from authority"" considered a ...",Argument from Authority is used to argue that ...,0.839993,0.799229,0.860857,0.826473,0.036452,0.042735,0.042735,Argument from Authority is used to argue that ...,0
4,What would happen if the US just adopted all o...,The armed forces would have the same level of ...,0.936990,0.904340,0.959570,0.921739,0.024732,0.031181,0.031181,The armed forces would have the same level of ...,0
...,...,...,...,...,...,...,...,...,...,...,...
842,How does the 'scene' community work? Where \ni...,"The answer is, that there is no real money. Pe...",0.883097,0.796359,0.863940,0.818401,0.092224,0.061619,0.061619,"The answer is, that there is no real money. Pe...",1
843,- Why doesn't the moon rotate?This may be a st...,The moon rotates every 23.8 Earth days. That's...,0.812355,0.729764,0.898722,0.730280,0.080822,0.112762,0.112762,The moon rotates every 23.8 Earth days. That's...,1
844,"In detail, how do sperm actually come in conta...",While a man has to ejaculate semen (the male r...,0.805675,0.701103,0.859314,0.722655,0.106865,0.132386,0.132386,While a man has to ejaculate semen (the male r...,1
845,Are humans genetically inclined to stay with o...,A general theory is that long-term monogamy in...,0.895491,0.781531,0.911074,0.858425,0.122413,0.060815,0.060815,A general theory is that long-term monogamy in...,1


In [None]:
newexs = pd.read_json("output/bigdsetp4.jsonl", orient='records', lines=True)

In [41]:
def random_prefix_dataframe(df: pd.DataFrame) -> pd.DataFrame:
    new_rows = []

    for index, row in df.iterrows():
        input_str = row['inp']
        hyps = row['hyps']
        scos = row['scos']
        
        for hyp, sco in zip(hyps, scos):
            words = hyp.split(' ')
            prefix_length = random.randint(1, len(words))
            prefix = ' '.join(words[:prefix_length])

            new_row = {
                'inp': input_str,
                'hyp': prefix,
                'pflen':prefix_length,
                'sco': sco
            }
            new_rows.append(new_row)

    new_df = pd.DataFrame(new_rows)
    return new_df

In [43]:
import random

In [13]:
test_df

Unnamed: 0,inp,prefix,max_scos_first,mean_scos_first,max_scos_best,mean_scos_best,stdev_first,stdev_best,sco,hyp,label
0,"Why is the word ""'nother"" so typically said af...","Another is not the same as another noun, it's ...",0.900835,0.817461,0.926091,0.907470,0.088781,0.021048,0.021048,"Another is not the same as another noun, it's ...",0
1,What Does make a Western country be considered...,Latin America is the only sub-region that is n...,0.887012,0.780066,0.898514,0.880013,0.156884,0.017371,0.017371,Latin America is the only sub-region that is n...,0
2,What is happening in C# when I call a method?C...,The In portion is the name of the method,0.767647,0.632500,0.770063,0.703528,0.157825,0.049528,0.049528,The In portion is the name of the method,0
3,"Why is ""Argument from authority"" considered a ...",Argument from Authority is used to argue that ...,0.839993,0.799229,0.860857,0.826473,0.036452,0.042735,0.042735,Argument from Authority is used to argue that ...,0
4,What would happen if the US just adopted all o...,The armed forces would have the same level of ...,0.936990,0.904340,0.959570,0.921739,0.024732,0.031181,0.031181,The armed forces would have the same level of ...,0
...,...,...,...,...,...,...,...,...,...,...,...
842,How does the 'scene' community work? Where \ni...,"The answer is, that there is no real money. Pe...",0.883097,0.796359,0.863940,0.818401,0.092224,0.061619,0.061619,"The answer is, that there is no real money. Pe...",1
843,- Why doesn't the moon rotate?This may be a st...,The moon rotates every 23.8 Earth days. That's...,0.812355,0.729764,0.898722,0.730280,0.080822,0.112762,0.112762,The moon rotates every 23.8 Earth days. That's...,1
844,"In detail, how do sperm actually come in conta...",While a man has to ejaculate semen (the male r...,0.805675,0.701103,0.859314,0.722655,0.106865,0.132386,0.132386,While a man has to ejaculate semen (the male r...,1
845,Are humans genetically inclined to stay with o...,A general theory is that long-term monogamy in...,0.895491,0.781531,0.911074,0.858425,0.122413,0.060815,0.060815,A general theory is that long-term monogamy in...,1


In [65]:
def valid_subset(start, end):
    tmp = pfdf[pfdf['pflen']>start].reset_index(drop=True).copy()
    tmp = tmp[tmp['pflen']<end]
    print(len(tmp[tmp['sco']>0.85]))
    print(len(tmp))
    tmp['label']= (tmp['sco']>0.85).astype(int)
    tmp = balance_dataframe(tmp)
    tesdataset = CustomDataset(tmp, tokenizer, max_len)

    val_loader = DataLoader(tesdataset, batch_size=batch_size, num_workers=10)
    #trainer.fit(model, train_loader, val_loader, ckpt_path="./lightning_logs/version_3/checkpoints/epoch=0-step=2000.ckpt")
    trainer.validate(model, val_loader, ckpt_path="./lightning_logs/version_4/checkpoints/epoch=2-step=11896.ckpt")

In [72]:
valid_subset(60, 71)

Restoring states from the checkpoint path at ./lightning_logs/version_4/checkpoints/epoch=2-step=11896.ckpt


28
53


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
Loaded model weights from the checkpoint at ./lightning_logs/version_4/checkpoints/epoch=2-step=11896.ckpt


Validation: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val_accuracy                  0.7
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


In [55]:
pfdf[pfdf['pflen']<60]

Unnamed: 0,inp,hyp,pflen,sco
0,Robert McNamara and why so many people dislike...,He was one of the most,6,0.926491
1,Robert McNamara and why so many people dislike...,I know he was the secretary of defense during ...,25,0.632656
2,Robert McNamara and why so many people dislike...,"He is a very controversial figure, and one who...",48,0.862696
3,Robert McNamara and why so many people dislike...,"Robert McNamara, a member of the United States...",15,0.608410
4,Robert McNamara and why so many people dislike...,He was known for a lot of mistakes and squande...,24,0.798126
...,...,...,...,...
1923,How come people pass out do to an extreme amou...,"When you experience extreme g-force, you exper...",11,0.635739
1924,How come people pass out do to an extreme amou...,"When you experience extreme g-force, you exper...",9,0.820997
1925,How come people pass out do to an extreme amou...,"When you experience extreme g-force, you exper...",11,0.877012
1926,How come people pass out do to an extreme amou...,"When you experience extreme g-force, you exper...",33,0.883407
