# Fine Tuning Transformer for MultiLabel Text Classification

### Introduction

In this tutorial we will be fine tuning a transformer model for the **Multilabel text classification** problem. 
This is one of the most common business problems where a given piece of text/sentence/document needs to be classified into one or more of categories out of the given list. For example a movie can be categorized into 1 or more genres.

#### Flow of the notebook

The notebook will be divided into seperate sections to provide a organized walk through for the process used. This process can be modified for individual use cases. The sections are:

1. [Importing Python Libraries and preparing the environment](#section01)
2. [Importing and Pre-Processing the domain data](#section02)
3. [Preparing the Dataset and Dataloader](#section03)
4. [Creating the Neural Network for Fine Tuning](#section04)
5. [Fine Tuning the Model](#section05)
6. [Validating the Model Performance](#section06)


#### Data: 
	 We are using the Jigsaw toxic data from [Kaggle](https://www.kaggle.com/)
     This is competion provide the souce dataset [Toxic Comment Competition](https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge)
	  We are referring only to the first csv file from the data dump: `train.csv`
	  There are rows of data.  Where each row has the following data-point: 
		 - Comment Text
		 - `toxic`
		 - `severe_toxic`
		 - `obscene`
		 - `threat`
		 - `insult`
		 - `identity_hate`

Each comment can be marked for multiple categories. If the comment is `toxic` and `obscene`, then for both those headers the value will be `1` and for the others it will be `0`.



---
***NOTE***
- *It is to be noted that the overall mechanisms for a multiclass and multilabel problems are similar, except for few differences namely:*
	- *Loss function is designed to evaluate all the probability of categories individually rather than as compared to other categories. Hence the use of `BCE` rather than `Cross Entropy` when defining loss.*
	- *Sigmoid of the outputs calcuated to rather than Softmax. Again for the reasons defined in the previous point*


---

<a id='section01'></a>
### Importing Python Libraries and preparing the environment


In [1]:
# Installing the transformers library and additional libraries if looking process 

!pip install -q transformers


[K     |████████████████████████████████| 675kB 4.4MB/s 
[K     |████████████████████████████████| 1.1MB 20.1MB/s 
[K     |████████████████████████████████| 890kB 34.6MB/s 
[K     |████████████████████████████████| 3.8MB 46.2MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone


In [0]:

import numpy as np
import pandas as pd
from sklearn import metrics
import transformers
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertTokenizer, BertModel, BertConfig


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [4]:

from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
device

'cuda'

<a id='section02'></a>
### Importing and Pre-Processing the domain data

We will be working with the data and preparing for fine tuning purposes. 
*Assuming that the `train.csv` is already downloaded, unzipped and saved in your `data` folder*

* Import the file in a dataframe and give it the headers as per the documentation.
* Taking the values of all the categories and coverting it into a list.
* The list is appened as a new column and other columns are removed

In [5]:
df = pd.read_csv('./drive/My Drive/data/class5/train.csv')
df['list'] = df[df.columns[2:]].values.tolist()
new_df = df[['comment_text', 'list']].copy()
new_df.head()

Unnamed: 0,comment_text,list
0,Explanation\nWhy the edits made under my usern...,"[0, 0, 0, 0, 0, 0]"
1,D'aww! He matches this background colour I'm s...,"[0, 0, 0, 0, 0, 0]"
2,"Hey man, I'm really not trying to edit war. It...","[0, 0, 0, 0, 0, 0]"
3,"""\nMore\nI can't make any real suggestions on ...","[0, 0, 0, 0, 0, 0]"
4,"You, sir, are my hero. Any chance you remember...","[0, 0, 0, 0, 0, 0]"


<a id='section03'></a>
### Preparing the Dataset and Dataloader

We will start with defining few key variables that will be used later during the training/fine tuning stage.
Followed by creation of CustomDataset class - This defines how the text is pre-processed before sending it to the neural network. We will also define the Dataloader that will feed  the data in batches to the neural network for suitable training and processing. 
Dataset and Dataloader are constructs of the PyTorch library for defining and controlling the data pre-processing and its passage to neural network. For further reading into Dataset and Dataloader read the [docs at PyTorch](https://pytorch.org/docs/stable/data.html)

#### *CustomDataset* Dataset Class
- This class is defined to accept the `tokenizer`, `dataframe` and `max_length` as input and generate tokenized output and tags that is used by the BERT model for training. 
- We are using the BERT tokenizer to tokenize the data in the `comment_text` column of the dataframe.
- The tokenizer uses the `encode_plus` method to perform tokenization and generate the necessary outputs, namely: `ids`, `attention_mask`, `token_type_ids`
---
- *This is the first difference between the distilbert and bert, where the tokenizer generates the token_type_ids in case of Bert*
---
- To read further into the tokenizer, [refer to this document](https://huggingface.co/transformers/model_doc/bert.html#berttokenizer)
- `targest` is the list of categories labled as `0` or `1` in the dataframe. 
- The *CustomDataset* class is used to create 2 datasets, for training and for validation.
- *Training Dataset* is used to fine tune the model: **80% of the original data**
- *Validation Dataset* is used to evaluate the performance of the model. The model has not seen this data during training. 

#### Dataloader
- Dataloader is used to for creating training and validation dataloader that load data to the neural network in a defined manner. This is needed because all the data from the dataset cannot be loaded to the memory at once, hence the amount of dataloaded to the memory and then passed to the neural network needs to be controlled.
- This control is achieved using the parameters such as `batch_size` and `max_len`.
- Training and Validation dataloaders are used in the training and validation part of the flow respectively

In [0]:
# Sections of config

# Defining some key variables that will be used later on in the training
MAX_LEN = 100
TRAIN_BATCH_SIZE = 16
VALID_BATCH_SIZE = 16
EPOCHS = 1
LEARNING_RATE = 1e-05
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [0]:
class CustomDataset(Dataset):

    def __init__(self, dataframe, tokenizer, max_len):
        self.tokenizer = tokenizer
        self.data = dataframe
        self.comment_text = dataframe.comment_text
        self.targets = self.data.list
        self.max_len = max_len

    def __len__(self):
        return len(self.comment_text)

    def __getitem__(self, index):
        comment_text = str(self.comment_text[index])
        comment_text = " ".join(comment_text.split())
         
        inputs = self.tokenizer.encode_plus(
            comment_text,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            pad_to_max_length=True,
            return_token_type_ids=True
        )
        ids = inputs['input_ids']
        mask = inputs['attention_mask']
        token_type_ids = inputs["token_type_ids"]


        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
            'targets': torch.tensor(self.targets[index], dtype=torch.float)
        }

In [27]:
# Creating the dataset and dataloader for the neural network

train_size = 0.8
train_dataset=new_df.sample(frac=train_size,random_state=200).reset_index(drop=True)
test_dataset=new_df.drop(train_dataset.index).reset_index(drop=True)


print("FULL Dataset: {}".format(new_df.shape))
print("TRAIN Dataset: {}".format(train_dataset.shape))
print("TEST Dataset: {}".format(test_dataset.shape))

training_set = CustomDataset(train_dataset, tokenizer, MAX_LEN)
testing_set = CustomDataset(test_dataset, tokenizer, MAX_LEN)

FULL Dataset: (159571, 2)
TRAIN Dataset: (127657, 2)
TEST Dataset: (31914, 2)


In [12]:
training_set[10005]

{'ids': tensor([  101, 10166, 14719,  2615,  2003,  8915,  2232,  4047,  2023,  6643,
         13910,   999,   999,   999,  8840,  2140,  1010,  2798,  1013, 14719,
          2615,  3849,  2000,  2022,  1037,  2613,  6904,  2290,  1012,  8239,
          4632,  9148,  5051,   102,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,  

In [0]:
train_params = {'batch_size': TRAIN_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0
                }

test_params = {'batch_size': VALID_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0
                }

training_loader = DataLoader(training_set, **train_params)
testing_loader = DataLoader(testing_set, **test_params)

<a id='section04'></a>
### Creating the Neural Network for Fine Tuning

 

In [29]:
# Creating the customized model, by adding a drop out and a dense layer on top of distil bert to get the final output for the model. 

class BERTClass(torch.nn.Module):
    def __init__(self):
        super(BERTClass, self).__init__()
        self.bert = transformers.BertModel.from_pretrained('bert-base-uncased')
        self.l2 = torch.nn.Dropout(0.3)
        self.l3 = torch.nn.Linear(768, 6)
        
    def forward(self, ids, mask, token_type_ids):
        _, output_1= self.bert(ids, attention_mask = mask, token_type_ids = token_type_ids)
        output_2 = self.l2(output_1)
        output = self.l3(output_2)
        return output

model = BERTClass()
model.to(device)

BERTClass(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  

In [0]:
def loss_fn(outputs, targets):
    return torch.nn.BCEWithLogitsLoss()(outputs, targets)

In [0]:
optimizer = torch.optim.Adam(params =  model.parameters(), lr=LEARNING_RATE)

<a id='section05'></a>
### Fine Tuning the Model


In [0]:
from tqdm import tqdm

def train(epoch):
    model.train()
    epoch_step = 0
    for data in tqdm(training_loader):
        ids = data['ids'].to(device, dtype = torch.long)
        mask = data['mask'].to(device, dtype = torch.long)
        token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
        targets = data['targets'].to(device, dtype = torch.float)

        outputs = model(ids, mask, token_type_ids)

        optimizer.zero_grad()
        loss = loss_fn(outputs, targets)

        if epoch_step%100==0:
            print(f'Epoch: {epoch}, Loss:  {loss.item()}')
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_step+=1

In [32]:
for epoch in range(EPOCHS):
    train(epoch)


  0%|          | 0/7979 [00:00<?, ?it/s][A
  0%|          | 1/7979 [00:00<29:37,  4.49it/s][A

Epoch: 0, Loss:  0.6522731781005859



  0%|          | 2/7979 [00:00<29:55,  4.44it/s][A
  0%|          | 3/7979 [00:00<28:24,  4.68it/s][A
  0%|          | 4/7979 [00:00<27:41,  4.80it/s][A
  0%|          | 5/7979 [00:01<27:00,  4.92it/s][A
  0%|          | 6/7979 [00:01<26:47,  4.96it/s][A
  0%|          | 7/7979 [00:01<26:21,  5.04it/s][A
  0%|          | 8/7979 [00:01<26:28,  5.02it/s][A
  0%|          | 9/7979 [00:01<26:26,  5.02it/s][A
  0%|          | 10/7979 [00:02<26:04,  5.09it/s][A
  0%|          | 11/7979 [00:02<26:36,  4.99it/s][A
  0%|          | 12/7979 [00:02<26:33,  5.00it/s][A
  0%|          | 13/7979 [00:02<26:24,  5.03it/s][A
  0%|          | 14/7979 [00:02<25:57,  5.11it/s][A
  0%|          | 15/7979 [00:02<25:49,  5.14it/s][A
  0%|          | 16/7979 [00:03<25:37,  5.18it/s][A
  0%|          | 17/7979 [00:03<25:53,  5.12it/s][A
  0%|          | 18/7979 [00:03<25:52,  5.13it/s][A
  0%|          | 19/7979 [00:03<26:16,  5.05it/s][A
  0%|          | 20/7979 [00:03<26:08,  5.08it/s][A


Epoch: 0, Loss:  0.1833244413137436



  1%|▏         | 102/7979 [00:20<25:45,  5.10it/s][A
  1%|▏         | 103/7979 [00:20<25:54,  5.07it/s][A
  1%|▏         | 104/7979 [00:20<26:23,  4.97it/s][A
  1%|▏         | 105/7979 [00:20<26:17,  4.99it/s][A
  1%|▏         | 106/7979 [00:20<26:16,  4.99it/s][A
  1%|▏         | 107/7979 [00:21<26:09,  5.01it/s][A
  1%|▏         | 108/7979 [00:21<25:56,  5.06it/s][A
  1%|▏         | 109/7979 [00:21<26:36,  4.93it/s][A
  1%|▏         | 110/7979 [00:21<26:41,  4.91it/s][A
  1%|▏         | 111/7979 [00:21<26:11,  5.01it/s][A
  1%|▏         | 112/7979 [00:22<26:06,  5.02it/s][A
  1%|▏         | 113/7979 [00:22<25:44,  5.09it/s][A
  1%|▏         | 114/7979 [00:22<25:52,  5.07it/s][A
  1%|▏         | 115/7979 [00:22<25:45,  5.09it/s][A
  1%|▏         | 116/7979 [00:22<25:42,  5.10it/s][A
  1%|▏         | 117/7979 [00:23<26:09,  5.01it/s][A
  1%|▏         | 118/7979 [00:23<25:57,  5.05it/s][A
  1%|▏         | 119/7979 [00:23<25:39,  5.10it/s][A
  2%|▏         | 120/7979 [

Epoch: 0, Loss:  0.22625665366649628



  3%|▎         | 202/7979 [00:39<25:00,  5.18it/s][A
  3%|▎         | 203/7979 [00:40<26:06,  4.96it/s][A
  3%|▎         | 204/7979 [00:40<26:22,  4.91it/s][A
  3%|▎         | 205/7979 [00:40<25:46,  5.03it/s][A
  3%|▎         | 206/7979 [00:40<25:47,  5.02it/s][A
  3%|▎         | 207/7979 [00:40<25:25,  5.09it/s][A
  3%|▎         | 208/7979 [00:41<25:07,  5.15it/s][A
  3%|▎         | 209/7979 [00:41<25:04,  5.17it/s][A
  3%|▎         | 210/7979 [00:41<25:05,  5.16it/s][A
  3%|▎         | 211/7979 [00:41<24:44,  5.23it/s][A
  3%|▎         | 212/7979 [00:41<25:08,  5.15it/s][A
  3%|▎         | 213/7979 [00:42<25:02,  5.17it/s][A
  3%|▎         | 214/7979 [00:42<24:57,  5.19it/s][A
  3%|▎         | 215/7979 [00:42<25:00,  5.17it/s][A
  3%|▎         | 216/7979 [00:42<25:18,  5.11it/s][A
  3%|▎         | 217/7979 [00:42<25:07,  5.15it/s][A
  3%|▎         | 218/7979 [00:43<25:32,  5.07it/s][A
  3%|▎         | 219/7979 [00:43<25:17,  5.12it/s][A
  3%|▎         | 220/7979 [

Epoch: 0, Loss:  0.07685349881649017



  4%|▍         | 302/7979 [00:59<25:52,  4.95it/s][A
  4%|▍         | 303/7979 [00:59<25:40,  4.98it/s][A
  4%|▍         | 304/7979 [00:59<26:04,  4.91it/s][A
  4%|▍         | 305/7979 [01:00<25:50,  4.95it/s][A
  4%|▍         | 306/7979 [01:00<25:28,  5.02it/s][A
  4%|▍         | 307/7979 [01:00<25:11,  5.08it/s][A
  4%|▍         | 308/7979 [01:00<25:05,  5.09it/s][A
  4%|▍         | 309/7979 [01:00<25:19,  5.05it/s][A
  4%|▍         | 310/7979 [01:01<24:57,  5.12it/s][A
  4%|▍         | 311/7979 [01:01<24:39,  5.18it/s][A
  4%|▍         | 312/7979 [01:01<24:36,  5.19it/s][A
  4%|▍         | 313/7979 [01:01<24:28,  5.22it/s][A
  4%|▍         | 314/7979 [01:01<24:25,  5.23it/s][A
  4%|▍         | 315/7979 [01:02<24:34,  5.20it/s][A
  4%|▍         | 316/7979 [01:02<24:34,  5.20it/s][A
  4%|▍         | 317/7979 [01:02<24:34,  5.19it/s][A
  4%|▍         | 318/7979 [01:02<24:53,  5.13it/s][A
  4%|▍         | 319/7979 [01:02<25:08,  5.08it/s][A
  4%|▍         | 320/7979 [

Epoch: 0, Loss:  0.10536889731884003



  5%|▌         | 402/7979 [01:19<25:24,  4.97it/s][A
  5%|▌         | 403/7979 [01:19<25:16,  4.99it/s][A
  5%|▌         | 404/7979 [01:19<24:55,  5.06it/s][A
  5%|▌         | 405/7979 [01:19<24:46,  5.10it/s][A
  5%|▌         | 406/7979 [01:20<24:54,  5.07it/s][A
  5%|▌         | 407/7979 [01:20<24:36,  5.13it/s][A
  5%|▌         | 408/7979 [01:20<24:30,  5.15it/s][A
  5%|▌         | 409/7979 [01:20<24:30,  5.15it/s][A
  5%|▌         | 410/7979 [01:20<24:54,  5.07it/s][A
  5%|▌         | 411/7979 [01:21<25:47,  4.89it/s][A
  5%|▌         | 412/7979 [01:21<26:00,  4.85it/s][A
  5%|▌         | 413/7979 [01:21<25:36,  4.93it/s][A
  5%|▌         | 414/7979 [01:21<25:47,  4.89it/s][A
  5%|▌         | 415/7979 [01:21<25:16,  4.99it/s][A
  5%|▌         | 416/7979 [01:22<25:18,  4.98it/s][A
  5%|▌         | 417/7979 [01:22<25:20,  4.97it/s][A
  5%|▌         | 418/7979 [01:22<24:53,  5.06it/s][A
  5%|▌         | 419/7979 [01:22<24:33,  5.13it/s][A
  5%|▌         | 420/7979 [

Epoch: 0, Loss:  0.0667152926325798



  6%|▋         | 502/7979 [01:39<24:12,  5.15it/s][A
  6%|▋         | 503/7979 [01:39<24:08,  5.16it/s][A
  6%|▋         | 504/7979 [01:39<24:19,  5.12it/s][A
  6%|▋         | 505/7979 [01:39<23:57,  5.20it/s][A
  6%|▋         | 506/7979 [01:39<24:17,  5.13it/s][A
  6%|▋         | 507/7979 [01:40<24:26,  5.10it/s][A
  6%|▋         | 508/7979 [01:40<24:19,  5.12it/s][A
  6%|▋         | 509/7979 [01:40<24:59,  4.98it/s][A
  6%|▋         | 510/7979 [01:40<24:28,  5.08it/s][A
  6%|▋         | 511/7979 [01:40<24:22,  5.11it/s][A
  6%|▋         | 512/7979 [01:41<24:38,  5.05it/s][A
  6%|▋         | 513/7979 [01:41<25:02,  4.97it/s][A
  6%|▋         | 514/7979 [01:41<24:43,  5.03it/s][A
  6%|▋         | 515/7979 [01:41<24:56,  4.99it/s][A
  6%|▋         | 516/7979 [01:41<24:38,  5.05it/s][A
  6%|▋         | 517/7979 [01:42<24:42,  5.03it/s][A
  6%|▋         | 518/7979 [01:42<24:26,  5.09it/s][A
  7%|▋         | 519/7979 [01:42<24:57,  4.98it/s][A
  7%|▋         | 520/7979 [

Epoch: 0, Loss:  0.08836552500724792



  8%|▊         | 602/7979 [01:59<23:52,  5.15it/s][A
  8%|▊         | 603/7979 [01:59<24:31,  5.01it/s][A
  8%|▊         | 604/7979 [01:59<24:34,  5.00it/s][A
  8%|▊         | 605/7979 [01:59<24:20,  5.05it/s][A
  8%|▊         | 606/7979 [01:59<24:10,  5.08it/s][A
  8%|▊         | 607/7979 [02:00<24:44,  4.97it/s][A
  8%|▊         | 608/7979 [02:00<24:27,  5.02it/s][A
  8%|▊         | 609/7979 [02:00<24:11,  5.08it/s][A
  8%|▊         | 610/7979 [02:00<23:37,  5.20it/s][A
  8%|▊         | 611/7979 [02:00<23:38,  5.20it/s][A
  8%|▊         | 612/7979 [02:01<23:53,  5.14it/s][A
  8%|▊         | 613/7979 [02:01<23:37,  5.20it/s][A
  8%|▊         | 614/7979 [02:01<24:30,  5.01it/s][A
  8%|▊         | 615/7979 [02:01<24:38,  4.98it/s][A
  8%|▊         | 616/7979 [02:01<24:20,  5.04it/s][A
  8%|▊         | 617/7979 [02:02<25:10,  4.87it/s][A
  8%|▊         | 618/7979 [02:02<25:29,  4.81it/s][A
  8%|▊         | 619/7979 [02:02<24:49,  4.94it/s][A
  8%|▊         | 620/7979 [

KeyboardInterrupt: ignored

<a id='section06'></a>
### Validating the Model

During the validation stage we pass the unseen data(Testing Dataset) to the model. This step determines how good the model performs on the unseen data. 

This unseen data is the 20% of `train.csv` which was seperated during the Dataset creation stage. 
During the validation stage the weights of the model are not updated. Only the final output is compared to the actual value. This comparison is then used to calcuate the accuracy of the model. 

As defined above to get a measure of our models performance we are using the following metrics. 
- Accuracy Score
- F1 Micro
- F1 Macro

We are getting amazing results for all these 3 categories just by training the model for 1 Epoch.

In [0]:
def validation(epoch):
    model.eval()
    fin_targets=[]
    fin_outputs=[]
    with torch.no_grad():
        for _, data in enumerate(testing_loader, 0):
            ids = data['ids'].to(device, dtype = torch.long)
            mask = data['mask'].to(device, dtype = torch.long)
            token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
            targets = data['targets'].to(device, dtype = torch.float)
            outputs = model(ids, mask, token_type_ids)
            fin_targets.extend(targets.cpu().detach().numpy().tolist())
            fin_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist())
    return fin_outputs, fin_targets

In [34]:
for epoch in range(EPOCHS):
    outputs, targets = validation(epoch)
    outputs = np.array(outputs) >= 0.5
    accuracy = metrics.accuracy_score(targets, outputs)
    f1_score_micro = metrics.f1_score(targets, outputs, average='micro')
    f1_score_macro = metrics.f1_score(targets, outputs, average='macro')
    print(f"Accuracy Score = {accuracy}")
    print(f"F1 Score (Micro) = {f1_score_micro}")
    print(f"F1 Score (Macro) = {f1_score_macro}")

Accuracy Score = 0.9214451337970796
F1 Score (Micro) = 0.7240212663122282
F1 Score (Macro) = 0.4037705040668253
