# Selective Masking 流程準備
參考論文：[Train No Evil: Selective Masking for Task-Guided Pre-Training](https://arxiv.org/abs/2004.09733)

## 1. Fine-tune BERT

In [1]:
import pandas as pd
import matplotlib.pyplot as plt

datapath = '../bbc-text.csv'
df = pd.read_csv(datapath)
df.head()

Unnamed: 0,category,text
0,tech,tv future in the hands of viewers with home th...
1,business,worldcom boss left books alone former worldc...
2,sport,tigers wary of farrell gamble leicester say ...
3,sport,yeading face newcastle in fa cup premiership s...
4,entertainment,ocean s twelve raids box office ocean s twelve...


In [2]:
from transformers import BertTokenizer
import torch
import numpy as np
from transformers import BertTokenizer

# 決定 tokenizer 類型
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
# 決定資料集中各分類對應的 id
labels = {'business':0,
          'entertainment':1,
          'sport':2,
          'tech':3,
          'politics':4
          }

# 資料集處理
class Dataset(torch.utils.data.Dataset):

    def __init__(self, df):
        # 把每一筆資料的類別改成 id
        self.labels = [labels[label] for label in df['category']]  
        # 對每筆資料做 BERT tokenize
        self.texts = [tokenizer(text, 
                               padding='max_length', max_length = 512, truncation=True,
                                return_tensors="pt") for text in df['text']]

    # 回傳資料集各類別 (id)
    def classes(self):
        return self.labels

    # 回傳該 label 的資料數
    def __len__(self):
        return len(self.labels)

    # 取得當前資料的 label
    def get_batch_labels(self, idx):
        # Fetch a batch of labels
        return np.array(self.labels[idx])

    # 取得當前資料的 text
    def get_batch_texts(self, idx):
        # Fetch a batch of inputs
        return self.texts[idx]

    def __getitem__(self, idx):
        batch_texts = self.get_batch_texts(idx)
        batch_y = self.get_batch_labels(idx)

        return batch_texts, batch_y

In [3]:
np.random.seed(112)
df_train, df_val, df_test = np.split(df.sample(frac=1, random_state=42), 
                                     [int(.8*len(df)), int(.9*len(df))])

print(len(df_train),len(df_val), len(df_test))

1780 222 223


In [4]:
from torch import nn
from transformers import BertForSequenceClassification
from torch.optim import Adam
from tqdm import tqdm
import os

def train(model, train_data, val_data, learning_rate, epochs, batch_size, model_name, save_path):
    if os.path.isfile("fine_tune_record_epoch.csv"):
        rec = pd.read_csv("fine_tune_record_epoch.csv")
    else:
        rec = pd.DataFrame({"model_name":[], "train_acc":[], "train_loss":[], "val_acc":[], "val_loss":[]})

    # 把原本的資料經過 Dataset 類別包裝起來
    train, val = Dataset(train_data), Dataset(val_data)

    # 把訓練、驗證資料集丟進 Dataloader 定義取樣資訊 (ex: 設定 batch_size...等等)
    train_dataloader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val, batch_size=batch_size)

    # 偵測有 GPU，有就用
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    criterion = nn.CrossEntropyLoss()                       # Loss Function: Categorical cross entropy
    optimizer = Adam(model.parameters(), lr= learning_rate) # Optimizer: Adam

    if use_cuda:

            model = model.cuda()
            criterion = criterion.cuda()
    
    train_acc = []
    train_loss = []
    val_acc = []
    val_loss = []
    # 每次完整訓練 (每個 epoch) 要做的事
    for epoch_num in range(epochs):

            # ---------- 訓練的部分 ----------
            total_acc_train = 0
            total_loss_train = 0

            # 這邊加上 tqdm 模組來顯示 dataloader 處理進度條
            # 所以在程式意義上，可以直接把這行當作 for train_input, train_label in train_dataloader:
            for train_input, train_label in tqdm(train_dataloader):
                # .to(device): 把東西 (tensor) 丟到 GPU 的概念
                train_label = train_label.type(torch.LongTensor).to(device)
                mask = train_input['attention_mask'].squeeze(1).to(device)
                input_id = train_input['input_ids'].squeeze(1).to(device)

                # 把 data 丟進 BERT
                output = model(input_ids=input_id, attention_mask=mask, labels=train_label)
                
                # 計算 Cross Entropy，以此計算 loss
                batch_loss = output[0]
                total_loss_train += batch_loss.item()               # .item(): tensor 轉 純量
                
                # 看 model output "可能性最高" 的 label 是不是和 data 一樣，是的話，acc + 1
                logits = output[1]
                pred_label = logits.argmax(dim=1)
                acc = (pred_label == train_label).sum().item()
                total_acc_train += acc

                model.zero_grad()       # 清空前一次 Gradient
                batch_loss.backward()   # 根據 lost 計算 back propagation
                optimizer.step()        # 做 Gradient Decent
            
            # ---------- 驗證的部分 ----------
            total_acc_val = 0
            total_loss_val = 0

            # 步驟和訓練時差不多，差在沒做 Gradient Decent
            with torch.no_grad():

                for val_input, val_label in val_dataloader:

                    val_label = val_label.type(torch.LongTensor).to(device)
                    mask = val_input['attention_mask'].to(device)
                    input_id = val_input['input_ids'].squeeze(1).to(device)

                    output = model(input_ids=input_id, attention_mask=mask, labels=val_label)

                    batch_loss = output[0]
                    total_loss_val += batch_loss.item()
                    
                    logits = output[1]
                    pred_label = logits.argmax(dim=1)
                    acc = (pred_label == val_label).sum().item()
                    total_acc_val += acc
            
            train_loss.append(total_loss_train / len(train_data))
            train_acc.append(total_acc_train / len(train_data))
            val_loss.append(total_loss_val / len(val_data))
            val_acc.append(total_acc_val / len(val_data))
            print(
                f'Epochs: {epoch_num + 1} | Train Loss: {total_loss_train / len(train_data): .3f} \
                | Train Accuracy: {total_acc_train / len(train_data): .3f} \
                | Val Loss: {total_loss_val / len(val_data): .3f} \
                | Val Accuracy: {total_acc_val / len(val_data): .3f}')
    
    new_rec = pd.concat([rec, pd.DataFrame(pd.DataFrame({'model_name': model_name, 'train_acc': [train_acc], 'train_loss': [train_loss], 'val_acc': [val_acc], 'val_loss': [val_loss]}))], ignore_index=True)
    new_rec.to_csv("fine_tune_record_epoch.csv", index = None)
    model.save_pretrained(save_path)
    model = None

In [5]:
EPOCHS = 8
LR = 2e-5
batch_size = 8
model = BertForSequenceClassification.from_pretrained('bert-base-cased', num_labels=5)     
train(model, df_train, df_val, LR, EPOCHS, batch_size, "Fine-tuned_BERT", "fine_tuned_bert")

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

Epochs: 1 | Train Loss:  0.065                 | Train Accuracy:  0.842                 | Val Loss:  0.007                 | Val Accuracy:  0.986


100%|██████████| 223/223 [01:24<00:00,  2.63it/s]


Epochs: 2 | Train Loss:  0.008                 | Train Accuracy:  0.984                 | Val Loss:  0.006                 | Val Accuracy:  0.991


100%|██████████| 223/223 [01:24<00:00,  2.65it/s]


Epochs: 3 | Train Loss:  0.005                 | Train Accuracy:  0.990                 | Val Loss:  0.003                 | Val Accuracy:  0.986


100%|██████████| 223/223 [01:23<00:00,  2.66it/s]


Epochs: 4 | Train Loss:  0.003                 | Train Accuracy:  0.994                 | Val Loss:  0.012                 | Val Accuracy:  0.977


100%|██████████| 223/223 [01:23<00:00,  2.66it/s]


Epochs: 5 | Train Loss:  0.003                 | Train Accuracy:  0.993                 | Val Loss:  0.011                 | Val Accuracy:  0.977


100%|██████████| 223/223 [01:23<00:00,  2.68it/s]


Epochs: 6 | Train Loss:  0.001                 | Train Accuracy:  0.997                 | Val Loss:  0.007                 | Val Accuracy:  0.991


100%|██████████| 223/223 [01:22<00:00,  2.71it/s]


Epochs: 7 | Train Loss:  0.001                 | Train Accuracy:  0.998                 | Val Loss:  0.006                 | Val Accuracy:  0.991


100%|██████████| 223/223 [01:23<00:00,  2.67it/s]


Epochs: 8 | Train Loss:  0.004                 | Train Accuracy:  0.993                 | Val Loss:  0.011                 | Val Accuracy:  0.986


## 2. Downstream Mask

In [1]:
import pandas as pd
import matplotlib.pyplot as plt
from transformers import BertTokenizer
import torch
import numpy as np
from transformers import BertTokenizer
import torch.nn.functional as F
from torch import nn
from transformers import BertForSequenceClassification

In [2]:
datapath = '../bbc-text.csv'
df = pd.read_csv(datapath)
df.head()

Unnamed: 0,category,text
0,tech,tv future in the hands of viewers with home th...
1,business,worldcom boss left books alone former worldc...
2,sport,tigers wary of farrell gamble leicester say ...
3,sport,yeading face newcastle in fa cup premiership s...
4,entertainment,ocean s twelve raids box office ocean s twelve...


In [3]:
sentences = []

for i in range(len(df)):
    sentences.append(df.iloc[i, 1])

In [4]:
most_important = set()
second_important = set()
third_important = set()
not_important = set()
most_threshold = 0.05
second_threshold = 0.07
third_threshold = 0.10

tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
model = BertForSequenceClassification.from_pretrained('fine_tuned_bert')

In [5]:
def isIn(token, set_list, ifDel = False):
    return_val = False
    for s in set_list:
        if token in s:
            if ifDel:
                s.remove(token)
            return_val = True
    return return_val

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for i in range(len(sentences)):
    inputs_sentence = tokenizer(sentences[i], padding='max_length', \
                                max_length = 512, truncation=True, return_tensors="pt")
    
    model.to(device)
    inputs_sentence = inputs_sentence.to(device)
    outputs_sentence = model(**inputs_sentence)

    logits_sentence = outputs_sentence.logits
    probs_sentence = F.softmax(logits_sentence, dim=-1)
    pred_label_sentence = torch.argmax(probs_sentence, dim=-1).item()
    confidence_sentence = probs_sentence[0][pred_label_sentence].item()

    sentence2_idx = 0
    sentence2_input = torch.full(inputs_sentence['input_ids'].shape, 0)
    sentence2_tkn_type = torch.full(inputs_sentence['token_type_ids'].shape, 0)
    sentence2_att_mast = torch.full(inputs_sentence['attention_mask'].shape, 0)

    for j in range(512):
        token_now = inputs_sentence['input_ids'][0][j]
        token_now_int = token_now.item()
        
        sentence2_input[0][sentence2_idx] = token_now
        sentence2_att_mast[0][sentence2_idx] = 1
        sentence2_idx += 1
        if token_now_int == 101:
            continue
        elif token_now_int == 102:
            break
        
        sentence2_input = sentence2_input.to(device)
        sentence2_tkn_type = sentence2_tkn_type.to(device)
        sentence2_att_mast = sentence2_att_mast.to(device)
        # print(tokenizer.decode(sentence2_input[0]))
        outputs_sentence2 = model(input_ids=sentence2_input, attention_mask=sentence2_att_mast, token_type_ids=sentence2_tkn_type)

        logits_sentence2 = outputs_sentence2.logits
        probs_sentence2 = F.softmax(logits_sentence2, dim=-1)
        pred_label_sentence2 = torch.argmax(probs_sentence2, dim=-1).item()
        confidence_sentence2 = probs_sentence2[0][pred_label_sentence2].item()
        
        if pred_label_sentence != pred_label_sentence2:
            if isIn(token_now_int, [most_important, second_important, third_important, not_important]) == False:
                not_important.add(token_now_int)
        else:
            score = abs(confidence_sentence - confidence_sentence2)
            if score <= most_threshold:
                most_important.add(token_now_int)
                isIn(token_now_int, [second_important, third_important, not_important], ifDel = True)
                sentence2_input[0][sentence2_idx] = 0
                sentence2_att_mast[0][sentence2_idx] = 0
                sentence2_idx -= 1
            elif score <= second_threshold and not isIn(token_now_int, [most_important]):
                second_important.add(token_now_int)
                isIn(token_now_int, [third_important, not_important], ifDel = True)
                sentence2_input[0][sentence2_idx] = 0
                sentence2_att_mast[0][sentence2_idx] = 0
                sentence2_idx -= 1
            elif score <= third_threshold and not isIn(token_now_int, [most_important, second_important]):
                third_important.add(token_now_int)
                isIn(token_now_int, [not_important], ifDel = True)
                sentence2_input[0][sentence2_idx] = 0
                sentence2_att_mast[0][sentence2_idx] = 0
                sentence2_idx -= 1
            elif not isIn(token_now_int, [most_important, second_important, third_important]):
                not_important.add(token_now_int)
    break
# print(most_important)
# print(second_important)
# print(third_important)
# print(not_important)

In [6]:
print(most_important)
print(second_important)
print(third_important)
print(not_important)

{1888, 15139, 2621, 12263, 2984, 1449, 6379, 1964, 5197, 8171, 7951, 11216, 26577, 18898, 3539, 2357, 2557, 2815}
{13441, 5989, 2344, 4045, 2258, 1555, 5754, 3643, 1116, 189}
{2443, 1292, 6412, 1169, 1176, 2086, 15403, 1197, 1977, 1211, 4159, 6984, 15435, 1103, 6095, 1105, 1106, 1104, 1234, 1499, 1383, 2538, 1132, 3438, 21359, 113, 1524, 1141, 119}
{4097, 4609, 2052, 15370, 3094, 6678, 2595, 4644, 2084, 1575, 11305, 4653, 1605, 2120, 5710, 1107, 8276, 1110, 1111, 17496, 3673, 1114, 1115, 1112, 1118, 1119, 1120, 4193, 1122, 1121, 1126, 4711, 1129, 1133, 1134, 1647, 1136, 114, 3186, 116, 118, 1142, 1144, 6265, 1146, 2683, 1145, 27772, 2174, 1151, 1152, 1147, 1154, 12411, 1665, 131, 1158, 1159, 1671, 27271, 1162, 1163, 5260, 1164, 1165, 1167, 1168, 8844, 1172, 3223, 1175, 1177, 1690, 1179, 1178, 3741, 1184, 1185, 1696, 10915, 1190, 1193, 1194, 6827, 172, 173, 170, 1199, 13482, 1201, 1195, 179, 171, 175, 11955, 1207, 2232, 1209, 185, 12986, 3772, 188, 1721, 1208, 1217, 1218, 3265, 1228, 12

In [7]:
text1 = tokenizer.decode(most_important)
text1

'video portable allow providers store system networksv devices boxes technologies electronics broadband recorder digital personal companies technology'

### Rewrite

In [1]:
import pandas as pd
import matplotlib.pyplot as plt
from transformers import BertTokenizer
import torch
import numpy as np
from transformers import BertTokenizer
import torch.nn.functional as F
from torch import nn
from transformers import BertForSequenceClassification

In [2]:
# datapath = '../bbc-text.csv'
# df = pd.read_csv(datapath)
# df.head()
# df["important_labels"] = [" " for _ in range(len(df.index))]
# df["most_important"] = [" " for _ in range(len(df.index))]
# df["second_important"] = [" " for _ in range(len(df.index))]
# df["third_important"] = [" " for _ in range(len(df.index))]
# df.to_csv("bbc-text-with-important.csv", index=None)
# df.head()

Unnamed: 0,category,text,important_labels,most_important,second_important,third_important
0,tech,tv future in the hands of viewers with home th...,,,,
1,business,worldcom boss left books alone former worldc...,,,,
2,sport,tigers wary of farrell gamble leicester say ...,,,,
3,sport,yeading face newcastle in fa cup premiership s...,,,,
4,entertainment,ocean s twelve raids box office ocean s twelve...,,,,


In [None]:
datapath = 'bbc-text-with-important.csv'
df = pd.read_csv(datapath)
df[2202:]

In [3]:
sentences = []

for i in range(len(df)):
    sentences.append(df.iloc[i, 1])

In [4]:
most_threshold = 0.05
second_threshold = 0.07
third_threshold = 0.10

tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
model = BertForSequenceClassification.from_pretrained('fine_tuned_bert')

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for i in range(len(sentences)):
    inputs_sentence = tokenizer(sentences[i], padding='max_length', \
                                max_length = 512, truncation=True, return_tensors="pt")
    
    model.to(device)
    inputs_sentence = inputs_sentence.to(device)
    outputs_sentence = model(**inputs_sentence)

    logits_sentence = outputs_sentence.logits
    probs_sentence = F.softmax(logits_sentence, dim=-1)
    pred_label_sentence = torch.argmax(probs_sentence, dim=-1).item()
    confidence_sentence = probs_sentence[0][pred_label_sentence].item()

    sentence2_idx = 0
    sentence2_input = torch.full(inputs_sentence['input_ids'].shape, 0)
    sentence2_tkn_type = torch.full(inputs_sentence['token_type_ids'].shape, 0)
    sentence2_att_mast = torch.full(inputs_sentence['attention_mask'].shape, 0)

    for j in range(512):
        token_now = inputs_sentence['input_ids'][0][j]
        token_now_int = token_now.item()
        
        sentence2_input[0][sentence2_idx] = token_now
        sentence2_att_mast[0][sentence2_idx] = 1
        sentence2_idx += 1
        if token_now_int == 101:
            df.iloc[i, 2] += '0'
            continue
        elif token_now_int == 102:
            df.iloc[i, 2] += ", 0"
            break
        
        sentence2_input = sentence2_input.to(device)
        sentence2_tkn_type = sentence2_tkn_type.to(device)
        sentence2_att_mast = sentence2_att_mast.to(device)
        # print(tokenizer.decode(sentence2_input[0]))
        outputs_sentence2 = model(input_ids=sentence2_input, attention_mask=sentence2_att_mast, token_type_ids=sentence2_tkn_type)

        logits_sentence2 = outputs_sentence2.logits
        probs_sentence2 = F.softmax(logits_sentence2, dim=-1)
        pred_label_sentence2 = torch.argmax(probs_sentence2, dim=-1).item()
        confidence_sentence2 = probs_sentence2[0][pred_label_sentence2].item()
        
        if pred_label_sentence != pred_label_sentence2:
            df.iloc[i, 2] += ", 0"
        else:
            score = abs(confidence_sentence - confidence_sentence2)
            
            if score <= most_threshold:
                df.iloc[i, 2] += ", 3"
                df.iloc[i, 3] += ", " + str(token_now_int)
                sentence2_input[0][sentence2_idx] = 0
                sentence2_att_mast[0][sentence2_idx] = 0
                sentence2_idx -= 1
            elif score <= second_threshold:
                df.iloc[i, 2] += ", 2"
                df.iloc[i, 4] += ", " + str(token_now_int)
                sentence2_input[0][sentence2_idx] = 0
                sentence2_att_mast[0][sentence2_idx] = 0
                sentence2_idx -= 1
            elif score <= third_threshold:
                df.iloc[i, 2] += ", 1"
                df.iloc[i, 5] += ", " + str(token_now_int)
                sentence2_input[0][sentence2_idx] = 0
                sentence2_att_mast[0][sentence2_idx] = 0
                sentence2_idx -= 1
            else:
                df.iloc[i, 2] += ", 0"
df.to_csv("bbc-text-with-important.csv", index=None)

In [12]:
df.iloc[1885]

category                                                entertainment
text                label withdraws mcfadden s video the new video...
important_labels     0, 3, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 3, 0, 0, ...
most_important       , 3107, 1888, 1888, 14430, 2483, 1461, 1888, ...
second_important                                               , 3481
third_important      , 1488, 1144, 5627, 7662, 1637, 20715, 14430,...
Name: 1885, dtype: object

In [9]:
test = df.iloc[1885,3].split(', ')[1:]
for k in range(len(test)):
    test[k] = int(test[k])
text1 = tokenizer.decode(test)
text1

'label video videolife singer song video song lyrics performer label video television stations label pop star song solo album song autobiographical songs group solo single chart top band'

In [10]:
test = df.iloc[1885,4].split(', ')[1:]
for k in range(len(test)):
    test[k] = int(test[k])
text1 = tokenizer.decode(test)
text1

'chart'

In [11]:
test = df.iloc[1885,5].split(', ')[1:]
for k in range(len(test)):
    test[k] = int(test[k])
text1 = tokenizer.decode(test)
text1

'son has replacement topics written collaboratorlifepers boy'

## Rewrite-0.03

In [6]:
import pandas as pd
import matplotlib.pyplot as plt
from transformers import BertTokenizer
import torch
import numpy as np
from transformers import BertTokenizer
import torch.nn.functional as F
from torch import nn
from transformers import BertForSequenceClassification

In [7]:
datapath = '../bbc-text.csv'
df = pd.read_csv(datapath)

df["important_labels"] = [" " for _ in range(len(df.index))]
df["most_important"] = [" " for _ in range(len(df.index))]
df["second_important"] = [" " for _ in range(len(df.index))]
df["third_important"] = [" " for _ in range(len(df.index))]
df.to_csv("bbc-text-with-important-003.csv", index=None)
df.head()

Unnamed: 0,category,text,important_labels,most_important,second_important,third_important
0,tech,tv future in the hands of viewers with home th...,,,,
1,business,worldcom boss left books alone former worldc...,,,,
2,sport,tigers wary of farrell gamble leicester say ...,,,,
3,sport,yeading face newcastle in fa cup premiership s...,,,,
4,entertainment,ocean s twelve raids box office ocean s twelve...,,,,


In [8]:
sentences = []

for i in range(len(df)):
    sentences.append(df.iloc[i, 1])

most_threshold = 0.03
second_threshold = 0.05
third_threshold = 0.07

tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
model = BertForSequenceClassification.from_pretrained('fine_tuned_bert')

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for i in range(len(sentences)):
    inputs_sentence = tokenizer(sentences[i], padding='max_length', \
                                max_length = 512, truncation=True, return_tensors="pt")
    
    model.to(device)
    inputs_sentence = inputs_sentence.to(device)
    outputs_sentence = model(**inputs_sentence)

    logits_sentence = outputs_sentence.logits
    probs_sentence = F.softmax(logits_sentence, dim=-1)
    pred_label_sentence = torch.argmax(probs_sentence, dim=-1).item()
    confidence_sentence = probs_sentence[0][pred_label_sentence].item()

    sentence2_idx = 0
    sentence2_input = torch.full(inputs_sentence['input_ids'].shape, 0)
    sentence2_tkn_type = torch.full(inputs_sentence['token_type_ids'].shape, 0)
    sentence2_att_mast = torch.full(inputs_sentence['attention_mask'].shape, 0)

    for j in range(512):
        token_now = inputs_sentence['input_ids'][0][j]
        token_now_int = token_now.item()
        
        sentence2_input[0][sentence2_idx] = token_now
        sentence2_att_mast[0][sentence2_idx] = 1
        sentence2_idx += 1
        if token_now_int == 101:
            df.iloc[i, 2] += '0'
            continue
        elif token_now_int == 102:
            df.iloc[i, 2] += ", 0"
            break
        
        sentence2_input = sentence2_input.to(device)
        sentence2_tkn_type = sentence2_tkn_type.to(device)
        sentence2_att_mast = sentence2_att_mast.to(device)
        # print(tokenizer.decode(sentence2_input[0]))
        outputs_sentence2 = model(input_ids=sentence2_input, attention_mask=sentence2_att_mast, token_type_ids=sentence2_tkn_type)

        logits_sentence2 = outputs_sentence2.logits
        probs_sentence2 = F.softmax(logits_sentence2, dim=-1)
        pred_label_sentence2 = torch.argmax(probs_sentence2, dim=-1).item()
        confidence_sentence2 = probs_sentence2[0][pred_label_sentence2].item()
        
        if pred_label_sentence != pred_label_sentence2:
            df.iloc[i, 2] += ", 0"
        else:
            score = abs(confidence_sentence - confidence_sentence2)
            
            if score <= most_threshold:
                df.iloc[i, 2] += ", 3"
                df.iloc[i, 3] += ", " + str(token_now_int)
                sentence2_input[0][sentence2_idx] = 0
                sentence2_att_mast[0][sentence2_idx] = 0
                sentence2_idx -= 1
            elif score <= second_threshold:
                df.iloc[i, 2] += ", 2"
                df.iloc[i, 4] += ", " + str(token_now_int)
                sentence2_input[0][sentence2_idx] = 0
                sentence2_att_mast[0][sentence2_idx] = 0
                sentence2_idx -= 1
            elif score <= third_threshold:
                df.iloc[i, 2] += ", 1"
                df.iloc[i, 5] += ", " + str(token_now_int)
                sentence2_input[0][sentence2_idx] = 0
                sentence2_att_mast[0][sentence2_idx] = 0
                sentence2_idx -= 1
            else:
                df.iloc[i, 2] += ", 0"
df.to_csv("bbc-text-with-important-003.csv", index=None)

### --- test ---

In [23]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
model = BertForSequenceClassification.from_pretrained('fine_tuned_bert')
inputs = tokenizer(sentences[2], padding='max_length', \
                   max_length = 512, truncation=True, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
pred_label = logits.argmax(dim=1)
logits

tensor([[-1.6769, -1.3050,  4.5313, -1.6913, -0.7815]],
       grad_fn=<AddmmBackward0>)

In [28]:
probs = F.softmax(logits, dim=-1)

# 获取分类结果和概率值
label = torch.argmax(probs, dim=-1).item()
confidence = probs[0][label].item()
print(pred_label)
print(pred_label.item())
print(label)
print(confidence)
probs

tensor([2])
2
2
0.9882943630218506


tensor([[0.0020, 0.0029, 0.9883, 0.0020, 0.0049]], grad_fn=<SoftmaxBackward0>)