In [30]:
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer

import torch
from torch.utils.data import DataLoader

In [146]:
# Loads dataset and splits to train, test & validation
ds = load_dataset("mrm8488/fake-news")

print(ds["train"].features)

# Get test and train data
ds_tmp0 = ds["train"].train_test_split(test_size=0.2, seed=42)

# Get validation data
ds_tmp1 = ds["train"].train_test_split(test_size=0.1, seed=42)

dsd = DatasetDict({
    "train": ds_tmp1["train"],
    "validation": ds_tmp1["test"],
    "test": ds_tmp0["test"]
})

{'text': Value('string'), 'label': Value('int64')}


In [161]:
tok = AutoTokenizer.from_pretrained("distilbert-base-uncased")
vocab_size = tok.vocab_size
num_class = 2
emsize=64

In [148]:
#print("Tokens:", tokens)
#print("IDs:", ids)
#print("Zurück:", back)
#print("Decoded:", decoded)

In [149]:
#len_trainSet = len(dsd["train"])
#len_testSet = len(dsd["test"])
#len_validSet = len(dsd["validation"])
#len_all = len(ds["train"])

#print(len_trainSet)
#print(len_testSet)
#print(len_validSet)
#print(len_all)

40408
8980
4490
44898


In [150]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [151]:
def text_pipeline(texts):
    ids_masks = tok(
        texts,
        truncation=True,
        max_length=512,
        padding=True,
    )
    return ids_masks["input_ids"]

In [152]:
def label_pipeline(x):
    return x

In [153]:
def collate_batch(batch):

    label_list, text_list, offsets = [], [], [0]
    
    for item in batch:
        label, text = item["label"], item["text"]

        # Set label
        label_list.append(label_pipeline(label))

        # Set tokens for text
        processed_text = torch.tensor(text_pipeline(text), dtype=torch.int64)
        text_list.append(processed_text)

        # Set offset form text
        offsets.append(processed_text.size(0))
        
    # Merge single batches to one

    label_list = torch.tensor(label_list, dtype=torch.int64)
    text_list = torch.cat(text_list)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    
    return label_list, text_list, offsets
        

In [154]:
BATCH_SIZE = 64

train_dataloader = DataLoader(dsd["train"], batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(dsd["test"], batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
validation_dataloader = DataLoader(dsd["validation"], batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)


In [164]:
label, text, offsets = next(iter(train_dataloader))
label, text, offsets

(tensor([1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1,
         1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1,
         1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0]),
 tensor([ 101, 2144, 1996,  ..., 2015, 1012,  102]),
 tensor([    0,   341,   353,   792,  1304,  1538,  1935,  2447,  2855,  3306,
          3690,  4202,  4714,  5226,  5583,  6095,  6607,  7119,  7631,  7734,
          8177,  8689,  8786,  9117,  9623, 10135, 10647, 11024, 11080, 11467,
         11852, 12364, 12876, 13349, 13351, 13714, 14119, 14631, 14766, 15278,
         15617, 15641, 15800, 16312, 16824, 16858, 17370, 17493, 17644, 17908,
         18420, 18773, 19285, 19797, 20160, 20454, 20597, 20964, 21347, 21496,
         22008, 22400, 22912, 23332]))

In [162]:
from torch import nn

class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

In [163]:
model = TextClassificationModel(vocab_size, emsize, num_class).to(device)
model

TextClassificationModel(
  (embedding): EmbeddingBag(30522, 64, mode='mean')
  (fc): Linear(in_features=64, out_features=2, bias=True)
)

In [165]:
predicted_label=model(text, offsets)

In [166]:
predicted_label.shape

torch.Size([64, 2])

In [167]:
predicted_label[0]

tensor([ 0.1216, -0.0561], grad_fn=<SelectBackward0>)

In [198]:
def predict(text, text_pipeline):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(text))
        output = model(text, torch.tensor([0]))   
        label = output.argmax(1).item()
        
        if(label == 1):
            return "Correct!"
        else:
            return "Fake news!"

In [200]:
predict("This a real news", text_pipeline)

'Fake news!'