In [1]:
!export CUDA_VISIBLE_DEVICES=3
%set_env CUDA_VISIBLE_DEVICES=3

env: CUDA_VISIBLE_DEVICES=3


In [2]:
import fire
import os
import torch
import pickle
from transformers import TrainingArguments
from torch.utils.data import RandomSampler, DataLoader
from data.dataloader import CustomDataset
from model.multi_bert import multiBert as Model
from trainers import BertTrainer
from utils.callbacks import EvaluateRecord
from utils.general_utils import seed_all
from utils.multitask_evaluator_all_attributes import Evaluator
from safetensors.torch import load_file
from torch.utils.data.dataloader import default_collate
class NerConfig:
    def __init__(self):
        self.lr = 1e-5
        self.epoch = 15
        self.batch_size = 10
        self.device = "cuda"
        self.num_trait = 9
        self.alpha = 0.7
        self.delta = 0.7
        self.filter_num = 100
        self.chunk_sizes = [90, 30, 130, 10]
        self.data_file = "/home/tsaibw/Multi_scale/ckps/feacture"
        self.hidden_dim = 100 # chunk & linear output_dim
        self.mhd_head = 2
args = NerConfig()
from torch.utils.data import Subset
import numpy as np

# 使用固定隨機種子以保證子集一致性
def get_subset(dataset, fraction=0.1, seed=42):
    np.random.seed(seed)
    dataset_size = len(dataset)
    indices = np.random.permutation(dataset_size)[:int(fraction * dataset_size)]
    return Subset(dataset, indices)

# 修改 train_dataset



In [None]:
def train(
    test_prompt_id: int = 1,
    experiment_tag: str = "test",
    seed: int = 11,
    num_train_epochs: int = 14,
    batch_size: int = 8,
    gradient_accumulation: int = 2,
    learning_rate: float = 1e-5,
    weight_decay: float = 0.01,
    chunk_sizes: int = [90, 30, 130, 10]
):
    seed_all(seed)

#     train_dataset = get_subset(
#     CustomDataset(f"/home/tsaibw/Multi_scale/dataset/new_train/encode_prompt_{test_prompt_id}.pkl"), 
#     fraction=0.001
# )
    train_dataset = CustomDataset(f"/home/tsaibw/Multi_scale/dataset/new_train/encode_prompt_{test_prompt_id}.pkl")
    eval_dataset = CustomDataset(f"/home/tsaibw/Multi_scale/dataset/new_dev/encode_prompt_{test_prompt_id}.pkl")
    test_dataset = CustomDataset(f"/home/tsaibw/Multi_scale/dataset/new_test/encode_prompt_{test_prompt_id}.pkl")
    model = Model(
        args=args
    )
    evaluator = Evaluator(eval_dataset, test_dataset, seed)

    output_dir = f"ckpts/Curriculum/Epoch_20/prompt_{test_prompt_id}"
    training_args = TrainingArguments(
        output_dir=output_dir,
        learning_rate = learning_rate,
        num_train_epochs = num_train_epochs,
        per_device_train_batch_size = batch_size,
        per_device_eval_batch_size = batch_size,
        gradient_accumulation_steps = gradient_accumulation,
        logging_dir = f"logs/{experiment_tag}/Curriculum/Epoch_20/prompt_{test_prompt_id}",
        evaluation_strategy = "epoch",
        label_names = ["scaled_score"],
        save_strategy = "epoch",
        save_total_limit = 5,
        do_eval = True,
        load_best_model_at_end = True, 
        fp16 = False,
        remove_unused_columns = True,
        metric_for_best_model = "eval_test_avg",
        greater_is_better = True,
        seed = seed,
        data_seed = seed,
        ddp_find_unused_parameters = False
    )
            
    trainer = BertTrainer(
        model = model,
        args = training_args,
        train_dataset = train_dataset,
        eval_dataset = eval_dataset,
        test_dataset = test_dataset,
        evaluator = evaluator,
        callbacks = [EvaluateRecord(output_dir)],
        data_collator = default_collate,
    )

    print('Trainer is using device:', trainer.args.device)
    print(test_prompt_id)
    trainer.train()
    # trainer.train(resume_from_checkpoint = f"/home/tsaibw/ProTACT_pytorch/ckpts/Curriculum/prompt_{test_prompt_id}/checkpoint-epoch_10")

if __name__ == "__main__":
    for i in range(6,9):
        train(test_prompt_id = i)
    



Trainer is using device: cuda:0
6


Epoch,Training Loss,Validation Loss,Dev Avg,Test Avg,Dev Score,Dev Content,Dev Organization,Dev Word Choice,Dev Sentence Fluency,Dev Conventions,Dev Prompt Adherence,Dev Language,Dev Narrativity,Test Score,Test Content,Test Prompt Adherence,Test Language,Test Narrativity
1,0.0698,0.043584,0.662015,0.285415,0.942,0.819,0.673,0.719,0.709,0.601,0.491,0.499,0.504,0.269,0.259,0.26,0.331,0.308
2,0.0447,0.036089,0.673677,0.266574,0.943,0.813,0.68,0.745,0.736,0.642,0.488,0.504,0.512,0.225,0.254,0.252,0.312,0.289
3,0.0388,0.042767,0.733965,0.480573,0.954,0.811,0.738,0.802,0.752,0.673,0.622,0.628,0.625,0.499,0.467,0.445,0.496,0.496


Processing key: score
Pred shape: 1800, Original shape: 1800
Processing key: content
Pred shape: 1800, Original shape: 1800
Processing key: prompt_adherence
Pred shape: 1800, Original shape: 1800
Processing key: language
Pred shape: 1800, Original shape: 1800
Processing key: narrativity
Pred shape: 1800, Original shape: 1800
CURRENT EPOCH: 1.0
[DEV] AVG QWK: 0.662
[DEV] score QWK: 0.942
[DEV] content QWK: 0.819
[DEV] organization QWK: 0.673
[DEV] word_choice QWK: 0.719
[DEV] sentence_fluency QWK: 0.709
[DEV] conventions QWK: 0.601
[DEV] prompt_adherence QWK: 0.491
[DEV] language QWK: 0.499
[DEV] narrativity QWK: 0.504
------------------------
[TEST] AVG QWK: 0.285
[TEST] score QWK: 0.269
[TEST] content QWK: 0.259
[TEST] prompt_adherence QWK: 0.26
[TEST] language QWK: 0.331
[TEST] narrativity QWK: 0.308
------------------------
[BEST TEST] AVG QWK: 0.285, epoch: 1.0
[BEST TEST] score QWK: 0.269
[BEST TEST] content QWK: 0.259
[BEST TEST] prompt_adherence QWK: 0.26
[BEST TEST] language QW