In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content/drive/MyDrive/rag_project

!pip install --upgrade pip
!pip install transformers datasets
!pip install faiss-cpu
!pip install faiss-gpu
!pip install pytorch_lightning

In [None]:
import argparse
import logging
import os
import sys
from pathlib import Path
from collections import defaultdict
from typing import Any, Dict, List, Tuple

import faiss
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from transformers import (
    RagTokenizer,
    RagSequenceForGeneration,
    RagRetriever,
    DPRContextEncoder,
    DPRContextEncoderTokenizer,
    DPRQuestionEncoderTokenizer,
    BartTokenizer
)

import pandas as pd
import datasets
from datasets import Dataset, load_from_disk

# OpenMP 오류를 해결하기 위한 환경 변수 설정
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
os.environ['TRUST_REMOTE_CODE'] = 'True'


In [None]:
def compute_embeddings(batch, ctx_tokenizer, ctx_encoder):
    inputs = ctx_tokenizer(batch['text'], truncation=True, padding=True, return_tensors="pt")
    inputs = {key: val.to(device) for key, val in inputs.items()}  # 입력 텐서를 GPU 또는 CPU로 이동
    with torch.no_grad():
        embeddings = ctx_encoder(**inputs).pooler_output
    return {'embeddings': embeddings.cpu().numpy().tolist()}

In [None]:
tsv_file = "psgs_w100.tsv"  # 파일 경로를 올바르게 설정
dataset_dict = datasets.load_dataset('csv', data_files=tsv_file, delimiter='\t')

# 데이터셋 변환
dataset = dataset_dict['train']

# 사전 훈련된 DPR 모델 및 토크나이저 로드
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
ctx_encoder = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')

# 장치 설정 (GPU가 사용 가능한 경우 GPU 사용)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ctx_encoder.to(device)

embeddings_dataset = dataset.map(lambda batch: compute_embeddings(batch, ctx_tokenizer, ctx_encoder), batched=True, batch_size=128)

embeddings_dataset.add_faiss_index(column='embeddings')

# FAISS 인덱스를 파일로 저장
faiss.write_index(embeddings_dataset.get_index('embeddings').faiss_index, 'embeddings.faiss')


In [None]:
class Seq2SeqDataset(Dataset):
    def __init__(self, tokenizer, hf_dataset, max_length=512):
        self.tokenizer = tokenizer
        self.dataset = hf_dataset
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        question = self.dataset[idx]['Question']
        answer = self.dataset[idx]['Answer']

        inputs = self.tokenizer(question, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt")
        labels = self.tokenizer(answer, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt")

        item = {key: val.squeeze() for key, val in inputs.items()}
        item['labels'] = labels['input_ids'].squeeze()

        return item

def get_dataloaders(tokenizer, data_path, batch_size=8, max_length=512):
    train_data = pd.read_csv(data_path)
    hf_dataset = datasets.Dataset.from_pandas(train_data)

    dataset = Seq2SeqDataset(tokenizer, hf_dataset, max_length=max_length)

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return dataloader

In [None]:
# 로깅 설정
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class GenerativeQAModule(pl.LightningModule):
    def __init__(self, hparams):
        super(GenerativeQAModule, self).__init__()

        # 하이퍼파라미터 저장
        self.save_hyperparameters(hparams)

        self.question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
        self.generator_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
        self.tokenizer = RagTokenizer(self.question_encoder_tokenizer,self.generator_tokenizer)

        self.retriever = RagRetriever.from_pretrained(
            "facebook/rag-sequence-base",
            indexed_dataset=hparams.indexed_dataset
        )
        self.model = RagSequenceForGeneration.from_pretrained(
            "facebook/rag-sequence-base",
            retriever=self.retriever
        )

        # 기타 설정
        self.step_count = 0
        self.metrics = defaultdict(list)

        # 특정 레이어를 비활성화하여 메모리 사용량 줄이기
        for name, param in self.model.generator.named_parameters():
            parts = name.split('.')
            if 'layers' in name and parts[1].isdigit():
                layer_num = int(parts[1])
                if layer_num > 6:
                    param.requires_grad = False

    def forward(self, input_ids, attention_mask=None, labels=None):
        return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

    def training_step(self, batch, batch_idx):
        inputs = {key: val.to(self.device) for key, val in batch.items()}
        outputs = self.model(**inputs)
        loss = outputs.loss.mean()

        # 로깅
        logs = {'train_loss': loss.detach()}

        # tokens per batch
        tgt_pad_token_id = (
            self.tokenizer.generator.pad_token_id
            if isinstance(self.tokenizer, RagTokenizer)
            else self.tokenizer.pad_token_id
        )
        src_pad_token_id = (
            self.tokenizer.question_encoder.pad_token_id
            if isinstance(self.tokenizer, RagTokenizer)
            else self.tokenizer.pad_token_id
        )
        logs["tpb"] = (
            batch["input_ids"].ne(src_pad_token_id).sum() + batch["decoder_input_ids"].ne(tgt_pad_token_id).sum()
        )

        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
    return {"loss": loss, "log": logs}


    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams['learning_rate'])

    def train_dataloader(self):
        return get_dataloaders(
            self.tokenizer,
            self.hparams['train_file_path'],
            batch_size=self.hparams['batch_size'],
            max_length=self.hparams['max_length'],
            num_workers=4
        )


In [None]:
parser = argparse.ArgumentParser(description="Generative QA Model Training Script")

parser.add_argument("--indexed_dataset", type=type(embeddings_dataset), default=embeddings_dataset, help="Path to the indexed dataset")
parser.add_argument("--learning_rate", type=float, default=3e-5, help="Learning rate for the optimizer")
parser.add_argument("--train_file_path", type=str, default="/content/drive/MyDrive/rag_project/nq-train.csv", help="Path to the training CSV file")
parser.add_argument("--batch_size", type=int, default=2, help="Batch size for training")
parser.add_argument("--max_length", type=int, default=512, help="Maximum length of the input sequences")
parser.add_argument("--max_epochs", type=int, default=1, help="Number of epochs for training")

args = parser.parse_args([])

In [None]:
model = GenerativeQAModule(args)

train_dataloader = get_dataloaders(model.tokenizer, args.train_file_path, batch_size=args.batch_size, max_length=args.max_length)

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

trainer = pl.Trainer(
        max_epochs=args.max_epochs,
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        devices=1 if torch.cuda.is_available() else 0,
        precision='16-mixed',  # Mixed Precision Training 활성화
        accumulate_grad_batches=2,  # Gradient Accumulation 적용
        enable_progress_bar=True
    )

trainer.fit(model, train_dataloader)

In [None]:
# 모델 저장
save_path = "/content/drive/MyDrive/rag_project/trained_model"
trainer.save_checkpoint(os.path.join(save_path, "model_checkpoint.ckpt"))

# 모델 파라미터 저장
model.model.save_pretrained(save_path)
model.tokenizer.save_pretrained(save_path)
model.retriever.save_pretrained(save_path)
