In [None]:
import torch
import sys
sys.path.append('../')
from model.continuous_prompt import ContinuousPromptingLLM
from model.recsys_encoder import RecsysContinuousPromptModel
from model.projection import BasicProjection
from dataset import RecsysDataset

from util import plot_and_save

In [None]:
MODE='train'
TASK='recommendation'
MODEL_NAME = 'light-gcn'
LLM_DIR = "/SSL_NAS/bonbak/model/models--yanolja--EEVE-Korean-Instruct-2.8B-v1.0/snapshots/482db2d0ba911253d09342c34d0e42ac871bfea3"
SAVE_DIR=f'/home/bonbak/continuous-prompting/output/{TASK}'
TASKS_DIR = f'/home/bonbak/continuous-prompting/task/{TASK}'
DEVICE='cuda:1'

In [None]:
train_dataset = RecsysDataset(f"{TASKS_DIR}/{MODE}.jsonl", f"{TASKS_DIR}/edge.csv")
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
num_users, num_items = len(train_dataset.user_mapping), len(train_dataset.item_mapping)

In [None]:
continuous_prompt_model = RecsysContinuousPromptModel(num_users, num_items,f'{TASKS_DIR}/train_edge_index.pt')
projection_module = BasicProjection(continuous_prompt_model.model.embedding_dim)

model = ContinuousPromptingLLM(
    LLM_DIR,
    continuous_prompt_model, 
    continuous_prompt_model.model.embedding_dim
)

model.continuous_prompt_model.model.load_state_dict(torch.load(f'{SAVE_DIR}/model/{MODEL_NAME}.bin'))

continuous_prompt_model.to(DEVICE)
model.to(DEVICE)

for param in model.parameters():
    param.requires_grad = False
for param in model.projection_module.parameters():
    param.requires_grad = True
for param in model.continuous_prompt_model.parameters():
    param.requires_grad = True

optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5)

In [5]:
model.train()
c = 0
loss_log_list = []
min_loss = 1000000
accumulate_step = 8

def mean(l):
    return sum(l)/len(l)

for epoch in range(10):
    for input_text, continuous_prompt_input, answer_list in train_dataloader:
        inputs_embeds, attention_mask, labels = model.make_seq2seq_input_label(input_text,continuous_prompt_input,answer_list, embedding_first=True)

        generated_output = model.llm_model.forward(
                    inputs_embeds=inputs_embeds,
                    attention_mask = attention_mask,
                    labels=labels
                )
        generated_output.loss.backward()
        
        if c % accumulate_step == 0:
            optimizer.step()
            optimizer.zero_grad()
        loss_log_list.append(generated_output.loss.item())
        
        if c % 80 == 0 and c!=0:
            cur_loss = mean(loss_log_list[-accumulate_step:])
            if min_loss > cur_loss:
                model.eval()
                model.to('cpu')
                min_loss = cur_loss
                torch.save(model.projection_module.state_dict(), f'{SAVE_DIR}/model/{MODEL_NAME}-projection.bin')
                torch.save(model.continuous_prompt_model.state_dict(), f'{SAVE_DIR}/model/{MODEL_NAME}-encoder.bin')

                inputs_embeds, attention_mask = model.make_input_embed(input_text,continuous_prompt_input, embedding_first=True)
                output = model.llm_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, pad_token_id=model.llm_tokenizer.eos_token_id, max_new_tokens=1)
                print(input_text[0], model.llm_tokenizer.decode(output[0]))
                plot_and_save(loss_log_list, f'{SAVE_DIR}/loss/{MODEL_NAME}.png')

                model.train()
                model.to(DEVICE)

            print(f'step {c} | cur_loss : {cur_loss:.4f} | min_loss : {min_loss:.4f} ')
        c+=1