# load pretrained model

1. Misra, Rishabh. "News Category Dataset." arXiv preprint arXiv:2209.11429 (2022).
2. Misra, Rishabh and Jigyasa Grover. "Sculpting Data for ML: The first act of Machine Learning." ISBN 9798585463570 (2021).


In [1]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [2]:
from transformers import LlamaForCausalLM
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("daily_tokenizer_0612")
model = LlamaForCausalLM.from_pretrained('daily_llama_0612')

model.to(device)
0

0

In [3]:
model_fake = LlamaForCausalLM.from_pretrained('fake_detect_llama')

model_fake.to(device)
model_fake.eval()
0

0

In [4]:
prompt = """Return True if the given article is fake. article: Boeing CEO says he assured Trump about Air Force One costs answer:"""

inputs = tokenizer(prompt, return_tensors="pt")
inputs.to(device)

# Generate
generate_ids = model_fake.generate(inputs.input_ids, max_length=50)
output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

print(output)

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Return True if the given article is fake. article: Boeing CEO says he assured Trump about Air Force One costs answer: True answer: True answer: True answer: True if the given article is fake. article:  answer: False-W


In [5]:
model_fake.eval()

prompt = """What is the topic of the collowing article? article: Boeing CEO says he assured Trump about Air Force One costs answer:"""
inputs = tokenizer(prompt, return_tensors="pt")
inputs.to(device)

# Generate
generate_ids = model_fake.generate(inputs.input_ids, max_length=30)
tokenizer.batch_decode(generate_ids, skip_special_tokens=True, 
                    clean_up_tokenization_spaces=False)[0]


'What is the topic of the collowing article? article: Boeing CEO says he assured Trump about Air Force One costs answer: True answer: True answer'

In [6]:
model.eval()

prompt = """\
What is the topic of the collowing article? article: Boeing CEO says he assured Trump about Air Force One costs answer:"""
inputs = tokenizer(prompt, return_tensors="pt")
inputs.to(device)

# Generate
generate_ids = model.generate(inputs.input_ids, max_length=100)
tokenizer.batch_decode(generate_ids, skip_special_tokens=True, 
                    clean_up_tokenization_spaces=False)[0]


'What is the topic of the collowing article? article: Boeing CEO says he assured Trump about Air Force One costs answer: True answer: True answer: True answer: False-Wing answer: False-Wing answer: False-Wing answer: False-Wing answer: False-Wing-Wing-Wing-Wing-Wing-Wing-Wing-Wing-Wing-Wing-Wing-Wing-Wing-W'

## load dataset

In [7]:
from datasets import load_dataset
data = 'heegyu/news-category-balanced-top10'
dataset = load_dataset(data)

dataset

DatasetDict({
    train: Dataset({
        features: ['link', 'headline', 'category', 'short_description', 'authors', 'date'],
        num_rows: 83878
    })
})

In [8]:
dataset['train'][0]

{'link': 'https://www.huffpost.com/entry/rei-workers-berkeley-store-union_n_6307a5f4e4b0f72c09ded80d',
 'headline': 'REI Workers At Berkeley Store Vote To Unionize In Another Win For Labor',
 'category': 'BUSINESS',
 'short_description': 'They follow in the footsteps of REI workers in New York City who formed a union earlier this year.',
 'authors': 'Dave Jamieson',
 'date': 1661385600000}

In [9]:
dataset['train'].features

{'link': Value(dtype='string', id=None),
 'headline': Value(dtype='string', id=None),
 'category': Value(dtype='string', id=None),
 'short_description': Value(dtype='string', id=None),
 'authors': Value(dtype='string', id=None),
 'date': Value(dtype='int64', id=None)}

In [10]:
df = dataset['train'].to_pandas()
df

Unnamed: 0,link,headline,category,short_description,authors,date
0,https://www.huffpost.com/entry/rei-workers-ber...,REI Workers At Berkeley Store Vote To Unionize...,BUSINESS,They follow in the footsteps of REI workers in...,Dave Jamieson,1661385600000
1,https://www.huffpost.com/entry/twitter-elon-mu...,Twitter Lawyer Calls Elon Musk 'Committed Enem...,BUSINESS,Delaware Chancery Judge Kathaleen McCormick de...,Marita Vlachou,1658275200000
2,https://www.huffpost.com/entry/starbucks-leave...,"Starbucks Leaving Russian Market, Shutting 130...",BUSINESS,Starbucks' move follows McDonald's exit from t...,"DEE-ANN DURBIN, AP",1653264000000
3,https://www.huffpost.com/entry/coinbase-crypto...,Crypto Crash Leaves Trading Platform Coinbase ...,BUSINESS,Cryptocurrency trading platform Coinbase has l...,"Matt Ott, AP",1652313600000
4,https://www.huffpost.com/entry/us-april-jobs-r...,"US Added 428,000 Jobs In April Despite Surging...",BUSINESS,"At 3.6%, unemployment nearly reached the lowes...","Paul Wiseman, AP",1651795200000
...,...,...,...,...,...,...
83873,https://www.huffingtonpost.com/entry/gratitude...,"Flex Your Gratitude Muscle, and Lift Stress Away",WELLNESS,"For most of us, giving comes a lot easier than...","meQuilibrium, Contributor\nPersonalized Stress...",1369353600000
83874,https://www.huffingtonpost.com/entry/diabetes-...,Don't Wait to Prevent Diabetes: Start Today Wi...,WELLNESS,"Small, reasonable changes can add up to a lot ...","Susan B. Dopart, MS, RD, CDE, Contributor\nHea...",1355443200000
83875,https://www.huffingtonpost.com/entry/dream-lif...,The Real Reason You're Not Living Your Dream L...,WELLNESS,Excuses are artificial creations that mask the...,"Alexis Sclamberg, Contributor\nCEO & Founder, ...",1346025600000
83876,https://www.huffingtonpost.com/entry/sugar-obe...,"Is Sugar Making the World Fat, Diabetic, and H...",WELLNESS,The new study in Public Health Nutrition remin...,"Ayala Laufer-Cahana, M.D., Contributor\nPhysic...",1362096000000


In [11]:
categories = df.category.unique().tolist()
categories.sort()
categories

['BUSINESS',
 'ENTERTAINMENT',
 'FOOD & DRINK',
 'HEALTHY LIVING',
 'PARENTING',
 'POLITICS',
 'QUEER VOICES',
 'STYLE & BEAUTY',
 'TRAVEL',
 'WELLNESS']

In [12]:
categories = categories[:4]

In [13]:
dataset = dataset.filter(lambda element: element['category'] in categories)
dataset

DatasetDict({
    train: Dataset({
        features: ['link', 'headline', 'category', 'short_description', 'authors', 'date'],
        num_rows: 29026
    })
})

In [14]:
categories = [x.split(' ')[0].lower() for x in categories[:5]]
categories

['business', 'entertainment', 'food', 'healthy']

In [15]:
int2label = {i: categories[i] for i in range(len(categories))}
label2int = {int2label[key]:key for key in int2label}

In [16]:
def gen_label(element):
    category = element['category'].split(' ')[0].lower()
    return {'label': label2int[category], 'category': category}

dataset = dataset.map(gen_label)
dataset

DatasetDict({
    train: Dataset({
        features: ['link', 'headline', 'category', 'short_description', 'authors', 'date', 'label'],
        num_rows: 29026
    })
})

In [21]:
from datasets import DatasetDict
import random

prompt_format1 = """Given the article, what is the topic of the article? article: %s  answer: %s"""
prompt_format2 = """Determine the topic of the news article. article: %s answer: %s"""
prompt_format3 = """What is this article about? business/entertainment/food/healthy/parenting article: %s answer: %s"""

prompts = [prompt_format1, prompt_format2, prompt_format3]

def gen_prompt(element):
    prompt_format = prompts[random.randint(0, len(prompts)-1)]
    #return DatasetDict({'input': prompt_format%(element['headline'], int2label[element['category']])})
    return DatasetDict({'input': prompt_format%(element['headline'], int2label[element['label']])})


dataset = dataset.map(gen_prompt)

Map:   0%|          | 0/29026 [00:00<?, ? examples/s]

In [22]:
dataset['train'].to_pandas()

Unnamed: 0,link,headline,category,short_description,authors,date,label,input
0,https://www.huffpost.com/entry/rei-workers-ber...,REI Workers At Berkeley Store Vote To Unionize...,business,They follow in the footsteps of REI workers in...,Dave Jamieson,1661385600000,0,"Given the article, what is the topic of the ar..."
1,https://www.huffpost.com/entry/twitter-elon-mu...,Twitter Lawyer Calls Elon Musk 'Committed Enem...,business,Delaware Chancery Judge Kathaleen McCormick de...,Marita Vlachou,1658275200000,0,"Given the article, what is the topic of the ar..."
2,https://www.huffpost.com/entry/starbucks-leave...,"Starbucks Leaving Russian Market, Shutting 130...",business,Starbucks' move follows McDonald's exit from t...,"DEE-ANN DURBIN, AP",1653264000000,0,"Given the article, what is the topic of the ar..."
3,https://www.huffpost.com/entry/coinbase-crypto...,Crypto Crash Leaves Trading Platform Coinbase ...,business,Cryptocurrency trading platform Coinbase has l...,"Matt Ott, AP",1652313600000,0,Determine the topic of the news article. artic...
4,https://www.huffpost.com/entry/us-april-jobs-r...,"US Added 428,000 Jobs In April Despite Surging...",business,"At 3.6%, unemployment nearly reached the lowes...","Paul Wiseman, AP",1651795200000,0,Determine the topic of the news article. artic...
...,...,...,...,...,...,...,...,...
29021,https://www.huffingtonpost.com/entry/happy-hea...,Why You Need Both a 'Bouncer' and a 'Bartender...,healthy,Instead of judging whether you made the right ...,"Elizabeth Grace Saunders, ContributorFounder, ...",1397779200000,3,Determine the topic of the news article. artic...
29022,https://www.huffingtonpost.com/entry/mental-il...,How Video Games Can Improve Dialogue on Mental...,healthy,While there are strong arguments for the games...,"Mona Shattell, Contributornurse researcher",1397779200000,3,What is this article about? business/entertain...
29023,https://www.huffingtonpost.com/entry/wake-up-c...,Wake-Up Calls Inspired My Change From Overdriv...,healthy,My wake-up call marching orders were clear: No...,"Jane Shure, ContributorLeadership Coach, Psych...",1397779200000,3,Determine the topic of the news article. artic...
29024,https://www.huffingtonpost.com/entry/narcissis...,Loving a Narcissist Without Losing Yourself,healthy,It is very difficult for some people to see an...,"Nancy Colier, ContributorPsychotherapist, inte...",1397779200000,3,What is this article about? business/entertain...


In [23]:
dataset = dataset['train'].train_test_split(test_size=0.1)

In [24]:
def tokenize(element):
    tokenizer.pad_token = tokenizer.eos_token
    outputs = tokenizer(
        element['input'],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=False,
        return_length=True,
        padding=True
    )
    input_batch = []
    for inputs, input_ids in zip(element["input"], outputs["input_ids"]):
        input_batch.append(input_ids)
    return {"input_ids": input_batch}


context_length=128
tokenized_datasets = dataset.map(
    tokenize, batched=True, remove_columns=dataset['train'].column_names
)
tokenized_datasets

Map:   0%|          | 0/26123 [00:00<?, ? examples/s]

Map:   0%|          | 0/2903 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids'],
        num_rows: 26123
    })
    test: Dataset({
        features: ['input_ids'],
        num_rows: 2903
    })
})

## train

In [25]:
from transformers import DataCollatorForLanguageModeling

tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

In [26]:
out = data_collator([tokenized_datasets['train'][i] for i in range(5)])
for key in out:
    print(f"{key} shape: {out[key].shape}")

input_ids shape: torch.Size([5, 55])
attention_mask shape: torch.Size([5, 55])
labels shape: torch.Size([5, 55])


In [27]:
from transformers import Trainer, TrainingArguments

args = TrainingArguments(
    output_dir="topic_llama",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    evaluation_strategy="steps",
    eval_steps=500,
    logging_steps=500,
    gradient_accumulation_steps=8,
    num_train_epochs=1,
    weight_decay=0.1,
    warmup_steps=500,
    lr_scheduler_type="cosine",
    learning_rate=5e-4,
    save_steps=500,
    fp16=True,
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [28]:
trainer.train()

Step,Training Loss,Validation Loss
500,2.797,2.4874


TrainOutput(global_step=816, training_loss=2.642992431042241, metrics={'train_runtime': 221.5971, 'train_samples_per_second': 117.885, 'train_steps_per_second': 3.682, 'total_flos': 363222179389440.0, 'train_loss': 2.642992431042241, 'epoch': 1.0})

## evaluate

In [29]:
prompt = """Determine the topic of the news article. article: Bikini'd Kate Hudson Hits The Beach With Chris Martin answer:"""

inputs = tokenizer(prompt, return_tensors="pt")
inputs.to("cuda:0")

# Generate
generate_ids = model.generate(inputs.input_ids, max_length=30)
output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

print(output)

Determine the topic of the news article. article: Bikini'd Kate Hudson Hits The Beach With Chris Martin answer: entertainment answer: entertainment


In [30]:
dataset['test'][222]['input']

"Given the article, what is the topic of the article? article: RIP Bob Schiller: Radio Writing Wasn't Working, So He Sent Lucy Out to Stomp Some Grapes  answer: entertainment"

In [31]:
prompt = "Given the article, what is the topic of the article? article: This Simple Menu Change Could Finally Get Us To Stop Overeating  answer:"

inputs = tokenizer(prompt, return_tensors="pt")
inputs.to("cuda:0")

# Generate
generate_ids = model.generate(inputs.input_ids, max_length=50)
output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

print(output)

Given the article, what is the topic of the article? article: This Simple Menu Change Could Finally Get Us To Stop Overeating  answer: healthy  answer: healthy  answer: healthy  answer: healthy  answer: healthy  answer:


In [32]:
dataset['test'][104]['input']


"What is this article about? business/entertainment/food/healthy/parenting article: Hellmann's vs. Best Foods Mayonnaise: Is There A Difference? answer: food"

In [33]:
prompt = "What is this article about? business/entertainment/food/healthy/parenting article: Kylie Jenner Wants You To Know She's Got 'Chunkiness,' Not Butt Implants, OK? answer:"

inputs = tokenizer(prompt, return_tensors="pt")
inputs.to("cuda:0")

# Generate
generate_ids = model.generate(inputs.input_ids, max_length=50)
output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

print(output)

What is this article about? business/entertainment/food/healthy/parenting article: Kylie Jenner Wants You To Know She's Got 'Chunkiness,' Not Butt Implants, OK? answer: entertainment entertainment answer: entertainment


In [34]:
tokenizer = AutoTokenizer.from_pretrained("daily_tokenizer_0612", padding_side='left')
prompt_format1 = """Given the article, what is the topic of the article? article: %s  answer:"""
prompt_format2 = """Determine the topic of the news article. article: %s answer:"""
prompt_format3 = """What is this article about? business/entertainment/food/healthy/parenting article: %s answer:"""

prompts = [prompt_format1, prompt_format2, prompt_format3]

def gen_valid_prompt(element):
    prompt_format = prompts[random.randint(0, len(prompts)-1)]
    return DatasetDict({'input': prompt_format%(element['headline'])})




valid_dataset = dataset['test'].select(range(100)).map(gen_valid_prompt)
valid_dataset[0]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

{'link': 'https://www.huffingtonpost.com/entry/star-wars-stayin-alive_us_56795914e4b0b958f657ddfa',
 'headline': "'The Tonight Show' Creates Epic 'Star Wars'-'Stayin' Alive' Mashup",
 'category': 'entertainment',
 'short_description': 'This is thankfully not a video about who does and does not stay alive in "The Force Awakens."',
 'authors': 'Todd Van Luling',
 'date': 1450742400000,
 'label': 1,
 'input': "What is this article about? business/entertainment/food/healthy/parenting article: 'The Tonight Show' Creates Epic 'Star Wars'-'Stayin' Alive' Mashup answer:"}

In [35]:
valid_dataset.column_names

['link',
 'headline',
 'category',
 'short_description',
 'authors',
 'date',
 'label',
 'input']

In [36]:
valid_dataset = valid_dataset.map(
    tokenize, batched=True, remove_columns=['link', 'headline', 'category', 'short_description', 'authors', 'date', 'input']
)
valid_dataset

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Dataset({
    features: ['label', 'input_ids'],
    num_rows: 100
})

In [37]:
from torch.utils.data import DataLoader

batch_size=4
val_ds = valid_dataset
val_ds.set_format(type='torch')
val_dl = DataLoader(val_ds, batch_size=batch_size)

In [38]:
import re
import torch
from tqdm import tqdm

def acc(pred,label):
    return torch.sum(torch.tensor(pred) == label.squeeze()).item()


In [39]:
model_orig = LlamaForCausalLM.from_pretrained('daily_llama_0612')
model_orig.to(device)
model_orig.eval()

val_losses = []
val_acc = 0

for step, batch in enumerate(tqdm(val_dl)):
    label = batch['label']
    input_id= batch['input_ids'].to(device)

    pred = model_orig.generate(input_id, max_length=150)
    decoded_pred = tokenizer.batch_decode(pred, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    decoded_pred = [re.findall("answer: ([a-z]+)", x)[0] if re.findall("answer: ([a-z]+)", x) else 'none' for x in decoded_pred]
    decoded_pred = [label2int[x] if x in label2int else -1 for x in decoded_pred]

    val_acc += acc(decoded_pred, label)
    

print("val acc: ", val_acc/len(val_dl.dataset))

  0%|          | 0/25 [00:00<?, ?it/s]This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (128). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.
100%|██████████| 25/25 [00:24<00:00,  1.03it/s]

val acc:  0.0





In [40]:
tokenizer.batch_decode(pred, skip_special_tokens=True, clean_up_tokenization_spaces=False)

["Given the article, what is the topic of the article? article: 'The Interview': No Laughing Matter  answer: True if the given article is fake. article:  Trump’s ‘Dictators’ To ‘Dictators’ Him’s ‘Dictators’ To ‘Dictators’ (VIDEO) answer: False, But It’s A ‘Dictators’ answer: Falsely Takes Himself With A ‘Dictators’ answer: Falsely Takes Himself With The Most ‘Dictators’ To Help Them’ (VIDEO) answer: Falsely T",
 'Determine the topic of the news article. article: The Truth About Pumpkin Seeds Will Make You Want Them Even More answer: False, It’s A ‘Dreamer’  answer: Falsely Tells The Most ‘Dictators’ To Help Them  answer: Falsely Mocks Them In The Most ‘Drease’ Of The Most ‘Dictators’ To Help Them’ In The Most ‘Dictators’ To Help Them’ In The Most ‘Dictators’ To Help Them’ In The Most ‘Dictators’ To Help Them’ In The Most ‘',
 'Determine the topic of the news article. article: 18 Tweets That Capture Carrie Fisher’s Mental Health Legacy answer: False,000  answer: False,000 Americans  ans

In [41]:
model.eval()
val_losses = []
val_acc = 0

for step, batch in enumerate(tqdm(val_dl)):
    label = batch['label']
    input_id= batch['input_ids'].to(device)

    pred = model.generate(input_id, max_length=65)
    decoded_pred = tokenizer.batch_decode(pred, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    decoded_pred = [re.findall("answer: ([a-z]+)", x)[0] if re.findall("answer: ([a-z]+)", x) else 'none' for x in decoded_pred]
    decoded_pred = [label2int[x] if x in label2int else -1 for x in decoded_pred]

    val_acc += acc(decoded_pred, label)
    

print("val acc: ", val_acc/len(val_dl.dataset))

100%|██████████| 25/25 [00:05<00:00,  4.88it/s]

val acc:  0.83





In [42]:
model.save_pretrained('topic_llama_0618')

In [None]:
# 모델의 사이즈가 커질수록 더 큰 pre-training 데이터로 학습을 할수록 더 많은 task에 대해서 일반화가 가능하다.