In [1]:
!export CUDA_VISIBLE_DEVICES=3
%set_env CUDA_VISIBLE_DEVICES=3

env: CUDA_VISIBLE_DEVICES=3


In [2]:
import os
import torch
import pickle
from transformers import TrainingArguments
from torch.utils.data import RandomSampler, DataLoader
from data.dataloader import CustomDataset
from model.model_architechure_bert import multiBert as Model
from trainers import BertTrainer
from utils.callbacks import EvaluateRecord
from utils.general_utils import seed_all
from utils.multitask_evaluator_all_attributes import Evaluator
from safetensors.torch import load_file
from torch.utils.data.dataloader import default_collate
class NerConfig:
    def __init__(self):
        self.device = "cuda"
        self.num_trait = 9
        self.alpha = 0.7
        self.delta = 0.7
        self.filter_num = 100
        self.chunk_sizes = [90, 30, 130, 10]
        self.data_file = "/home/tsaibw/Multi_scale/ckps/feacture"
        self.hidden_dim = 100 # chunk & linear output_dim
        self.mhd_head = 2
args = NerConfig()


In [3]:
def train(
    test_prompt_id: int = 1,
    experiment_tag: str = "test",
    seed: int = 11,
    num_train_epochs: int = 15,
    batch_size: int = 12,
    gradient_accumulation: int = 1,
    learning_rate: float = 1e-5,
    weight_decay: float = 0.01,
    chunk_sizes: int = [90, 30, 130, 10]
):
    seed_all(seed)

    train_dataset = CustomDataset(f"/home/tsaibw/Multi_scale/dataset/var_norm/new_train/encode_prompt_{test_prompt_id}.pkl")
    eval_dataset = CustomDataset(f"/home/tsaibw/Multi_scale/dataset/var_norm/new_dev/encode_prompt_{test_prompt_id}.pkl")
    test_dataset = CustomDataset(f"/home/tsaibw/Multi_scale/dataset/var_norm/new_test/encode_prompt_{test_prompt_id}.pkl")
    model = Model(
        args=args
    )
    evaluator = Evaluator(eval_dataset, test_dataset, seed)

    output_dir = f"ckpts/trait_var/prompt_{test_prompt_id}"
    training_args = TrainingArguments(
        output_dir=output_dir,
        learning_rate = learning_rate,
        num_train_epochs = num_train_epochs,
        per_device_train_batch_size = batch_size,
        per_device_eval_batch_size = batch_size,
        gradient_accumulation_steps = gradient_accumulation,
        logging_dir = f"logs/{experiment_tag}/trait_var/prompt_{test_prompt_id}",
        evaluation_strategy = "epoch",
        label_names = ["scaled_score"],
        save_strategy = "epoch",
        save_total_limit = 5,
        do_eval = True,
        load_best_model_at_end = True, 
        fp16 = False,
        remove_unused_columns = True,
        metric_for_best_model = "eval_test_avg",
        greater_is_better = True,
        seed = seed,
        data_seed = seed,
        ddp_find_unused_parameters = False
    )
            
    trainer = BertTrainer(
        model = model,
        args = training_args,
        train_dataset = train_dataset,
        eval_dataset = eval_dataset,
        test_dataset = test_dataset,
        evaluator = evaluator,
        callbacks = [EvaluateRecord(output_dir)],
        data_collator = default_collate,
    )

    print('Trainer is using device:', trainer.args.device)
    print(test_prompt_id)
    trainer.train()
    # trainer.train(resume_from_checkpoint = )

def inference(test_prompt_id, model_path, data_path):
    
    train_dataset = CustomDataset(f"{data_path}/new_train/encode_prompt_{test_prompt_id}.pkl")
    eval_dataset = CustomDataset(f"{data_path}/new_dev/encode_prompt_{test_prompt_id}.pkl")
    test_dataset = CustomDataset(f"{data_path}/new_test/encode_prompt_{test_prompt_id}.pkl") 
    args = NerConfig()
    model = Model(
        args=args
    )
    weights = load_file(model_path)
    model.load_state_dict(weights)
    output_dir = f"ckpts/trait_var/prompt_{test_prompt_id}"
    training_args = TrainingArguments(
        output_dir = output_dir,
        per_device_eval_batch_size = 512,
        label_names=["scaled_score"],
        do_train = False, 
        do_eval = True,
    )
    evaluator = Evaluator(eval_dataset, test_dataset, seed=11)
    
    trainer = BertTrainer(
        model=model,
        args=training_args,
        eval_dataset = eval_dataset,
        test_dataset = test_dataset,
        evaluator = evaluator,
        data_collator = default_collate,
    )

    prediction_output = trainer.predict(eval_dataset)
if __name__ == "__main__":
    test_prompt_id = 1
    model_path = f'/home/tsaibw/Multi_scale/ckpts/trait_var/prompt_{test_prompt_id}/checkpoint-793/model.safetensors'
    data_path = f"/home/tsaibw/Multi_scale/dataset/var_norm"
    inference(test_prompt_id, model_path, data_path)
    # for i in range(1,9):
    #     train(test_prompt_id = i)



------------------------
CURRENT EPOCH: None
[DEV] AVG QWK: 0.807
[DEV] score QWK: 0.959
[DEV] content QWK: 0.882
[DEV] prompt_adherence QWK: 0.679
[DEV] language QWK: 0.691
[DEV] narrativity QWK: 0.699
[DEV] organization QWK: 0.847
[DEV] word_choice QWK: 0.85
[DEV] sentence_fluency QWK: 0.849
[DEV] conventions QWK: 0.809
------------------------
[TEST] AVG QWK: 0.598
[TEST] score QWK: 0.497
[TEST] content QWK: 0.578
[TEST] organization QWK: 0.643
[TEST] word_choice QWK: 0.636
[TEST] sentence_fluency QWK: 0.626
[TEST] conventions QWK: 0.606
------------------------
[BEST TEST] AVG QWK: 0.598, epoch: None
[BEST TEST] score QWK: 0.497
[BEST TEST] content QWK: 0.578
[BEST TEST] organization QWK: 0.643
[BEST TEST] word_choice QWK: 0.636
[BEST TEST] sentence_fluency QWK: 0.626
[BEST TEST] conventions QWK: 0.606
--------------------------------------------------------------------------------------------------------------------------
