In [None]:
# Reference tutorial:
# https://velog.io/@na2na8/ELECTRA%EB%A1%9C-Binary-Classification#electra-with-pytorch-lightning

In [None]:
!pip install transformers --quiet

[K     |████████████████████████████████| 4.2 MB 7.3 MB/s 
[K     |████████████████████████████████| 6.6 MB 45.0 MB/s 
[K     |████████████████████████████████| 86 kB 5.8 MB/s 
[K     |████████████████████████████████| 596 kB 66.3 MB/s 
[?25h

In [None]:
!pip install git+https://github.com/PyTorchLightning/pytorch-lightning --quiet
import pytorch_lightning as pl
print(pl.__version__)

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
[K     |████████████████████████████████| 140 kB 4.6 MB/s 
[K     |████████████████████████████████| 409 kB 13.1 MB/s 
[K     |████████████████████████████████| 1.1 MB 8.8 MB/s 
[K     |████████████████████████████████| 94 kB 2.4 MB/s 
[K     |████████████████████████████████| 271 kB 54.1 MB/s 
[K     |████████████████████████████████| 144 kB 43.8 MB/s 
[?25h  Building wheel for pytorch-lightning (PEP 517) ... [?25l[?25hdone
1.7.0dev


In [None]:
!pip install wandb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
!wandb login

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
import os
import re

import numpy as np
import pandas as pd

import torch
import torchmetrics
import torch.nn as nn
import wandb
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import loggers as pl_loggers
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

import transformers
from transformers import ElectraForSequenceClassification, ElectraTokenizer, AdamW

device = torch.device("cuda")


In [None]:
!git clone https://github.com/AyushiM1102/Electra_classification_fake_vs_real_news.git

Cloning into 'Electra_classification_fake_vs_real_news'...
remote: Enumerating objects: 63, done.[K
remote: Counting objects: 100% (28/28), done.[K
remote: Compressing objects: 100% (16/16), done.[K
remote: Total 63 (delta 16), reused 22 (delta 12), pack-reused 35[K
Unpacking objects: 100% (63/63), done.


In [None]:
!unzip /content/Electra_classification_fake_vs_real_news/data/WELFake_Dataset.csv.zip -d /content/Electra_classification_fake_vs_real_news/dataset

Archive:  /content/Electra_classification_fake_vs_real_news/data/WELFake_Dataset.csv.zip
  inflating: /content/Electra_classification_fake_vs_real_news/dataset/WELFake_Dataset.csv  
  inflating: /content/Electra_classification_fake_vs_real_news/dataset/__MACOSX/._WELFake_Dataset.csv  


In [None]:
# Overview of database
# df = pd.read_csv('/content/Electra_classification_fake_vs_real_news/sample_dataset/train.csv', sep=',')
datapath = f'/content/Electra_classification_fake_vs_real_news/dataset/WELFake_Dataset.csv'
df = pd.read_csv(datapath, sep=',')
df = df.dropna(axis=0)
df.drop_duplicates(inplace=True)

In [None]:
df['label'].value_counts()

1    36509
0    35028
Name: label, dtype: int64

In [None]:
train_size = int(0.8 * len(df))
val_size = int(0.5*(len(df) - train_size))
test_size = int(val_size)
train_size, val_size, test_size

(57229, 7154, 7154)

In [None]:
train_dataset, val_dataset, test_dataset = df[:train_size],df[train_size:train_size+val_size],df[train_size+val_size:train_size+val_size+test_size]

In [None]:
train_dataset.to_csv('/content/Electra_classification_fake_vs_real_news/dataset/train.csv', index = False)
val_dataset.to_csv('/content/Electra_classification_fake_vs_real_news/dataset/val.csv', index = False)
test_dataset.to_csv('/content/Electra_classification_fake_vs_real_news/dataset/test.csv', index = False)

In [None]:
class ElectraClassificationDataset(Dataset) :
    def __init__(self, path, sep, doc_col, label_col, max_length, num_workers=1, labels_dict=None) :

        self.tokenizer = ElectraTokenizer.from_pretrained("google/electra-small-discriminator")
        self.max_length = max_length
        self.doc_col = doc_col
        self.label_col = label_col

        # labels, ex : {True : 1, False : 0}
        self.labels_dict = labels_dict

        # dataset
        df = pd.read_csv(path, sep=sep)
        df = df.dropna(axis=0)
        df.drop_duplicates(subset=[self.doc_col], inplace=True)
        self.dataset = df

    def __len__(self) :
        return len(self.dataset)
    
    # Clean text
    def cleanse(self, text) :
        url_pattern = re.compile(r'https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)')
        processed = url_pattern.sub(' ', text)
        processed = processed.replace('#', '')
        processed = processed.replace('@', '')
        processed = processed.strip()
        return processed

    def __getitem__(self, idx) :
        document = self.cleanse(self.dataset[self.doc_col].iloc[idx])
        #print(document)
        inputs = self.tokenizer(
            document,
            return_tensors='pt',
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            add_special_tokens=True
        )

        if self.labels_dict :
            label = self.labels_dict[self.dataset[self.label_col].iloc[idx]]
        else :
            label = self.dataset[self.label_col].iloc[idx]

        return {
            'input_ids' : inputs['input_ids'][0],
            'attention_mask' : inputs['attention_mask'][0],
            'label' : int(label)
        }

In [None]:
class ElectraClassificationDataModule(pl.LightningDataModule) :
    def __init__(self, train_path, valid_path, test_path, max_length, batch_size, sep,
                doc_col, label_col, num_workers=1, labels_dict=None) :
        super().__init__()
        self.batch_size = batch_size
        self.train_path = train_path
        self.valid_path = valid_path
        self.test_path = test_path
        self.max_length = max_length
        self.doc_col = doc_col
        self.label_col = label_col
        self.sep = sep
        self.num_workers = num_workers
        self.labels_dict = labels_dict

    def setup(self, stage=None) :
        self.set_train = ElectraClassificationDataset(self.train_path, sep=self.sep,
                                            doc_col=self.doc_col, label_col=self.label_col,
                                            max_length = self.max_length, labels_dict=self.labels_dict)
        self.set_valid = ElectraClassificationDataset(self.valid_path, sep=self.sep,
                                            doc_col=self.doc_col, label_col=self.label_col,
                                            max_length = self.max_length, labels_dict=self.labels_dict)
        self.set_test = ElectraClassificationDataset(self.test_path, sep=self.sep,
                                            doc_col=self.doc_col, label_col=self.label_col,
                                            max_length = self.max_length, labels_dict=self.labels_dict)
        

    def train_dataloader(self) :
        train = DataLoader(self.set_train, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
        return train
    
    def val_dataloader(self) :
        val = DataLoader(self.set_valid, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
        return val
    
    def test_dataloader(self) :
        test = DataLoader(self.set_test, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
        return test

In [None]:
# https://medium.com/huggingface/multi-label-text-classification-using-bert-the-mighty-transformer-69714fa3fb3d
# https://huggingface.co/docs/transformers/v4.15.0/en/model_doc/electra#transformers.ElectraForSequenceClassification

In [None]:
class ElectraClassification(pl.LightningModule) :
    def __init__(self, learning_rate) :
        super().__init__()
        self.learning_rate = learning_rate
        self.save_hyperparameters()
        self.electra = ElectraForSequenceClassification.from_pretrained("google/electra-small-discriminator")

        self.metric_acc = torchmetrics.Accuracy()
        self.metric_f1 = torchmetrics.F1Score(num_classes=2)
        self.metric_rec = torchmetrics.Recall(num_classes=2)
        self.metric_pre = torchmetrics.Precision(num_classes=2)

        self.loss_func = nn.CrossEntropyLoss()

    def forward(self, input_ids, attention_mask, labels=None) :
        output = self.electra(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        return output

    def training_step(self, batch, batch_idx) :
        '''
        ##########################################################
        electra forward input shape information
        * input_ids.shape (batch_size, max_length)
        * attention_mask.shape (batch_size, max_length)
        * label.shape (batch_size,)
        ##########################################################
        '''

        # change label shape (list -> torch.Tensor((batch_size, 1)))
        label = batch['label'].view([-1,1])

        output = self(input_ids=batch['input_ids'].to(device),
                        attention_mask=batch['attention_mask'].to(device),
                        labels=label.to(device))
        '''
        ##########################################################
        electra forward output shape information
        * loss.shape (1,)
        * logits.shape (batch_size, config.num_labels=2)
        '''
        logits = output.logits

        loss = output.loss
        # loss = self.loss_func(logits.to(device), batch['label'].to(device))

        softmax = nn.functional.softmax(logits, dim=1)
        preds = softmax.argmax(dim=1)

        self.log("train_loss", loss, prog_bar=True)
        
        return {
            'loss' : loss,
            'pred' : preds,
            'label' : batch['label']
        }

    def training_epoch_end(self, outputs, state='train') :
        y_true = []
        y_pred = []
        for i in outputs :
            y_true += i['label'].tolist()
            y_pred += i['pred'].tolist()

        acc = accuracy_score(y_true, y_pred)
        prec = precision_score(y_true, y_pred)
        rec = recall_score(y_true, y_pred)
        f1 = f1_score(y_true, y_pred)
        # self.log(state+'_acc', acc, on_epoch=True, prog_bar=True)
        # self.log(state+'_precision', prec, on_epoch=True, prog_bar=True)
        # self.log(state+'_recall', rec, on_epoch=True, prog_bar=True)
        # self.log(state+'_f1', f1, on_epoch=True, prog_bar=True)
        print(f'[Epoch {self.trainer.current_epoch} {state.upper()}] Acc: {acc}, Prec: {prec}, Rec: {rec}, F1: {f1}')

    def validation_step(self, batch, batch_idx) :
        '''
        ##########################################################
        electra forward input shape information
        * input_ids.shape (batch_size, max_length)
        * attention_mask.shape (batch_size, max_length)
        ##########################################################
        '''
        output = self(input_ids=batch['input_ids'].to(device),
                        attention_mask=batch['attention_mask'].to(device))
        logits = output.logits
        preds = nn.functional.softmax(logits, dim=1).argmax(dim=1)
        labels = batch['label']
        accuracy = self.metric_acc(preds, labels)
        f1 = self.metric_f1(preds, labels)
        recall = self.metric_rec(preds, labels)
        precision = self.metric_pre(preds, labels)
        self.log('val_accuracy', accuracy, on_epoch=True, prog_bar=True)
        self.log('val_f1', f1, on_epoch=True, prog_bar=True)
        self.log('val_recall', recall, on_epoch=True, prog_bar=True)
        self.log('val_precision', precision, on_epoch=True, prog_bar=True)
        return {
            'accuracy' : accuracy,
            'f1' : f1,
            'recall' : recall,
            'precision' : precision
        }

    def validation_epoch_end(self, outputs) :
        val_acc = torch.stack([i['accuracy'] for i in outputs]).mean()
        val_f1 = torch.stack([i['f1'] for i in outputs]).mean()
        val_rec = torch.stack([i['recall'] for i in outputs]).mean()
        val_pre = torch.stack([i['precision'] for i in outputs]).mean()
        # self.log('val_f1', val_f1, on_epoch=True, prog_bar=True)
        # self.log('val_acc', val_acc, on_epoch=True, prog_bar=True)
        print(f'val_accuracy : {val_acc}, val_f1 : {val_f1}, val_recall : {val_rec}, val_precision : {val_pre}')
    
    def test_step(self, batch, batch_idx):
        output = self(input_ids=batch['input_ids'].to(device),
                        attention_mask=batch['attention_mask'].to(device))
        logits = output.logits
        preds = nn.functional.softmax(logits, dim=1).argmax(dim=1)
        labels = batch['label']
        accuracy = self.metric_acc(preds, labels)
        f1 = self.metric_f1(preds, labels)
        recall = self.metric_rec(preds, labels)
        precision = self.metric_pre(preds, labels)
        self.log('test_accuracy', accuracy, on_epoch=True, prog_bar=True)
        self.log('test_f1', f1, on_epoch=True, prog_bar=True)
        self.log('test_recall', recall, on_epoch=True, prog_bar=True)
        self.log('test_precision', precision, on_epoch=True, prog_bar=True)

        return {
            'accuracy' : accuracy,
            'f1' : f1,
            'recall' : recall,
            'precision' : precision
        }


    def test_end(self, outputs):
        test_acc = torch.stack([i['accuracy'] for i in outputs]).mean()
        test_f1 = torch.stack([i['f1'] for i in outputs]).mean()
        test_rec = torch.stack([i['recall'] for i in outputs]).mean()
        test_pre = torch.stack([i['precision'] for i in outputs]).mean()
        # self.log('val_f1', val_f1, on_epoch=True, prog_bar=True)
        # self.log('val_acc', val_acc, on_epoch=True, prog_bar=True)
        print(f'test_accuracy : {test_acc}, test_f1 : {test_f1}, test_recall : {test_rec}, test_precision : {test_pre}')
        

    # def test_epoch_end(self, outputs):
    #     all_preds, all_labels = [], []
    #     for output in outputs:
    #         probs = list(output['logits'].cpu().detach().numpy()) # predicted values
    #         labels = list(output['labels'].flatten().cpu().detach().numpy())
    #         all_preds.extend(probs)
    #         all_labels.extend(labels)

    #     # you can calculate R2 here or save results as file
    #     r2 = ...
    
    # def predict_step(self, test_batch):
    #   x, y = test_batch
    #   logits = self.forward(x)
    #   return {'logits': logits, 'labels':y}

    def configure_optimizers(self) :
        optimizer = torch.optim.AdamW(self.electra.parameters(), lr=self.learning_rate)
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
        
        return {
            'optimizer' : optimizer,
            'lr_scheduler' : lr_scheduler
        }

In [None]:
df

Unnamed: 0.1,Unnamed: 0,title,text,label
0,0,LAW ENFORCEMENT ON HIGH ALERT Following Threat...,No comment is expected from Barack Obama Membe...,1
2,2,UNBELIEVABLE! OBAMA’S ATTORNEY GENERAL SAYS MO...,"Now, most of the demonstrators gathered last ...",1
3,3,"Bobby Jindal, raised Hindu, uses story of Chri...",A dozen politically active pastors came here f...,0
4,4,SATAN 2: Russia unvelis an image of its terrif...,"The RS-28 Sarmat missile, dubbed Satan 2, will...",1
5,5,About Time! Christian Group Sues Amazon and SP...,All we can say on this one is it s about time ...,1
...,...,...,...,...
72129,72129,Russians steal research on Trump in hack of U....,WASHINGTON (Reuters) - Hackers believed to be ...,0
72130,72130,WATCH: Giuliani Demands That Democrats Apolog...,"You know, because in fantasyland Republicans n...",1
72131,72131,Migrants Refuse To Leave Train At Refugee Camp...,Migrants Refuse To Leave Train At Refugee Camp...,0
72132,72132,Trump tussle gives unpopular Mexican leader mu...,MEXICO CITY (Reuters) - Donald Trump’s combati...,0


In [None]:
# Main to train the model

# Initialize WandB 
wandb_logger = WandbLogger(project='Electra Classification', # group runs in "MNIST" project
                           log_model='all')
model = ElectraClassification(learning_rate=0.0001)

wandb.watch(model)

dm = ElectraClassificationDataModule(batch_size=8, train_path='/content/Electra_classification_fake_vs_real_news/dataset/train.csv', valid_path='/content/Electra_classification_fake_vs_real_news/dataset/val.csv',
                                     test_path='/content/Electra_classification_fake_vs_real_news/dataset/test.csv',
                                max_length=256, sep=',', doc_col='text', label_col='label', num_workers=1)
dm.setup()
train_dataset = dm.train_dataloader()
valid_dataset = dm.val_dataloader()

checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor='val_accuracy',
                                                dirpath='./sample_electra_binary_nsmc_chpt',
                                                filename='ELECTRA/{epoch:02d}-{val_accuracy:.3f}',
                                                verbose=True,
                                                save_last=True,
                                                mode='max',
                                                save_top_k=-1,
                                                )

tb_logger = pl_loggers.TensorBoardLogger(os.path.join('./sample_electra_binary_nsmc_chpt', 'tb_logs'))

lr_logger = pl.callbacks.LearningRateMonitor()

trainer = pl.Trainer(
    default_root_dir='./sample_electra_binary_nsmc_chpt/checkpoints',
    logger = wandb_logger,
    callbacks = [checkpoint_callback, lr_logger],
    max_epochs=5,
    gpus=1)

trainer.fit(model, train_dataset, valid_dataset)

  "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse"
Some weights of the model checkpoint at google/electra-small-discriminator were not used when initializing ElectraForSequenceClassification: ['discriminator_predictions.dense.bias', 'discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.weight']
- This IS expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ElectraForSequenceClassification were not initialized fr

Sanity Checking: 0it [00:00, ?it/s]

val_accuracy : 0.625, val_f1 : 0.625, val_recall : 0.625, val_precision : 0.625


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

val_accuracy : 0.9807249903678894, val_f1 : 0.9807249903678894, val_recall : 0.9807249903678894, val_precision : 0.9807249903678894


Epoch 0, global step 6387: 'val_accuracy' reached 0.98072 (best 0.98072), saving model to '/content/sample_electra_binary_nsmc_chpt/ELECTRA/epoch=00-val_accuracy=0.981.ckpt' as top 1


[Epoch 0 TRAIN] Acc: 0.9700759354939721, Prec: 0.9679245283018868, Rec: 0.9665567593028733, F1: 0.9672401602639642


Validation: 0it [00:00, ?it/s]

val_accuracy : 0.9777042865753174, val_f1 : 0.9777042865753174, val_recall : 0.9777042865753174, val_precision : 0.9777042865753174


Epoch 1, global step 12774: 'val_accuracy' reached 0.97770 (best 0.98072), saving model to '/content/sample_electra_binary_nsmc_chpt/ELECTRA/epoch=01-val_accuracy=0.978.ckpt' as top 2


[Epoch 1 TRAIN] Acc: 0.9832667919210897, Prec: 0.9815496575342466, Rec: 0.9818438744486789, F1: 0.9816967439470811


Validation: 0it [00:00, ?it/s]

val_accuracy : 0.9807044863700867, val_f1 : 0.9807044863700867, val_recall : 0.9807044863700867, val_precision : 0.9807044863700867


Epoch 2, global step 19161: 'val_accuracy' reached 0.98072 (best 0.98072), saving model to '/content/sample_electra_binary_nsmc_chpt/ELECTRA/epoch=02-val_accuracy=0.981.ckpt' as top 3


[Epoch 2 TRAIN] Acc: 0.9877876937529356, Prec: 0.9865981588524941, Rec: 0.9866826531923093, F1: 0.986640404213411


Validation: 0it [00:00, ?it/s]

val_accuracy : 0.9873417615890503, val_f1 : 0.9873417615890503, val_recall : 0.9873417615890503, val_precision : 0.9873417615890503


Epoch 3, global step 25548: 'val_accuracy' reached 0.98734 (best 0.98734), saving model to '/content/sample_electra_binary_nsmc_chpt/ELECTRA/epoch=03-val_accuracy=0.987.ckpt' as top 4


[Epoch 3 TRAIN] Acc: 0.9905276342570847, Prec: 0.989260194257841, Rec: 0.9900226951569392, F1: 0.9896412978340896


Validation: 0it [00:00, ?it/s]

val_accuracy : 0.988204836845398, val_f1 : 0.988204836845398, val_recall : 0.988204836845398, val_precision : 0.988204836845398


Epoch 4, global step 31935: 'val_accuracy' reached 0.98820 (best 0.98820), saving model to '/content/sample_electra_binary_nsmc_chpt/ELECTRA/epoch=04-val_accuracy=0.988.ckpt' as top 5


[Epoch 4 TRAIN] Acc: 0.991232190386723, Prec: 0.9897788992002737, Rec: 0.9910504003768252, F1: 0.9904142416980486


In [None]:
test_dataset = dm.test_dataloader()
trainer.test(dataloaders=test_dataset)

  + f" You can pass `.{fn}(ckpt_path='best')` to use the best model or"
Restoring states from the checkpoint path at /content/sample_electra_binary_nsmc_chpt/ELECTRA/epoch=04-val_accuracy=0.988.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at /content/sample_electra_binary_nsmc_chpt/ELECTRA/epoch=04-val_accuracy=0.988.ckpt


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy         0.9886916875839233
         test_f1            0.9886916875839233
     test_precision         0.9886916875839233
       test_recall          0.9886916875839233
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_accuracy': 0.9886916875839233,
  'test_f1': 0.9886916875839233,
  'test_precision': 0.9886916875839233,
  'test_recall': 0.9886916875839233}]

In [None]:
wandb.finish()

### Code below this line is extraneous. 

In [None]:
electra = ElectraForSequenceClassification.from_pretrained("google/electra-small-discriminator")

# Check parameters
dm = ElectraClassificationDataModule(batch_size=8, train_path='/content/Electra_classification_fake_vs_real_news/sample_dataset/train.csv', valid_path='/content/Electra_classification_fake_vs_real_news/sample_dataset/val.csv',
                                    max_length=256, sep=',', doc_col='Tweet', label_col='is_retweet', num_workers=1)

dm.setup()

t = dm.train_dataloader()

print(t)
for idx, data in enumerate(t):
    print(idx, data['input_ids'].shape, data['attention_mask'].shape, data['label'].shape)

# Concatenate the batches ?? ********* PENDING *********** HOW TO DO THIS ?? 
#idx, data = enumerate(t)

v = dm.val_dataloader()

for idx, data in enumerate(v) :
  print(idx, data['input_ids'].shape, data['attention_mask'].shape, data['label'].shape)
  # print(idx, data['input_ids'], data['attention_mask'], data['label'])

  output = electra.forward(data['input_ids'], attention_mask=data['attention_mask'], labels=data['label'].view([-1,1]))

  print("This is the loss")
  print(output.loss)
  # print(output.loss.shape)
  # print(output.logits)
  print(output.logits.shape)

  softmax = nn.functional.softmax(output.logits, dim=1)
  print('softmax', softmax)
  pred = softmax.argmax(dim=1)
  print('pred', pred)

  y_true = data['label'].tolist()
  y_pred = pred.tolist()

acc = accuracy_score(y_true, y_pred)
prec = precision_score(y_true, y_pred)
rec = recall_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)

print(f'acc : {acc}, prec : {prec}, rec : {rec}, f1 : {f1}')


Some weights of the model checkpoint at google/electra-small-discriminator were not used when initializing ElectraForSequenceClassification: ['discriminator_predictions.dense.weight', 'discriminator_predictions.dense.bias', 'discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense_prediction.bias']
- This IS expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at google/electra-small-discriminator and are newly initialized: ['classifier

<torch.utils.data.dataloader.DataLoader object at 0x7f92b8844890>
0 torch.Size([8, 256]) torch.Size([8, 256]) torch.Size([8])
1 torch.Size([8, 256]) torch.Size([8, 256]) torch.Size([8])
2 torch.Size([8, 256]) torch.Size([8, 256]) torch.Size([8])
3 torch.Size([8, 256]) torch.Size([8, 256]) torch.Size([8])
4 torch.Size([8, 256]) torch.Size([8, 256]) torch.Size([8])
5 torch.Size([8, 256]) torch.Size([8, 256]) torch.Size([8])
6 torch.Size([8, 256]) torch.Size([8, 256]) torch.Size([8])
7 torch.Size([8, 256]) torch.Size([8, 256]) torch.Size([8])
8 torch.Size([8, 256]) torch.Size([8, 256]) torch.Size([8])
9 torch.Size([8, 256]) torch.Size([8, 256]) torch.Size([8])
10 torch.Size([8, 256]) torch.Size([8, 256]) torch.Size([8])
11 torch.Size([8, 256]) torch.Size([8, 256]) torch.Size([8])
12 torch.Size([8, 256]) torch.Size([8, 256]) torch.Size([8])
13 torch.Size([8, 256]) torch.Size([8, 256]) torch.Size([8])
14 torch.Size([8, 256]) torch.Size([8, 256]) torch.Size([8])
15 torch.Size([8, 256]) torch

  _warn_prf(average, modifier, msg_start, len(result))
