In [1]:
from transformers import BertTokenizer, BertForSequenceClassification
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# 쿠다없으면 cpu 사용
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


In [2]:
print(device)

cuda


In [3]:
model = BertForSequenceClassification.from_pretrained("bert-base-uncased")


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [4]:
model.config.num_labels = 1

In [5]:
for param in model.parameters():
    param.requires_grad = False

# bert-base의 H=768이기 때문
model.classifier = nn.Sequential(
    nn.Linear(768, 256),
    nn.LeakyReLU(),
    nn.Linear(256,64),
    nn.LeakyReLU(),
    nn.Linear(64, 2),
    nn.Softmax(dim=1)
)
model = model.to(device)

In [6]:
model.load_state_dict(torch.load("final_model.pt"))

<All keys matched successfully>

In [7]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
def preprocess_text(text):
    parts = []

    text_len = len(text.split(' '))
    max_parts = 5
    nb_cuts = int(text_len / 300)
    nb_cuts = min(nb_cuts, max_parts)
    
    
    for i in range(nb_cuts + 1):
        text_part = ' '.join(text.split(' ')[i * 300: (i + 1) * 300])
        parts.append(tokenizer.encode(text_part, return_tensors="pt", max_length=500).to(device))

    return parts

In [8]:
test = "Democrats in the Texas Legislature staged a dramatic, late-night walkout on Sunday night to force the failure of a sweeping Republican overhaul of state election laws. The move, which deprived the session of the minimum number of lawmakers required for a vote before a midnight deadline, was a stunning setback for state Republicans who had made a new voting law one of their top priorities. The effort is not entirely dead, however. Gov. Greg Abbott, a Republican, indicated that he would call a special session of the Legislature, which could start as early as June 1, or Tuesday, to restart the process. The governor has said that he strongly supported an election bill, and in a statement he called the failure to reach one on Sunday “deeply disappointing.” He was widely expected to sign whatever measure Republicans passed. Election Integrity & Bail Reform were emergency items for this legislative session,” Mr. Abbott said on Twitter on Sunday night. “They will be added to the special session agenda.” He did not specify when the session would start. While Republicans would still be favored to pass a bill in a special session, the unexpected turn of events on Sunday presents a new hurdle in their push to enact a far-reaching election law that would install some of the most rigid voting restrictions in the country, and cement the state as one of the hardest in which to cast a ballot."
RT = preprocess_text(test)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


In [9]:
overall_output = torch.zeros((1,2)).to(device)

for part in RT:
    if len(part) > 0:
        overall_output += model(part.reshape(1, -1))[0]
        
overall_output = F.softmax(overall_output[0], dim=-1)
            
result = overall_output.max(0)[1].float().item()

In [10]:
if result==0.0:
    print("Real_News")
else:
    print("Fake_News")

Real_News


In [11]:
text_parts ="by nick bernabe following the recent mass arrests of  people at the dakota access pipeline construction site located near standing rock north dakota an anonymous donor just donated  million to bail out everyone who was arrested at the protests the news came after tamara francisfourkiller a tribal leader from the caddo nation tribe in caddo county oklahoma was arrested at standing rock francisfourkiller was released after spending two days in jail but her family says she was just an innocent observer in the clashes between militarized law enforcement and native american activists or water protectors  update   pm est a statement from red owl legal collectivenational lawyers guild that is advising standing rock has issued a statement saying the  million has not been received yet we are waiting on confirmation from the caddo nation tribe according to local news affiliate news on   family members of caddo nation chairwoman tamara francisfourkiller said an anonymous donor paid  million late saturday afternoon to release everyone arrested on thursday at the dakota access pipeline site they said however that francisfourkiller should not have been arrested in the first place though the donor who sent the  million remains anonymous it appears the person is connected to the caddo nation tribe in some way"
FT = preprocess_text(text_parts)

In [12]:
overall_output = torch.zeros((1,2)).to(device)

for part in FT:
    if len(part) > 0:
        overall_output += model(part.reshape(1, -1))[0]
        
overall_output = F.softmax(overall_output[0], dim=-1)
            
result = overall_output.max(0)[1].float().item()

In [13]:
if result==0.0:
    print("Real_News")
else:
    print("Fake_News")

Fake_News
