### prepare data (100 sample)

In [3]:
from openai import OpenAI
from secret import OPENAI_API_KEY
client = OpenAI(api_key= OPENAI_API_KEY)
import json
import re

In [4]:
mname = 'gpt-3.5-turbo'
temperature = 0
max_token = 1024
prefix_prompt = """Given a document, do the followling tasks:
(1) According to the document, find at least 3 important events.
(2) With the retrieved event, compose a summary in 3 sentences.

Example:
============Example============
Prompt:
Document: [document]
Update:
Important Events:
1. [EVENT_1]
2. [EVENT_2]
3. [EVENT_3]
...

Summary:
[summary]
===============================
"""

In [5]:
def _response_process(content: str, document: str):
    event = content.split("Important Events:\n")[1].split("Summary")[0]
    eventlog = event.split("\n")
    rationale = ""
    for e in eventlog:
        if len(e) == 0: continue
        rationale += re.sub(r'^\d+\. ', '', e)

    summary = content.split("Summary:\n")[1]

    result = {
        "article": document,
        "rationale": rationale,
        "summary": summary
    }

    return result
    

In [6]:
def _store_as_jsonl(results: list):
    with open("data/cnndm/rationale.jsonl", mode="a") as f:
        for r in results:
            r = json.dumps(r)
            f.write(r + "\n")
        f.close()
    return

In [7]:
def gpt_api(data_size: int):
    document = []
    gt_summary = []
    results = []
    with open("data/cnndm/train.jsonl") as f:
        data = [json.loads(line) for line in f.readlines()]
        for i in range(100):
            document.append(data[i]["article"])
            gt_summary.append(data[i]["highlights"])
        f.close()
    for i in range(data_size):
        prompt = f"{prefix_prompt}\nPrompt:\n[Document]: {document[i]}\n\nUpdate:"
        response = client.chat.completions.create(
            model=mname,
            messages = [
                {"role": "user", "content": prompt}
            ],
            temperature=temperature,
            max_tokens=max_token
        )
        print(response.choices[0].message.content)
        results.append(_response_process(response.choices[0].message.content, document[i]))
        _store_as_jsonl(results)
    return
        

### Train

In [1]:
from transformers import BartTokenizer, BartForConditionalGeneration

import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn

from tqdm import tqdm
import numpy as np
import os
import json

# Define the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  from .autonotebook import tqdm as notebook_tqdm


### 目標：利用兩個bart當作extractor & abstractor，看看有沒有能力生出更好的Summary

In [9]:
# DATASET
from datasets import load_dataset
class CNNDMDataset(Dataset):
    def __init__(self, data, max_len:int=1024, data_len:int=1000, ):
        #input type [article, highlights, id]
        super().__init__()
        self.data = data
        self.tok = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
        self.max_len = max_len
        self.data_len = data_len

    def __len__(self):
        return self.data_len
    
    def __getitem__(self, idx):
        src = self.tok.encode_plus(self.data[idx]['article'], max_length=self.max_len, return_tensors='pt', truncation=True, padding='max_length')
        tgt = self.tok.encode_plus(self.data[idx]['summary'], max_length=self.max_len, return_tensors='pt', truncation=True, padding='max_length')
        ral = self.tok.encode_plus(self.data[idx]['rationale'], max_length=self.max_len, return_tensors='pt', truncation=True, padding='max_length')
        src_input_ids = src['input_ids'].squeeze()
        tgt_input_ids = tgt['input_ids'].squeeze()
        ral_input_ids = ral['input_ids'].squeeze()

        result = {
            'src_input_ids': src_input_ids,
            'tgt_input_ids': tgt_input_ids,
            'ral_input_ids': ral_input_ids,
        }
        return result

In [None]:
data = []
with open('data/cnndm/rationale.jsonl') as f:
    data = [json.loads(line) for line in f.readlines()]
dataset = CNNDMDataset(data, data_len=100)
train_loader = DataLoader(dataset, batch_sampler=8, shuffle=True)

In [11]:
extractor_model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
abstractor_model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')

In [None]:
def run():
    epoch = 3
    accumulate_count = 0
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
    model.to(device)
    for (i, batch) in enumerate(train_loader):
        input_ids = batch['src_input_ids'].to(device)
        tgt_ids = batch['tgt_input_ids'].to(device)
        ral_ids = batch['ral_input_ids'].to(device)

        