In [1]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


# Load Dataset

In [3]:
from datasets import load_dataset, DatasetDict

# load data
dataset: DatasetDict = load_dataset("LittleFish-Coder/Fake_News_KDD2020", download_mode="reuse_cache_if_exists", cache_dir="dataset")   # type: ignore

Generating train split: 100%|██████████| 4487/4487 [00:00<00:00, 35809.87 examples/s]
Generating test split: 100%|██████████| 499/499 [00:00<00:00, 34915.22 examples/s]


In [4]:
# data
print(f"Dataset: {dataset}")
train_dataset = dataset["train"]
test_dataset = dataset["test"]

Dataset: DatasetDict({
    train: Dataset({
        features: ['text', 'embeddings', 'label'],
        num_rows: 4487
    })
    test: Dataset({
        features: ['text', 'embeddings', 'label'],
        num_rows: 499
    })
})


In [6]:
# quick look at the data
first_train = train_dataset[0]
print(f"First training sample")
print(f"Keys: {first_train.keys()}")
print(f"Text: {first_train['text']}")
print(f"Label: {first_train['label']}")

First training sample
Keys: dict_keys(['text', 'embeddings', 'label'])
Text: Oops. Something went wrong. Please try again later  Looks like we are having a problem on the server.
Label: 0


In [7]:
text = test_dataset[0]["text"]
print(f"Text: {text}")

Text: Reports of Lawrence and Pitt dating spread in December 2017. Watch What Happens Live/YouTube and Jason Kempin/Getty Images for Netflix  Jennifer Lawrence appeared on Bravo's "Watch What Happens Live" on Thursday, March 1 and responded to the reports that she was dating Brad Pitt.  When a caller asked her if the two stars were "secretly dating," the "Red Sparrow" actress denied it.  Even though she says the reports weren't true, Lawrence admitted that she didn't mind the speculation too much.  "No," she said. "I've met him once in like, 2013, so it was very random, but I also wasn't like, in a hurry to debunk it."  In December 2017, it was speculated that Lawrence and Pitt were dating.  In December 2017, it was reported that Jennifer Lawrence was dating Brad Pitt, but the "Red Sparrow" actress has now cleared up the speculation and revealed it's not true.  While appearing on Bravo's "Watch What Happens Live" on Thursday, Lawrence answered questions from fans who called in, and one

# Directly Text Classification (Pipeline)

In [9]:
dataset_name = 'kdd2020'    # ['fake-news-tfg', 'kdd2020']
model_name = 'distilbert-base-uncased'  # ['bert-base-uncased', 'distilbert-base-uncased', 'roberta-base']

In [11]:
# Use a pipeline as a high-level helper
from transformers import pipeline

# get the model from huggingface model hub
pipe = pipeline("text-classification", model=f"LittleFish-Coder/{model_name}-{dataset_name}", truncation=True, device=device)

In [12]:
pipe(text)

[{'label': 'fake', 'score': 0.7897346615791321}]

# Tokenizer and Pretrained-Model

In [14]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained(f"LittleFish-Coder/{model_name}-{dataset_name}")
model = AutoModelForSequenceClassification.from_pretrained(f"LittleFish-Coder/{model_name}-{dataset_name}")

## Predict via tokenizer & model

### Tokenize the text and get the class

In [17]:
inputs = tokenizer(text, return_tensors="pt", truncation=True)

In [18]:
print(f"Input: {inputs}")

Input: {'input_ids': tensor([[  101,  4311,  1997,  5623,  1998, 15091,  5306,  3659,  1999,  2285,
          2418,  1012,  3422,  2054,  6433,  2444,  1013,  7858,  1998,  4463,
         20441,  2378,  1013,  2131,  3723,  4871,  2005, 20907,  7673,  5623,
          2596,  2006, 17562,  1005,  1055,  1000,  3422,  2054,  6433,  2444,
          1000,  2006,  9432,  1010,  2233,  1015,  1998,  5838,  2000,  1996,
          4311,  2008,  2016,  2001,  5306,  8226, 15091,  1012,  2043,  1037,
         20587,  2356,  2014,  2065,  1996,  2048,  3340,  2020,  1000, 10082,
          5306,  1010,  1000,  1996,  1000,  2417, 19479,  1000,  3883,  6380,
          2009,  1012,  2130,  2295,  2016,  2758,  1996,  4311,  4694,  1005,
          1056,  2995,  1010,  5623,  4914,  2008,  2016,  2134,  1005,  1056,
          2568,  1996, 12143,  2205,  2172,  1012,  1000,  2053,  1010,  1000,
          2016,  2056,  1012,  1000,  1045,  1005,  2310,  2777,  2032,  2320,
          1999,  2066,  1010,  

In [21]:
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits

predicted_class_id = logits.argmax().item()
prediction = model.config.id2label[predicted_class_id]
print(f"Output: {outputs}")
print(f"Logits: {logits}")
print(f"Prediction: {prediction}")

Output: SequenceClassifierOutput(loss=None, logits=tensor([[-0.8020,  0.5214]]), hidden_states=None, attentions=None)
Logits: tensor([[-0.8020,  0.5214]])
Prediction: fake


### Get the embedding of the text

In [32]:
model.config.output_hidden_states = True

# Get model output with hidden states
with torch.no_grad():  # Disable gradient calculation for inference
    outputs = model(**inputs)

# Now, outputs will have the hidden states
hidden_states = outputs.hidden_states

# The last layer's hidden state can be accessed like this
last_hidden_state = hidden_states[-1]

# get the cls token embedding
cls_embeddings = last_hidden_state[:, 0, :]   # (1, 768)

# flatten the embeddings
cls_embeddings = cls_embeddings.flatten()   # (768,)

In [33]:
print(f"hidden states type: {type(hidden_states)}")
print(f"hidden states Length: {len(hidden_states)}")

hidden states type: <class 'tuple'>
hidden states Length: 7


In [34]:
print(f"last hidden state type: {type(last_hidden_state)}")
print(f"last hidden state shape: {last_hidden_state.shape}")

last hidden state type: <class 'torch.Tensor'>
last hidden state shape: torch.Size([1, 498, 768])


In [37]:
print(f"cls embeddings type: {type(cls_embeddings)}")
print(f"cls embeddings shape: {cls_embeddings.shape}")
print(f"cls embeddings: {cls_embeddings}")

cls embeddings type: <class 'torch.Tensor'>
cls embeddings shape: torch.Size([768])
cls embeddings: tensor([ 0.7229, -0.4721, -0.5581, -0.1805,  0.0620,  0.3655,  0.3687,  0.8344,
        -0.2113, -0.0896, -0.1241,  0.5803, -0.4522,  1.1088,  0.6336, -0.0772,
        -0.7328,  0.0427, -0.8334,  0.3931,  0.3962,  0.4526,  0.1150, -0.4167,
        -0.5893, -0.2429,  0.6974, -0.2206,  0.3283,  0.0179,  0.0660, -0.0030,
         0.0931,  1.0225, -0.9738, -0.4682, -0.5479,  0.5818, -0.0042, -0.7241,
        -0.0964, -1.0804,  0.0816, -0.1570,  0.0300,  0.6269, -2.5275,  1.3813,
        -0.3521,  0.4846,  0.6566,  1.1796,  0.0407,  1.1897,  1.3038,  1.0105,
        -0.4753, -1.2677, -0.7465, -0.0726, -0.0122,  0.3052,  0.1873, -1.2204,
         0.4671, -0.8569, -0.2887,  0.9403,  0.4929, -0.8999, -0.1070,  0.0155,
        -0.5269,  0.3997, -0.6080,  0.2152, -0.1855,  0.8197, -0.3379, -0.0529,
        -0.0265,  0.0366, -0.0286,  1.3668, -0.5931, -0.6238, -0.1626,  0.2438,
        -0.5348,  1.

# Embedding Encoder

In [39]:
dataset_name = 'kdd2020'
model_name = 'distilbert-base-uncased'

In [40]:
from transformers import AutoModel, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(f"LittleFish-Coder/{model_name}-{dataset_name}")
model = AutoModel.from_pretrained(f"LittleFish-Coder/{model_name}-{dataset_name}")

In [41]:
inputs = tokenizer(text, return_tensors="pt", truncation=True)

In [50]:
cls_embeddings = model(**inputs).last_hidden_state[:, 0, :].flatten()

In [51]:
print(f"cls embeddings shape: {cls_embeddings.shape}")
print(f"cls embeddings: {cls_embeddings}")

cls embeddings shape: torch.Size([768])
cls embeddings: tensor([ 0.7229, -0.4721, -0.5581, -0.1805,  0.0620,  0.3655,  0.3687,  0.8344,
        -0.2113, -0.0896, -0.1241,  0.5803, -0.4522,  1.1088,  0.6336, -0.0772,
        -0.7328,  0.0427, -0.8334,  0.3931,  0.3962,  0.4526,  0.1150, -0.4167,
        -0.5893, -0.2429,  0.6974, -0.2206,  0.3283,  0.0179,  0.0660, -0.0030,
         0.0931,  1.0225, -0.9738, -0.4682, -0.5479,  0.5818, -0.0042, -0.7241,
        -0.0964, -1.0804,  0.0816, -0.1570,  0.0300,  0.6269, -2.5275,  1.3813,
        -0.3521,  0.4846,  0.6566,  1.1796,  0.0407,  1.1897,  1.3038,  1.0105,
        -0.4753, -1.2677, -0.7465, -0.0726, -0.0122,  0.3052,  0.1873, -1.2204,
         0.4671, -0.8569, -0.2887,  0.9403,  0.4929, -0.8999, -0.1070,  0.0155,
        -0.5269,  0.3997, -0.6080,  0.2152, -0.1855,  0.8197, -0.3379, -0.0529,
        -0.0265,  0.0366, -0.0286,  1.3668, -0.5931, -0.6238, -0.1626,  0.2438,
        -0.5348,  1.4884, -1.6267,  0.4665,  0.2794,  0.4028,  0