In [1]:
%%capture
import sys
!{sys.executable} -m pip install -U pandas-profiling[notebook]
!jupyter nbextension enable --py widgetsnbextension
!pip install spacy
!pip install multiprocessing
!pip install plotly

In [84]:
import pandas as pd
import torch
import torch.nn as nn
import multiprocessing
import numpy as np
from sklearn.model_selection import train_test_split
from torchtext.vocab import build_vocab_from_iterator
from sklearn.metrics import accuracy_score, recall_score
from torch.nn import functional as F
import plotly.graph_objects as go
from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset

# Подготовка данных
Так как у нас в одном файле лежат true news, а в другом fake news, необходимо два DF объединить в один и установить классовые значения. Так как у нас только два класса, то для fake news значение target установим равное 1, а для true news - 0.

In [85]:
data_true = pd.read_csv('True.csv')
data_true.sample(10)

Unnamed: 0,title,text,subject,date
1679,Anti-Assad nations say no to Syria reconstruct...,"NEW YORK (Reuters) - The United States, Britai...",politicsNews,"September 18, 2017"
15897,Seven bodies found in migrant boat off Libya,ROME (Reuters) - The bodies of seven migrants ...,worldnews,"November 1, 2017"
6546,House speaker: House ethics watchdog will rema...,WASHINGTON (Reuters) - U.S. House Speaker Paul...,politicsNews,"January 3, 2017"
7709,Former Secretary of State Powell will vote for...,"(Reuters) - Colin Powell, who served as secret...",politicsNews,"October 25, 2016"
6601,Trump taps RNC's Spicer for White House spokesman,WASHINGTON (Reuters) - The Republican National...,politicsNews,"December 22, 2016"
7586,Factbox: Wall Street's take on possible outcom...,(Reuters) - U.S. Democratic presidential candi...,politicsNews,"November 3, 2016"
2963,Senate Republicans struggle to salvage healthc...,WASHINGTON (Reuters) - The top U.S. Senate Rep...,politicsNews,"June 28, 2017"
14795,"Zimbabwe soldiers, armored vehicles seal road ...",HARARE (Reuters) - Zimbabwean soldiers and arm...,worldnews,"November 15, 2017"
9179,"Alabama speaker convicted of ethics charges, r...",(Reuters) - Alabama House Speaker Mike Hubbard...,politicsNews,"June 11, 2016"
5819,Senate Democrats delay committee votes on Sess...,WASHINGTON (Reuters) - U.S. Senate Democrats o...,politicsNews,"January 31, 2017"


In [86]:
data_fake = pd.read_csv('Fake.csv')
data_fake.sample(10)

Unnamed: 0,title,text,subject,date
9783,WATCH: OFFICER WHO CRASHED Motorcycle While Es...,It doesn t matter how many times President Tru...,politics,"Sep 29, 2017"
16140,SHOCKING! HERE’S HOW OFTEN Airlines Bump Passe...,Who knew that the Department of Transportation...,Government News,"Apr 11, 2017"
14287,OBAMA TRIES TO CONDEMN TRUMP SUPPORTERS But Ge...,How dare Obama the worst president and most di...,politics,"Mar 15, 2016"
2438,Chuck Schumer: Republicans Have ‘Real Problem...,Republicans may act like they have faith in Do...,News,"February 21, 2017"
21333,STUDENTS THREATEN YALE PRESIDENT: Give Us $8 M...,Isn t there a legal term for those kinds of th...,left-news,"Nov 14, 2015"
13720,BILL CLINTON EX-LOVER Spills The Beans On “Lum...,Wow! This first-hand assessment of Bill clinto...,politics,"Jun 10, 2016"
1871,Nunes’ Replacement In Trump Investigation Def...,The new leader of what can laughably be called...,News,"April 6, 2017"
17597,"BREAKING: Democrat Congressman, Vocal ILLEGAL ...","Rep. Luis Guti rrez (Ill.), one of the most vo...",left-news,"Nov 27, 2017"
548,Donald Trump Tried To Manipulate Stock Market...,"Donald Trump hates Jeff Bezos, the CEO of Amaz...",News,"August 16, 2017"
5745,WATCH: Ex-Trump Campaign Manager’s First Appe...,The announcement that Donald Trump s problemat...,News,"June 24, 2016"


In [87]:
data_fake['target'] = 1
data_true['target'] = 0

In [88]:
data = pd.concat([data_true, data_fake], ignore_index=True)
data.head(10)

Unnamed: 0,title,text,subject,date,target
0,"As U.S. budget fight looms, Republicans flip t...",WASHINGTON (Reuters) - The head of a conservat...,politicsNews,"December 31, 2017",0
1,U.S. military to accept transgender recruits o...,WASHINGTON (Reuters) - Transgender people will...,politicsNews,"December 29, 2017",0
2,Senior U.S. Republican senator: 'Let Mr. Muell...,WASHINGTON (Reuters) - The special counsel inv...,politicsNews,"December 31, 2017",0
3,FBI Russia probe helped by Australian diplomat...,WASHINGTON (Reuters) - Trump campaign adviser ...,politicsNews,"December 30, 2017",0
4,Trump wants Postal Service to charge 'much mor...,SEATTLE/WASHINGTON (Reuters) - President Donal...,politicsNews,"December 29, 2017",0
5,"White House, Congress prepare for talks on spe...","WEST PALM BEACH, Fla./WASHINGTON (Reuters) - T...",politicsNews,"December 29, 2017",0
6,"Trump says Russia probe will be fair, but time...","WEST PALM BEACH, Fla (Reuters) - President Don...",politicsNews,"December 29, 2017",0
7,Factbox: Trump on Twitter (Dec 29) - Approval ...,The following statements were posted to the ve...,politicsNews,"December 29, 2017",0
8,Trump on Twitter (Dec 28) - Global Warming,The following statements were posted to the ve...,politicsNews,"December 29, 2017",0
9,Alabama official to certify Senator-elect Jone...,WASHINGTON (Reuters) - Alabama Secretary of St...,politicsNews,"December 28, 2017",0


In [89]:
data['target'].value_counts()

target
1    23481
0    21417
Name: count, dtype: int64

Так как колонки 'subject' и 'date' нам особо никакой информации полезной не несут, их можно убрать из df, чтобы они не мешали. В некоторых записях встречаются пустые строки с полем 'text', поэтому было решено соединить 'title' (в ней тоже может быть полезная информация, которая поможет классифицировать новости) и 'text' в общую колонку 'title_text', и по ней уже производить предсказания.

In [90]:
data.drop(['subject', 'date'], axis=1, inplace=True)

In [91]:
data['title_text'] = data.apply(
    lambda row: (row['title']  + ' ' + row['text'])[:512],
    axis=1,
)

In [9]:
data.sample(7)

Unnamed: 0,title,text,target,title_text
1502,NFL protests are protected speech but 'misguid...,WASHINGTON (Reuters) - U.S. House Speaker Paul...,0,NFL protests are protected speech but 'misguid...
17666,Palestinian accord must abide by international...,JERUSALEM (Reuters) - Any Palestinian reconcil...,0,Palestinian accord must abide by international...
31560,TRUMP CHALLENGES FAKE MEDIA: “Are we going to ...,You have to give it to President Trump who wen...,1,TRUMP CHALLENGES FAKE MEDIA: “Are we going to ...
23407,Fox News Frantically Tries To Cover Up Trump’...,It s not clear whether Trump called in a favor...,1,Fox News Frantically Tries To Cover Up Trump’...
11914,Ramaphosa's ANC election win lifts South Afric...,JOHANNESBURG (Reuters) - South African banking...,0,Ramaphosa's ANC election win lifts South Afric...
9241,Sanders plants seeds for a lasting U.S. progre...,WASHINGTON (Reuters) - Bernie Sanders’ upstart...,0,Sanders plants seeds for a lasting U.S. progre...
44866,"BOILER ROOM – EP #43 – Cloppers, OR Osmosis, M...",Tune in to the Alternate Current Radio Network...,1,"BOILER ROOM – EP #43 – Cloppers, OR Osmosis, M..."


In [92]:
pd.isnull(data['title_text']).values.sum()
n = pd.isnull(data['title_text'])
data[n]

Unnamed: 0,title,text,target,title_text


Так как записей довольно много и обрабатываются они долго, использовалось распараллеливание. Вся обработка происходит в файле processing.py в функции multipreprocessing_text_spacy.

In [93]:
from preprocessing import multipreprocessing_text

with multiprocessing.Pool(processes=20) as pool:
    data['cleaned_text'] = pool.map(multipreprocessing_text, data['title_text'])
data.sample(10)

Unnamed: 0,title,text,target,title_text,cleaned_text
14615,Cambodia's main opposition party dissolved by ...,PHNOM PENH (Reuters) - Cambodia s highest cour...,0,Cambodia's main opposition party dissolved by ...,"[cambodias, main, opposition, party, dissolved..."
14013,Mexican leftist Lopez Obrador leads presidenti...,(This version of the November 22nd story corr...,0,Mexican leftist Lopez Obrador leads presidenti...,"[mexican, leftist, lopez, obrador, leads, pres..."
42290,HOW WE KNOW AMERICA IS FINALLY WINNING: Popula...,There hasn t been a time in decades when the L...,1,HOW WE KNOW AMERICA IS FINALLY WINNING: Popula...,"[how, we, know, america, is, finally, winning,..."
12925,OAS may recommend new Honduras election unless...,(Reuters) - The Organization of American State...,0,OAS may recommend new Honduras election unless...,"[oas, may, recommend, new, honduras, election,..."
31747,NEW WH COMMUNICATIONS DIRECTOR: I’ll Bring CNN...,Jake Tapper and the new White House Communicat...,1,NEW WH COMMUNICATIONS DIRECTOR: I’ll Bring CNN...,"[new, wh, communications, director, ill, bring..."
16269,'Death to blasphemers' increasing as political...,"SWABI, Pakistan (Reuters) - Three police offic...",0,'Death to blasphemers' increasing as political...,"[death, to, blasphemers, increasing, as, polit..."
1425,"Warren Buffett, Larry Fink criticize Trump tax...",WASHINGTON (Reuters) - President Donald Trump’...,0,"Warren Buffett, Larry Fink criticize Trump tax...","[warren, buffett, larry, fink, criticize, trum..."
4473,"Infrastructure overhaul may top $1 trillion, c...",WASHINGTON (Reuters) - President Donald Trump ...,0,"Infrastructure overhaul may top $1 trillion, c...","[infrastructure, overhaul, may, top, 1, trilli..."
41643,ALL KIDDING ASIDE…DID HILLARY JUST HAVE A SEIZ...,Whoa! We asked the question yesterday in more ...,1,ALL KIDDING ASIDE…DID HILLARY JUST HAVE A SEIZ...,"[all, kidding, asidedid, hillary, just, have, ..."
37071,DOJ LEADER PUTS BLAME FOR RIOTS ON…SLAVERY?,It s no accident that President Obama named Va...,1,DOJ LEADER PUTS BLAME FOR RIOTS ON…SLAVERY? It...,"[doj, leader, puts, blame, for, riots, onslave..."


In [94]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Для того, чтобы иметь возможность обучить модели, необходимо представить данные в числовом виде. Для этого был построен словарь с помощью функции build_vocab_from_iterator из PyTorch. Для незнакомых слов было добавлено слово "<UNK>". Словарь получился размером в 76149 слов. Далее была добавлена колонка 'indexes', в которой для каждого слова из текста сопоставляется индекс из полученного словаря. Таким образов был получен вектор, характеризующий каждое предложение в датасете. Так как размеры у векторов разные, был найден наибольший (max), после чего вектора меньшей длины "добились" 0 до max.

In [95]:
vocab_my = build_vocab_from_iterator((data['cleaned_text']), min_freq=1, specials=["<UNK>"])
vocab_my.set_default_index(vocab_my["<UNK>"])
n_tokens = len(vocab_my)
print(n_tokens)

76149


In [96]:
data['indexes'] = data['cleaned_text'].apply(lambda tokens: [vocab_my[token] for token in tokens])

In [97]:
data.sample(10)

Unnamed: 0,title,text,target,title_text,cleaned_text,indexes
9169,Ryan on Orlando shooting: 'We are a nation at ...,WASHINGTON (Reuters) - House of Representative...,0,Ryan on Orlando shooting: 'We are a nation at ...,"[ryan, on, orlando, shooting, we, are, a, nati...","[353, 7, 1463, 652, 53, 35, 3, 372, 27, 195, 1..."
17645,U.S. believes current North Korea nuclear thre...,WASHINGTON (Reuters) - White House Chief of St...,0,U.S. believes current North Korea nuclear thre...,"[us, believes, current, north, korea, nuclear,...","[15, 1367, 544, 110, 146, 221, 657, 12, 39183,..."
12357,North Korea shows no sign it is serious about ...,WASHINGTON (Reuters) - The State Department sa...,0,North Korea shows no sign it is serious about ...,"[north, korea, shows, no, sign, it, is, seriou...","[110, 146, 447, 83, 773, 19, 12, 1045, 36, 970..."
17686,"Palestinian rivals Hamas, Fatah agree to compl...",CAIRO (Reuters) - Palestinian rivals Hamas and...,0,"Palestinian rivals Hamas, Fatah agree to compl...","[palestinian, rivals, hamas, fatah, agree, to,...","[1539, 2644, 4108, 7303, 1067, 2, 1457, 2876, ..."
7305,Abe aims to underscore importance of Japan-U.S...,TOKYO (Reuters) - Japanese Prime Minister Shin...,0,Abe aims to underscore importance of Japan-U.S...,"[abe, aims, to, underscore, importance, of, ja...","[1632, 2822, 2, 12135, 3118, 4, 21356, 1476, 1..."
18933,"U.S. Commerce Secretary says market access, pr...",HONG KONG (Reuters) - U.S. Commerce Secretary ...,0,"U.S. Commerce Secretary says market access, pr...","[us, commerce, secretary, says, market, access...","[15, 2087, 157, 72, 1062, 946, 8374, 179, 183,..."
17688,Trump resists pressure to soften stance on Ira...,WASHINGTON (Reuters) - President Donald Trump ...,0,Trump resists pressure to soften stance on Ira...,"[trump, resists, pressure, to, soften, stance,...","[8, 42301, 802, 2, 10270, 1782, 7, 199, 221, 1..."
14273,Tunisia PM will go ahead with painful policy d...,TUNIS (Reuters) - Tunisia will continue with a...,0,Tunisia PM will go ahead with painful policy d...,"[tunisia, pm, will, go, ahead, with, painful, ...","[6628, 343, 37, 241, 473, 13, 6083, 249, 449, ..."
2980,Trump says Senate Republicans likely to pass h...,WASHINGTON (Reuters) - U.S. President Donald T...,0,Trump says Senate Republicans likely to pass h...,"[trump, says, senate, republicans, likely, to,...","[8, 72, 107, 124, 387, 2, 875, 407, 101, 41, 1..."
9809,"Driven up the wall by Trump, Mexico looks to r...","MEXICO CITY (Reuters) - At first, Mexico’s gov...",0,"Driven up the wall by Trump, Mexico looks to r...","[driven, up, the, wall, by, trump, mexico, loo...","[4074, 68, 1, 397, 22, 8, 327, 1230, 2, 18216,..."


In [98]:
max_sequence_length = max([len(index_list) for index_list in data['indexes']])
for i in range(len(data['indexes'])):
    padding_length = max_sequence_length - len(data['indexes'][i])
    data['indexes'][i] += [0] * padding_length

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [99]:
random_state = 42
n_layers = 1
batch_size = 64
torch.manual_seed(random_state)
torch.cuda.manual_seed(random_state)
torch.backends.cudnn.deterministic = True

In [100]:
x_train, x_test, y_train, y_test = train_test_split(data['indexes'], data['target'],
                                                    random_state=random_state)

In [101]:
x_train_list = list(x_train)
y_train_list = list(y_train)

x_train_array = np.array(x_train_list)
y_train_array = np.array(y_train_list)

x_train_tensor = torch.LongTensor(x_train_array)
y_train_tensor = torch.LongTensor(y_train_array)

x_test_list = list(x_test)
y_test_list = list(y_test)
x_test_array = np.array(x_test_list)
y_test_array = np.array(y_test_list)

x_test_tensor = torch.LongTensor(x_test_array)
y_test_tensor = torch.LongTensor(y_test_array)

tensor_train = TensorDataset(x_train_tensor, y_train_tensor)
tensor_test = TensorDataset(x_test_tensor, y_test_tensor)

dataloader_train = DataLoader(tensor_train, batch_size=batch_size, shuffle=True, num_workers=20)
dataloader_test = DataLoader(tensor_test, batch_size=batch_size, shuffle=True, num_workers=20)

Для обучения была прописана функция fit, в которой регистрируются loss на каждой эпохе и значения метрики accuracy на тестовой выборке. В функции plot_training реализовано построение графиков изменения значений loss для train и test, и построен график изменения accuracy на test. В функции model_predictions производит предсказание модели на всех данных test, возвращая предсказанные данные и те, которые реальные. Далее полученные массивы используются в функциях подсчета метрики recall. 

In [102]:
def fit(epochs, model, loss_func, opt, train_dl, test_dl):
    train_losses = []
    test_losses = []
    test_accuracies = []
    for epoch in range(epochs):
        model.train()
        loss_sum = 0
        for xb, yb in tqdm(train_dl):
            xb, yb = xb.to(device), yb.to(device)
            
            loss = loss_func(model(xb), yb)
            loss_sum += loss.item()
            
            loss.backward()
            opt.step()
            opt.zero_grad()
        train_losses.append(loss_sum)

        print(f'Epoch [{epoch+1}], Loss: {loss_sum}')
        model.eval()
        loss_sum = 0
        correct = 0
        num = 0
    
    
        with torch.no_grad():
            for xb, yb in tqdm(test_dl):
                xb, yb = xb.to(device), yb.to(device)
                
                probs = model(xb)
                loss_sum += loss_func(probs, yb).item()
                
                _, pred = torch.max(probs, axis=-1)
                correct += (pred == yb).sum().item()
                num += len(xb)
                
        accuracy = correct / num
        test_losses.append(loss_sum)
        test_accuracies.append(accuracy)
   
        print(f'Epoch [{epoch+1}], Loss: {loss_sum}, Accuracy: {accuracy}')    
    return train_losses, test_losses, test_accuracies

In [103]:
def plot_training(train_losses, test_losses, test_accuracy):
    epochs = list(range(1, len(train_losses) + 1))
    
    fig = go.Figure()

    fig.add_trace(go.Scatter(x=epochs, y=train_losses, mode='lines', name='Train Loss'))
    fig.add_trace(go.Scatter(x=epochs, y=test_losses, mode='lines', name='Test Loss'))

    fig.update_layout(
        title='Train and Test Loss',
        xaxis=dict(title='Epoch'),
        yaxis=dict(title='Loss'),
        template='plotly_dark',
        showlegend=True
    )

    fig_accuracy = go.Figure()
    fig_accuracy.add_trace(go.Scatter(x=epochs, y=test_accuracy, mode='lines', name='Test Accuracy'))
    
    fig_accuracy.update_layout(
        title='Test Accuracy',
        xaxis=dict(title='Epoch'),
        yaxis=dict(title='Accuracy'),
        template='plotly_dark',
        showlegend=True
    )

    fig.show()
    fig_accuracy.show()

In [104]:
def model_predictions(model, loader):
    real_target, pred_target = [], []
    for X, Y in loader:
        pred = model(X)
        pred_target.append(pred)
        real_target.append(Y)
    pred_target, real_target = torch.cat(pred_target), torch.cat(real_target)

    return real_target.detach().numpy(), F.softmax(pred_target, dim=-1).argmax(dim=-1).cpu().detach().numpy()

# LSTM

In [105]:
class ClassificationModelLSTM(nn.Module):
    def __init__(self, emb_size=16, hid_size=128, num_classes=2):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=n_tokens, embedding_dim=emb_size)
        self.lstm = nn.LSTM(input_size=emb_size, hidden_size=hid_size, batch_first=True)
        self.linear = nn.Linear(in_features=hid_size, out_features=num_classes)

    def forward(self, input_ix):
        input_ix = torch.as_tensor(input_ix).to(device)
        embeddings = self.embedding(input_ix)
        hn = torch.zeros(n_layers, batch_size, self.lstm.hidden_size).to(input_ix.device)
        cn = torch.zeros(n_layers, batch_size, self.lstm.hidden_size).to(input_ix.device)
        out, (hn, cn) = self.lstm(embeddings)
        out = self.linear(out)
        return out[:,-1]

In [106]:
torch.manual_seed(random_state)

<torch._C.Generator at 0x2347fd83970>

In [107]:
model = ClassificationModelLSTM().to(device)
opt = torch.optim.Adam(model.parameters())

In [108]:
model

ClassificationModelLSTM(
  (embedding): Embedding(76149, 16)
  (lstm): LSTM(16, 128, batch_first=True)
  (linear): Linear(in_features=128, out_features=2, bias=True)
)

In [111]:
epochs = 10
info = fit(epochs, model, nn.CrossEntropyLoss(), opt, dataloader_train, dataloader_test)
plot_training(*info)

100%|██████████| 527/527 [00:05<00:00, 95.40it/s] 


Epoch [1], Loss: 3.003543375642039


100%|██████████| 176/176 [00:04<00:00, 40.66it/s] 


Epoch [1], Loss: 3.020379475914524, Accuracy: 0.995456570155902


100%|██████████| 527/527 [00:05<00:00, 100.26it/s]


Epoch [2], Loss: 3.9085776543361135


100%|██████████| 176/176 [00:04<00:00, 41.48it/s]


Epoch [2], Loss: 2.297282113577239, Accuracy: 0.9973273942093541


100%|██████████| 527/527 [00:05<00:00, 98.22it/s] 


Epoch [3], Loss: 15.015052857575938


100%|██████████| 176/176 [00:04<00:00, 41.74it/s]


Epoch [3], Loss: 3.316279571969062, Accuracy: 0.9957238307349666


100%|██████████| 527/527 [00:05<00:00, 95.56it/s] 


Epoch [4], Loss: 3.887448985944502


100%|██████████| 176/176 [00:04<00:00, 40.69it/s]


Epoch [4], Loss: 2.7429524579492863, Accuracy: 0.9966146993318485


100%|██████████| 527/527 [00:05<00:00, 96.50it/s] 


Epoch [5], Loss: 2.8281535581336357


100%|██████████| 176/176 [00:04<00:00, 40.77it/s]


Epoch [5], Loss: 6.117588619352318, Accuracy: 0.9920712694877506


100%|██████████| 527/527 [00:05<00:00, 98.56it/s] 


Epoch [6], Loss: 2.684734739479609


100%|██████████| 176/176 [00:04<00:00, 41.40it/s]


Epoch [6], Loss: 2.331713298917748, Accuracy: 0.9970601336302896


100%|██████████| 527/527 [00:05<00:00, 98.51it/s] 


Epoch [7], Loss: 2.0584152437804732


100%|██████████| 176/176 [00:04<00:00, 40.38it/s] 


Epoch [7], Loss: 2.9400285180599894, Accuracy: 0.9967037861915368


100%|██████████| 527/527 [00:05<00:00, 98.32it/s] 


Epoch [8], Loss: 1.6707232468761504


100%|██████████| 176/176 [00:04<00:00, 40.55it/s]


Epoch [8], Loss: 2.480083775881212, Accuracy: 0.9970601336302896


100%|██████████| 527/527 [00:05<00:00, 96.85it/s] 


Epoch [9], Loss: 1.5042709352564998


100%|██████████| 176/176 [00:04<00:00, 40.75it/s]


Epoch [9], Loss: 2.830445627318113, Accuracy: 0.9965256124721603


100%|██████████| 527/527 [00:05<00:00, 95.97it/s] 


Epoch [10], Loss: 1.228446291432192


100%|██████████| 176/176 [00:04<00:00, 41.55it/s] 

Epoch [10], Loss: 2.724084989342373, Accuracy: 0.9972383073496659





In [112]:
real_target_LSTM, pred_target_LSTM = model_predictions(model, dataloader_test)

In [113]:
print("Accuracy: {}".format(accuracy_score(real_target_LSTM, pred_target_LSTM)))
print("Recall: {}".format(recall_score(real_target_LSTM, pred_target_LSTM)))

Accuracy: 0.9972383073496659
Recall: 0.9967404357522731


# CNN

In [114]:
kernel_size = 5
stride = 1
class ClassificationModelCNN(nn.Module):
    def __init__(self, emb_size=16, hid_size=128, num_classes=2):  
        super().__init__()
        self.embedding = nn.Embedding(n_tokens, emb_size)
        self.conv_1 = nn.Conv1d(in_channels=emb_size, out_channels=hid_size, kernel_size=kernel_size, stride=stride)
        self.conv_2 = nn.Conv1d(in_channels=hid_size, out_channels=hid_size, kernel_size=kernel_size, stride=stride)
        self.linear = nn.Linear(in_features=hid_size, out_features=num_classes)

    def forward(self, input_ix):
        input_ix = torch.as_tensor(input_ix).to(device)
        embeddings = self.embedding(input_ix)
        out = self.conv_1(embeddings.permute(0, 2, 1))  
        out = self.conv_2(out)
        out = torch.max(out, dim=2)[0]  
        out = self.linear(out)
        return out

In [115]:
model_CNN = ClassificationModelCNN().to(device)
opt_CNN = torch.optim.Adam(model_CNN.parameters())

In [116]:
model_CNN

ClassificationModelCNN(
  (embedding): Embedding(76149, 16)
  (conv_1): Conv1d(16, 128, kernel_size=(5,), stride=(1,))
  (conv_2): Conv1d(128, 128, kernel_size=(5,), stride=(1,))
  (linear): Linear(in_features=128, out_features=2, bias=True)
)

In [117]:
epochs = 10
info_CNN = fit(epochs, model_CNN, nn.CrossEntropyLoss(), opt_CNN, dataloader_train, dataloader_test)
plot_training(*info_CNN)

100%|██████████| 527/527 [00:05<00:00, 104.32it/s]


Epoch [1], Loss: 66.1779842969845


100%|██████████| 176/176 [00:04<00:00, 41.92it/s]


Epoch [1], Loss: 3.0647618734510615, Accuracy: 0.9964365256124722


100%|██████████| 527/527 [00:05<00:00, 105.28it/s]


Epoch [2], Loss: 4.744091896231112


100%|██████████| 176/176 [00:04<00:00, 41.20it/s]


Epoch [2], Loss: 2.54709131384152, Accuracy: 0.9970601336302896


100%|██████████| 527/527 [00:05<00:00, 104.39it/s]


Epoch [3], Loss: 2.0819032820691064


100%|██████████| 176/176 [00:04<00:00, 42.05it/s]


Epoch [3], Loss: 2.787452163422131, Accuracy: 0.996792873051225


100%|██████████| 527/527 [00:05<00:00, 103.06it/s]


Epoch [4], Loss: 0.8353481942267535


100%|██████████| 176/176 [00:04<00:00, 42.80it/s]


Epoch [4], Loss: 2.890138496637519, Accuracy: 0.9971492204899778


100%|██████████| 527/527 [00:04<00:00, 106.69it/s]


Epoch [5], Loss: 0.3620844869356006


100%|██████████| 176/176 [00:04<00:00, 40.81it/s]


Epoch [5], Loss: 2.8067308649478946, Accuracy: 0.9972383073496659


100%|██████████| 527/527 [00:05<00:00, 102.94it/s]


Epoch [6], Loss: 0.10665833158486748


100%|██████████| 176/176 [00:04<00:00, 41.79it/s]


Epoch [6], Loss: 3.203924149631348, Accuracy: 0.9970601336302896


100%|██████████| 527/527 [00:05<00:00, 104.02it/s]


Epoch [7], Loss: 0.025823418554864475


100%|██████████| 176/176 [00:04<00:00, 42.98it/s]


Epoch [7], Loss: 3.4633832142053507, Accuracy: 0.9971492204899778


100%|██████████| 527/527 [00:05<00:00, 103.97it/s]


Epoch [8], Loss: 0.01207351333061979


100%|██████████| 176/176 [00:04<00:00, 41.90it/s]


Epoch [8], Loss: 3.689043616046092, Accuracy: 0.9972383073496659


100%|██████████| 527/527 [00:04<00:00, 106.37it/s]


Epoch [9], Loss: 0.008443841980522393


100%|██████████| 176/176 [00:04<00:00, 41.65it/s]


Epoch [9], Loss: 3.674695024363473, Accuracy: 0.9972383073496659


100%|██████████| 527/527 [00:05<00:00, 102.17it/s]


Epoch [10], Loss: 0.005817373430232919


100%|██████████| 176/176 [00:04<00:00, 43.11it/s]

Epoch [10], Loss: 3.8500602446413836, Accuracy: 0.9972383073496659





In [119]:
real_target_CNN, pred_target_CNN = model_predictions(model_CNN, dataloader_test)

In [120]:
print("Accuracy: {}".format(accuracy_score(real_target_CNN, pred_target_CNN)))
print("Recall: {}".format(recall_score(real_target_CNN, pred_target_CNN)))

Accuracy: 0.9972383073496659
Recall: 0.998627551895694


# Итоги
В результате получилось, что обе модели показали score примерно равный 0.997-0.998. Обучать их долго не пришлось. Для получения такого результата вполне достаточно 10 эпох. Так как нам важно понимать, как много фейков мы распознаем, необходимо проверить метрику recall. Она тоже у обеих моделей имеет примерно одинаковое значение.