In [1]:

#一定要先，不然torch會偵測不到
!export CUDA_VISIBLE_DEVICES=3
%set_env CUDA_VISIBLE_DEVICES=3


env: CUDA_VISIBLE_DEVICES=3


In [2]:
import os
import torch
import pickle
import random
import numpy as np
from torch.utils.data import DataLoader, TensorDataset, ConcatDataset, Subset
from transformers import BertModel, BertConfig
from torch.optim import Adam
from sklearn.preprocessing import StandardScaler
from torch.nn.utils import clip_grad_value_
from tqdm import tqdm
from data.dataloader import CustomDataset
from model.multi_bert import multiBert
from data.scale import get_scaled_down_scores, separate_and_rescale_attributes_for_scoring
from utils.multitask_evaluator_all_attributes import Evaluator
from torch.utils.data import random_split


In [4]:

torch.manual_seed(11)

class NerConfig:
    def __init__(self):
        self.lr = 1e-5
        self.epoch = 15
        self.batch_size = 10
        self.device = "cuda"
        self.num_trait = 9
        self.alpha = 0.7
        self.delta = 0.7
        self.filter_num = 100
        self.chunk_sizes = [90, 30, 130, 10]
        self.data_file = "/home/tsaibw/Multi_scale/ckps/feacture"
        self.hidden_dim = 100
        self.mhd_head = 2
args = NerConfig()

In [13]:
train_dataset = CustomDataset(f"/home/tsaibw/Multi_scale/dataset/new_train/encode_prompt_1.pkl")
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)


In [16]:
for batch in train_loader:
    for key in batch.keys():
        value = batch[key]
        if isinstance(value, torch.Tensor):
            print(f"{key}: {value.shape}")  # 打印張量的形狀
        else:
            print(f"{key}: {type(value)}")  # 如果不是張量，打印類型

    print(batch['chunked_documents'][0].shape)
    print(batch['chunked_documents'][1].shape)
    print(batch['length'][0].shape)
    break  # 測試一個批次即可


prompt_id: torch.Size([10])
document_single: torch.Size([10, 3, 3, 512])
chunked_documents: <class 'list'>
length: <class 'list'>
hand_craft: torch.Size([10, 52])
readability: torch.Size([10, 34])
scaled_score: torch.Size([10, 9])
torch.Size([10, 17, 3, 90])
torch.Size([10, 53, 3, 30])
torch.Size([10])


In [5]:
# train normalize
# multi trait
def save_checkpoint(state, filename="checkpoint.pth.tar"):
    torch.save(state, filename)


def print_gradients(model):
    for name, parameter in model.named_parameters():
        if parameter.grad is not None:
            print(f"{name} - Gradient Norm: {parameter.grad.norm().item()}")
        else:
            print(f"{name} - No gradient")


def get_reduced_dataset(dataset, ratio=0.1):
    dataset_size = len(dataset)
    reduced_size = int(dataset_size * ratio)
    return Subset(dataset, range(reduced_size))  


for i in range(4,9):
    multi_bert_model = multiBert(args)  
    multi_bert_model.to(args.device)  
    optimizer = Adam(multi_bert_model.parameters(), lr = args.lr)
    train_dataset = CustomDataset(f"/home/tsaibw/Multi_scale/dataset/new_train/encode_prompt_{i}.pkl")
    eval_dataset = CustomDataset(f"/home/tsaibw/Multi_scale/dataset/new_dev/encode_prompt_{i}.pkl")
    test_dataset = CustomDataset(f"/home/tsaibw/Multi_scale/dataset/new_test/encode_prompt_{i}.pkl")
    
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
    dev_loader = DataLoader(eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
    evaluator = Evaluator(eval_dataset, test_dataset, 11)
    
    train_loss_list , eval_loss_list, test_loss_list = [] ,[] , []
    os.makedirs(f"{args.data_file}/prompt{i}", exist_ok=True)
    
    for epoch in range(args.epoch):
        multi_bert_model.train()
        total_loss = 0

        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{args.epoch}"):
            document_single = batch.get("document_single")
            chunked_documents = batch.get("chunked_documents")
            scaled_score = batch.get("scaled_score")
            prompt_id = batch.get("prompt_id")
            lengths = batch.get("lengths")
            readability = batch.get("readability")
            hand_craft = batch.get("hand_craft")

            document_single = document_single.to(args.device)
            optimizer.zero_grad()
            
            loss, predict_score, scaled_score = multi_bert_model(
                    prompt_id = prompt_id,
                    document_single=document_single,
                    chunked_documents=chunked_documents,
                    device=args.device,
                    lengths=lengths,
                    readability = readability.to(args.device),
                    hand_craft = hand_craft.to(args.device),
                    scaled_score = scaled_score
            )
            
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
        
        # eval_loader = eval_dataset + test_dataset
        eval_loss, test_loss, result = multi_bert_model.evaluate(dev_loader, test_loader, epoch, evaluator, device=args.device)
        
        print(f"Epoch {epoch}, Train Loss: {total_loss / len(train_loader)}")
        print(f"Eval Loss: {eval_loss}")
        print(f"Test Loss: {test_loss}")
        train_loss_list.append(total_loss / len(train_loader))
        eval_loss_list.append(eval_loss)
        test_loss_list.append(test_loss)
        
        qwk_path = f"{args.data_file}/prompt{i}/result.txt"
        with open(qwk_path, "a") as f:
            f.write(f"Epoch {epoch + 1}/{args.epoch}, result:{result}, train_loss: {train_loss_list[-1]}, eval_loss: {eval_loss_list[-1]}, test_loss: {test_loss_list[-1]} \n")
  
        checkpoint_path = f"{args.data_file}/prompt{i}/epoch_{epoch+1}_checkpoint.pth.tar"
        save_checkpoint({
          'epoch': epoch + 1,
          'state_dict': multi_bert_model.state_dict(),
          'optimizer': optimizer.state_dict(),
          'train_loss': total_loss / len(train_loader),
          'eval_loss': eval_loss
        }, filename = checkpoint_path)


Epoch 1/15:   0% 0/953 [00:01<?, ?it/s]


TypeError: unsupported operand type(s) for +=: 'int' and 'str'