In [1]:
from datasets import load_dataset
import os, sys
import numpy as np
import torch, pandas, torch.nn, tempfile,os,pickle
from collections import namedtuple
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm
from accelerate import Accelerator
from tuning import get_model_architecture,get_model
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = 'cuda'

#tuning_names = ["full_ft","prefix","prompt","lora"]
#model_name = ["llama","dialogpt"]

tuning_name = "full_ft"
model_name = "llama"

tokenizer,model = get_model(model_name,device)
model = get_model_architecture(tuning_name, model_name, model)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:04<00:00,  2.14s/it]


In [2]:
torch.__version__

'2.2.1+cu121'

In [3]:
def collate_fn(samples):
    input_ids = [torch.LongTensor(sample.tokens) for sample in samples]
    mask = [torch.LongTensor(sample.mask) for sample in samples]
    labels = [torch.LongTensor(sample.labels) for sample in samples]

    slen = torch.LongTensor([ len(sample.tokens) for sample in samples ])
    max_slen = max(slen)
    padded_input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True,
                                                      padding_value = tokenizer.pad_token_id)
    padded_mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True).type(torch.float32)
    padded_labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True,
                                                      padding_value = -100)

    attention_mask = (torch.arange(max_slen)[None, :] < slen[:, None]).type(torch.long)


    return {'input_ids': padded_input_ids.contiguous(),
            'mask': padded_mask.contiguous(),
            'attention_mask': attention_mask,
           'labels': padded_labels.contiguous()}
TextDataExample = namedtuple('TextDataExample', ['dialogue', 'tokens', 'mask','labels'])
with open('preprocessed_datasets_llama.pickle', 'rb') as fr:
    datasets = pickle.load(fr)
train_ds, valid_ds, test_ds = datasets['train'][:5000] ,datasets['valid'][:500],datasets['test'][:2000]

batch_size = 10
gradient_accumulation_steps = 4

train_dl = DataLoader(train_ds,batch_size = batch_size, collate_fn = collate_fn,shuffle=True)
valid_dl = DataLoader(valid_ds,batch_size = batch_size, collate_fn = collate_fn,shuffle = True)
test_dl = DataLoader(test_ds,batch_size = batch_size, collate_fn = collate_fn)

In [4]:
n_epochs = 20
lr = 1e-4
WARMUP_PROPORTION = 0.05
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
train_steps = n_epochs * (len(datasets['train']) // batch_size + 1)
lr_scheduler = get_linear_schedule_with_warmup(optimizer,
     num_warmup_steps=int(train_steps * WARMUP_PROPORTION), # learning rate 조절하는거래
    num_training_steps=n_epochs * (len(datasets['train']) // batch_size + 1))

accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)
model, optimizer, training_dataloader, scheduler = accelerator.prepare(
    model, optimizer, train_dl, lr_scheduler
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [5]:

train_losses = []
valid_losses = []
best_valid_loss =float('inf')

for epoch in range(n_epochs):
    model.train()
    average_train_loss = 0.0

    for index, batch in enumerate(tqdm(train_dl, ncols=80)): ###############################
        with accelerator.accumulate(model):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device) # input이랑, am은 마지막꺼 하나뗌
            labels = batch['labels'].to(device) # label은 처음거 하나뗌

            outputs = model(input_ids=input_ids, attention_mask=attention_mask,labels = labels)
            loss = outputs.loss
            accelerator.backward(loss)
            optimizer.step()   ###############################
            lr_scheduler.step()  ###############################
            optimizer.zero_grad() ###############################
            average_train_loss += loss.item()/len(train_dl)   ###############################

    #average_train_loss = train_loss_sum / len(train_dl) ######################
    train_losses.append(average_train_loss)
    
    model.eval()
    valid_loss_sum = 0.0

    with torch.no_grad():
        for batch in tqdm(valid_dl, ncols=80, desc=f'Epoch {epoch} Validation'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device) # input이랑, am은 마지막꺼 하나뗌
            labels = batch['labels'].to(device) # label은 처음거 하나뗌

            outputs = model(input_ids=input_ids, attention_mask=attention_mask,labels = labels)
            loss = outputs.loss

            valid_loss_sum += loss.item()

    # Calculate average validation loss
    average_valid_loss = valid_loss_sum / len(valid_dl)
    valid_losses.append(average_valid_loss)
    
    print("train_loss:",average_train_loss,"valid_loss: ",average_valid_loss)
    
    #Early stopping
    if average_valid_loss < best_valid_loss:
        best_valid_loss = average_valid_loss
        if epoch !=0:
            torch.save(model.state_dict(), '{}.pt'.format(tuning_name))

    else:
        print("*******경고*******loss 안 줄어든다!!!!!!!!*******")

100%|███████████████████████████████████████| 500/500 [1:01:42<00:00,  7.40s/it]
Epoch 0 Validation: 100%|███████████████████████| 50/50 [02:00<00:00,  2.42s/it]


train_loss: 4.1753833699226375 valid_loss:  4.107889900207519


100%|███████████████████████████████████████| 500/500 [1:02:39<00:00,  7.52s/it]
Epoch 1 Validation: 100%|███████████████████████| 50/50 [02:06<00:00,  2.52s/it]


train_loss: 4.1774588475227326 valid_loss:  4.1265633678436275
*******경고*******loss 안 줄어든다!!!!!!!!*******


100%|███████████████████████████████████████| 500/500 [1:03:34<00:00,  7.63s/it]
Epoch 2 Validation: 100%|███████████████████████| 50/50 [02:01<00:00,  2.42s/it]


train_loss: 4.178846688747405 valid_loss:  4.1230002117156985
*******경고*******loss 안 줄어든다!!!!!!!!*******


100%|███████████████████████████████████████| 500/500 [1:03:12<00:00,  7.58s/it]
Epoch 3 Validation: 100%|███████████████████████| 50/50 [02:04<00:00,  2.48s/it]


train_loss: 4.1741813492774975 valid_loss:  4.10860969543457
*******경고*******loss 안 줄어든다!!!!!!!!*******


100%|███████████████████████████████████████| 500/500 [1:02:38<00:00,  7.52s/it]
Epoch 4 Validation: 100%|███████████████████████| 50/50 [02:01<00:00,  2.43s/it]


train_loss: 4.17718098974228 valid_loss:  4.111264171600342
*******경고*******loss 안 줄어든다!!!!!!!!*******


100%|███████████████████████████████████████| 500/500 [1:03:30<00:00,  7.62s/it]
Epoch 5 Validation: 100%|███████████████████████| 50/50 [02:04<00:00,  2.49s/it]


train_loss: 4.1771445441246025 valid_loss:  4.1189279460906985
*******경고*******loss 안 줄어든다!!!!!!!!*******


 63%|█████████████████████████▋               | 314/500 [40:18<23:52,  7.70s/it]


KeyboardInterrupt: 

In [None]:

epochs = range(1,epoch+2 )
fig, ax = plt.subplots()
ax.plot(epochs,np.array(train_losses),label ='training loss')
ax.plot(epochs, np.array(valid_losses), label = 'validation loss' )
ax.set_xlabel('Epochs')
ax.set_ylabel('Loss')
ax.legend()
plt.show()