# Transformers for Multilabel

[BLOG](https://towardsdatascience.com/transformers-for-multilabel-classification-71a1a0daf5e1)

## 0. 导入库


In [4]:
import os
import sys


import pickle

import pandas as pd
import numpy as np

import torch
import torch.nn as no
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

from transformers import BertTokenizer, BertModel


## 1. 处理数据

### 1.1 可视化数据

In [None]:
DATA_DIR = '/mnt/HDD2/lyp/DATASET/toxic'

df = pd.read_csv(os.path.join(DATA_DIR, 'train.csv'))

df.head()

In [None]:
print('Unique comments:', df.comment_text.nunique() == df.shape[0])
print('Null values: ', df.isnull().values.any())

print('average sentence length: ', df.comment_text.str.split().str.len().mean())
print('stdev sentence length: ', df.comment_text.str.split().str.len().std())

In [None]:
cols = df.columns
label_cols = list(cols[2:])
num_labels = len(label_cols)

print('Count of 1 per label: \n', df[label_cols].sum(), '\n') # Label counts, may need to downsample or upsample
print('Count of 0 per label: \n', df[label_cols].eq(0).sum())


df = df.sample(frac=1).reset_index(drop=True) #shuffle rows
df['one_hot_labels'] = list(df[label_cols].values)
df.head()

In [None]:
labels = list(df.one_hot_labels.values)
comments = list(df.comment_text.values)

# 2. 构造 DataLoader

In [None]:
## Transformers

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
encodings = tokenizer(comments,
                      padding='max',
                      max_length = 'xx')
                      
print('tokenizer outputs: ', encodings.keys())

### 2.1 构造 Valid_Dataset

In [None]:
# Identifying indices of 'one_hot_labels' entries that only occur once - this will allow us to stratify split our training data later
label_counts = df.one_hot_labels.astype(str).value_counts()
one_freq = label_counts[label_counts==1].keys()
one_freq_idxs = sorted(list(df[df.one_hot_labels.astype(str).isin(one_freq)].index), reverse=True)
print('df label indices with only one instance: ', one_freq_idxs)

In [None]:
# Gathering single instance inputs to force into the training set after stratified split
one_freq_input_ids = [input_ids.pop(i) for i in one_freq_idxs]
one_freq_token_types = [token_type_ids.pop(i) for i in one_freq_idxs]
one_freq_attention_masks = [attention_masks.pop(i) for i in one_freq_idxs]
one_freq_labels = [labels.pop(i) for i in one_freq_idxs]


# Use train_test_split to split our data into train and validation sets

train_inputs, validation_inputs, train_labels, validation_labels, train_token_types, validation_token_types, train_masks, validation_masks = train_test_split(input_ids, labels, token_type_ids,attention_masks,
                                                            random_state=2020, test_size=0.10, stratify = labels)

# Add one frequency data to train data
train_inputs.extend(one_freq_input_ids)
train_labels.extend(one_freq_labels)
train_masks.extend(one_freq_attention_masks)
train_token_types.extend(one_freq_token_types)

# Convert all of our data into torch tensors, the required datatype for our model
train_inputs = torch.tensor(train_inputs)
train_labels = torch.tensor(train_labels)
train_masks = torch.tensor(train_masks)
train_token_types = torch.tensor(train_token_types)

validation_inputs = torch.tensor(validation_inputs)
validation_labels = torch.tensor(validation_labels)
validation_masks = torch.tensor(validation_masks)
validation_token_types = torch.tensor(validation_token_types)

In [None]:
batch_size = 32

train_data = TensorDataset(train_inputs, train_masks, train_labels, train_token_types)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

validation_data = TensorDataset(validation_inputs, validation_masks, validation_labels, validation_token_types)
validation_sampler = SequentialSampler(validation_data)
validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size)

with
    torch.save(validation_dataloader,'validation_data_loader')

with
torch.save(train_dataloader,'train_data_loader')