### Multilabel classification tutorial is implemented by transfer learning 'google/bert_uncased_L-4_H-256_A-4' + custom classification head to detect news labels with the title

Download Data

In [1]:
import kagglehub

path = kagglehub.dataset_download("shivanandmn/multilabel-classification-dataset")
print("Path to dataset files:", path)

Using Colab cache for faster access to the 'multilabel-classification-dataset' dataset.
Path to dataset files: /kaggle/input/multilabel-classification-dataset


Load Base Model

In [2]:
from transformers import BertTokenizer, BertModel, BertForSequenceClassification

model_name = "google/bert_uncased_L-4_H-256_A-4"

tokenizer = BertTokenizer.from_pretrained(model_name)
bert_model = BertModel.from_pretrained(model_name)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Freeze Layers

In [3]:
for layer in bert_model.parameters():
    layer.require_grad = False

In [4]:
for name, layer in bert_model.named_parameters():
   print(name,"|",layer.require_grad)

embeddings.word_embeddings.weight | False
embeddings.position_embeddings.weight | False
embeddings.token_type_embeddings.weight | False
embeddings.LayerNorm.weight | False
embeddings.LayerNorm.bias | False
encoder.layer.0.attention.self.query.weight | False
encoder.layer.0.attention.self.query.bias | False
encoder.layer.0.attention.self.key.weight | False
encoder.layer.0.attention.self.key.bias | False
encoder.layer.0.attention.self.value.weight | False
encoder.layer.0.attention.self.value.bias | False
encoder.layer.0.attention.output.dense.weight | False
encoder.layer.0.attention.output.dense.bias | False
encoder.layer.0.attention.output.LayerNorm.weight | False
encoder.layer.0.attention.output.LayerNorm.bias | False
encoder.layer.0.intermediate.dense.weight | False
encoder.layer.0.intermediate.dense.bias | False
encoder.layer.0.output.dense.weight | False
encoder.layer.0.output.dense.bias | False
encoder.layer.0.output.LayerNorm.weight | False
encoder.layer.0.output.LayerNorm.bias | 

In [5]:
text = "this is a sentence"
tokens = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)

In [6]:
bert_model(**tokens).last_hidden_state.shape

torch.Size([1, 3, 256])

Model

In [7]:
from torch.nn import Linear, Flatten, Module, BCEWithLogitsLoss
from torch.optim import AdamW
import torch

class TopicModel(Module):
    def __init__(self):
        super().__init__()
        self.base_model = bert_model
        self.classifier = Linear(in_features=256, out_features=6)

    def forward(self, inputs):
        with torch.no_grad():
            logits = self.base_model(**inputs).last_hidden_state
        pooling = logits.mean(dim=1)
        classify = self.classifier(pooling)
        return classify

In [8]:
topic_model = TopicModel()
optimizer = AdamW(topic_model.parameters(), lr = 0.001)
loss_func = BCEWithLogitsLoss()

Training & Validation

In [9]:
def training_function(model, train_dataset, loss_func, optimizer):
    no_sample = len(train_dataset.dataset)
    no_of_batch = len(train_dataset)
    model.train()
    for batch,(x,y) in enumerate(train_dataset):
        prediction = model(x)
        loss = loss_func(prediction, y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if batch%100 == 0:
           print(f"training loss : {loss.item()} | completed : [{32 * batch}/{no_sample}]")



def validation_function(model, test_dataset, loss_func):
    no_sample = len(test_dataset.dataset)
    no_of_batch = len(test_dataset)
    model.eval()
    loss, accuracy = 0, 0
    with torch.no_grad():
        for x, y in test_dataset:
           prediction = model(x)
           loss += loss_func(prediction, y).item()
           accuracy += torch.sum((torch.sigmoid(prediction)>0.5)==y) / 6

    loss/=no_of_batch
    accuracy/=no_sample
    print(f"validation loss : {loss} | accuracy : {accuracy}")
    return loss


Early Stopping

In [10]:
class EarlyStopping:
      def __init__(self, patience, min_delta):
         self.patience = patience
         self.min_delta = min_delta
         self.minimum = float("inf")
         self.count = 0

      def stop(self, validation_loss):
          if validation_loss < self.minimum - self.min_delta:
             self.minimum = validation_loss
             self.count = 0
          else:
             self.count += 1
             if self.count > self.patience:
                return True
          return False

Prepare Dataset

In [11]:
path

'/kaggle/input/multilabel-classification-dataset'

In [12]:
%cd "/root/.cache/kagglehub/datasets/shivanandmn/multilabel-classification-dataset/versions/1"

/root/.cache/kagglehub/datasets/shivanandmn/multilabel-classification-dataset/versions/1


In [13]:
import pandas as pd

topic_dataset = pd.read_csv("train.csv")

In [14]:
topic_dataset

Unnamed: 0,ID,TITLE,ABSTRACT,Computer Science,Physics,Mathematics,Statistics,Quantitative Biology,Quantitative Finance
0,1,Reconstructing Subject-Specific Effect Maps,Predictive models allow subject-specific inf...,1,0,0,0,0,0
1,2,Rotation Invariance Neural Network,Rotation invariance and translation invarian...,1,0,0,0,0,0
2,3,Spherical polyharmonics and Poisson kernels fo...,We introduce and develop the notion of spher...,0,0,1,0,0,0
3,4,A finite element approximation for the stochas...,The stochastic Landau--Lifshitz--Gilbert (LL...,0,0,1,0,0,0
4,5,Comparative study of Discrete Wavelet Transfor...,Fourier-transform infra-red (FTIR) spectra o...,1,0,0,1,0,0
...,...,...,...,...,...,...,...,...,...
20967,20968,Contemporary machine learning: a guide for pra...,Machine learning is finding increasingly bro...,1,1,0,0,0,0
20968,20969,Uniform diamond coatings on WC-Co hard alloy c...,Polycrystalline diamond coatings have been g...,0,1,0,0,0,0
20969,20970,Analysing Soccer Games with Clustering and Con...,We present a new approach for identifying si...,1,0,0,0,0,0
20970,20971,On the Efficient Simulation of the Left-Tail o...,The sum of Log-normal variates is encountere...,0,0,1,1,0,0


In [15]:
X = topic_dataset[["Computer Science","Physics","Mathematics","Statistics","Quantitative Biology","Quantitative Finance"]]
Y = topic_dataset["TITLE"]

In [16]:
from torch.utils.data import Dataset, DataLoader, random_split

class PandasDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        feature = tokenizer(
                self.Y.iloc[idx],
                truncation=True,
                max_length=512,
                padding="max_length"
              )

        feature = {k: torch.tensor(v) for k, v in feature.items()}

        target = torch.tensor(self.X.iloc[idx].values, dtype=torch.float)

        return feature, target


In [17]:
dataset = PandasDataset(X,Y)
train_size = int(len(X)*0.8)
test_size = len(X)-train_size
train_dataset, test_dataset = random_split(dataset,[train_size, test_size])

In [18]:
train_dataset = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataset = DataLoader(test_dataset, batch_size=32, shuffle = True)

Model Cycle

In [19]:
early = EarlyStopping(3, 0.001)
epoch = 5

for e in range(1, epoch+1):
    print(f"Epoch : {e}")
    training_function(topic_model, train_dataset, loss_func, optimizer)
    val_loss = validation_function(topic_model, test_dataset, loss_func)
    if early.stop(val_loss):
      break

Epoch : 1
training loss : 0.6856486797332764 | completed : [0/16777]
training loss : 0.32483360171318054 | completed : [3200/16777]
training loss : 0.3042384386062622 | completed : [6400/16777]
training loss : 0.30047184228897095 | completed : [9600/16777]
training loss : 0.3099982440471649 | completed : [12800/16777]
training loss : 0.3932584524154663 | completed : [16000/16777]
validation loss : 0.2980820751099875 | accuracy : 0.8732220530509949
Epoch : 2
training loss : 0.34503138065338135 | completed : [0/16777]
training loss : 0.2725656032562256 | completed : [3200/16777]
training loss : 0.4048542082309723 | completed : [6400/16777]
training loss : 0.24619154632091522 | completed : [9600/16777]
training loss : 0.3411550223827362 | completed : [12800/16777]
training loss : 0.2527824640274048 | completed : [16000/16777]
validation loss : 0.29142352893497003 | accuracy : 0.8771952986717224
Epoch : 3
training loss : 0.28217363357543945 | completed : [0/16777]
training loss : 0.2657084