In [None]:

#一定要先，不然torch會偵測不到
!export CUDA_VISIBLE_DEVICES=4
%set_env CUDA_VISIBLE_DEVICES=4


In [4]:
import os
import torch
import pickle
import random
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from transformers import BertModel, BertConfig
from torch.optim import Adam
from sklearn.preprocessing import StandardScaler
from torch.nn.utils import clip_grad_value_
from tqdm import tqdm
from data.dataloader import CustomDataset
from model.multi_bert import multiBert
from data.scale import get_scaled_down_scores, separate_and_rescale_attributes_for_scoring
from utils.evaluate import evaluation



In [5]:

torch.manual_seed(11)

class NerConfig:
    def __init__(self):
        self.lr = 1e-5
        self.epoch = 15
        self.batch_size = 12
        self.device = "cuda"
        self.chunk_sizes = [90]
        self.data_file = "/home/tsaibw/Multi_scale/ckps/chunk_90"
args = NerConfig()

In [6]:
# train normalize

def save_checkpoint(state, filename="checkpoint.pth.tar"):
    torch.save(state, filename)


def print_gradients(model):
    for name, parameter in model.named_parameters():
        if parameter.grad is not None:
            print(f"{name} - Gradient Norm: {parameter.grad.norm().item()}")
        else:
            print(f"{name} - No gradient")


for i in range(1,9):
    multi_bert_model = multiBert(args.chunk_sizes)  
    multi_bert_model.to(args.device)  
    optimizer = Adam(multi_bert_model.parameters(), lr = args.lr) 
    
    train_dataset = CustomDataset(f"/home/tsaibw/Multi_scale/dataset/train/encode_prompt_{i}.pkl")
    eval_dataset = CustomDataset(f"/home/tsaibw/Multi_scale/dataset/test/encode_prompt_{i}.pkl")
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
    eval_loader = DataLoader(eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)

    train_loss_list , eval_loss_list = [] ,[] 
    os.makedirs(f"{args.data_file}/prompt{i}", exist_ok=True)
    
    for epoch in range(args.epoch):
        multi_bert_model.train()
        total_loss = 0

        for document_single, chunked_documents, label, id_, lengths in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{args.epoch}"):
            document_single = document_single.to(args.device)
            optimizer.zero_grad()
            
            predictions = multi_bert_model(
                    document_single=document_single,
                    chunked_documents=chunked_documents,
                    device=args.device,
                    lengths=lengths
            )
            
            loss, inverse_predictions, inverse_labels = multi_bert_model.compute_loss(predictions, label, id_, args.device)
            total_loss += loss.item()

            loss.backward()
            optimizer.step()
        
        eval_loss, qwk_score, pearson_score = multi_bert_model.evaluate(eval_loader, device = args.device)
        
        print(f"Epoch {epoch}, Train Loss: {total_loss / len(train_loader)}")
        print(f"Test Loss: {eval_loss}")
        train_loss_list.append(total_loss / len(train_loader))
        eval_loss_list.append(eval_loss)

        qwk_path = f"{args.data_file}/prompt{i}/result.txt"
        with open(qwk_path, "a") as f:
            f.write(f"Epoch {epoch + 1}/{args.epoch}, QWK: {qwk_score}, Pearson: {pearson_score}, train_loss: {train_loss_list[-1]}, eval_loss: {eval_loss_list[-1]}\n")
  
        checkpoint_path = f"{args.data_file}/prompt{i}/epoch_{epoch+1}_checkpoint.pth.tar"
        save_checkpoint({
          'epoch': epoch + 1,
          'state_dict': multi_bert_model.state_dict(),
          'optimizer': optimizer.state_dict(),
          'train_loss': total_loss / len(train_loader),
          'eval_loss': eval_loss
        }, filename = checkpoint_path)


Epoch 1/15: 100% 75/75 [00:57<00:00,  1.30it/s]


Pearson :  -0.642
QWK :  -0.137
Epoch 0, Train Loss: 0.5873666714131832
Test Loss: 0.09719806015491486


Epoch 2/15: 100% 75/75 [00:57<00:00,  1.31it/s]


Pearson :  0.508
QWK :  0.113
Epoch 1, Train Loss: 0.12487845321496328
Test Loss: 0.04140287674963474


Epoch 3/15: 100% 75/75 [00:58<00:00,  1.28it/s]


Pearson :  0.484
QWK :  0.143
Epoch 2, Train Loss: 0.09129460168381533
Test Loss: 0.02152123786509037


Epoch 4/15: 100% 75/75 [00:57<00:00,  1.31it/s]


Pearson :  0.588
QWK :  0.356
Epoch 3, Train Loss: 0.05846137084066868
Test Loss: 0.0168944222231706


Epoch 5/15: 100% 75/75 [00:57<00:00,  1.30it/s]


Pearson :  0.589
QWK :  0.358
Epoch 4, Train Loss: 0.0481412306924661
Test Loss: 0.01852664121737083


Epoch 6/15: 100% 75/75 [00:58<00:00,  1.28it/s]


Pearson :  0.207
QWK :  0.016
Epoch 5, Train Loss: 0.045498618533213936
Test Loss: 0.02126310607418418


Epoch 7/15: 100% 75/75 [00:57<00:00,  1.29it/s]


Pearson :  0.207
QWK :  0.016
Epoch 6, Train Loss: 0.04639702626814445
Test Loss: 0.023182655808826287


Epoch 8/15: 100% 75/75 [00:59<00:00,  1.26it/s]


Pearson :  0.207
QWK :  0.016
Epoch 7, Train Loss: 0.045787829694648584
Test Loss: 0.022919519680241744


Epoch 9/15: 100% 75/75 [00:58<00:00,  1.29it/s]


Pearson :  0.207
QWK :  0.016
Epoch 8, Train Loss: 0.04424444705247879
Test Loss: 0.02536022998392582


Epoch 10/15: 100% 75/75 [01:00<00:00,  1.23it/s]


Pearson :  0.318
QWK :  0.08
Epoch 9, Train Loss: 0.0407459406927228
Test Loss: 0.030371446597079434


Epoch 11/15: 100% 75/75 [00:59<00:00,  1.26it/s]


Pearson :  0.342
QWK :  0.14
Epoch 10, Train Loss: 0.03502129915480812
Test Loss: 0.026542170842488607


Epoch 12/15: 100% 75/75 [01:00<00:00,  1.24it/s]


Pearson :  0.377
QWK :  0.18
Epoch 11, Train Loss: 0.030374603973080713
Test Loss: 0.031469534647961456


Epoch 13/15: 100% 75/75 [01:00<00:00,  1.24it/s]


Pearson :  0.455
QWK :  0.24
Epoch 12, Train Loss: 0.027945026823629936
Test Loss: 0.03252855117122332


Epoch 14/15: 100% 75/75 [01:01<00:00,  1.22it/s]


Pearson :  0.404
QWK :  0.189
Epoch 13, Train Loss: 0.024981539969642957
Test Loss: 0.03944007282455762


Epoch 15/15: 100% 75/75 [01:00<00:00,  1.25it/s]


Pearson :  0.489
QWK :  0.274
Epoch 14, Train Loss: 0.022414098136747877
Test Loss: 0.03403735980391502


Epoch 1/15: 100% 75/75 [01:00<00:00,  1.24it/s]


Pearson :  -0.608
QWK :  -0.586
Epoch 0, Train Loss: 0.8197053261597951
Test Loss: 0.061110729227463405


Epoch 2/15:  93% 70/75 [00:57<00:04,  1.23it/s]


KeyboardInterrupt: 