In [12]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import BertModel
from tokenizers import BertWordPieceTokenizer

import torch
import torch.nn.utils.prune as prune
from torch.utils.data import DataLoader, Dataset

import pandas as pd
from tqdm import tqdm

from sklearn.metrics import classification_report, confusion_matrix

In [2]:
device = None

import torch
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print (device)
else:
    print ("MPS device not found.")

mps


In [3]:
# 모델 및 토크나이저 로드
model = AutoModelForSequenceClassification.from_pretrained("fabriceyhc/bert-base-uncased-yahoo_answers_topics")
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [4]:
# test 데이터 가져오기
test_data = "./yahoo_answers_csv/test.csv"
test_df = pd.read_csv(test_data)

sentence = test_df.iloc[0, 1] + " " + test_df.iloc[0, 2] + " " + test_df.iloc[0, 3]
true_label = test_df.iloc[0, 0]
print(f"sentence: {sentence}")

inputs = tokenizer(sentence, return_tensors="pt").to(device)

token_length = inputs.input_ids.shape[1]
print(f"token length: {token_length}\n")



# 모델에 입력값 넣기
outputs = model(**inputs)

predictions = outputs.logits.argmax(dim=-1)
print(f"pred output: {predictions.item() + 1}")
print(f"true label: {true_label}")
print(outputs.logits)



"""
class 정보

1: Society & Culture
2: Science & Mathematics
3: Health
4: Education & Reference
5: Computers & Internet
6: Sports
7: Business & Finance
8: Entertainment & Music
9: Family & Relationships
10: Politics & Government

"""


sentence: Why does Zebras have stripes? What is the purpose or those stripes? Who do they serve the Zebras in the wild life? this provides camouflage - predator vision is such that it is usually difficult for them to see complex patterns
token length: 48

pred output: 2
true label: 2
tensor([[ 0.7546,  3.1281, -0.6842,  2.3181, -1.2996, -0.6507,  1.6057, -1.0207,
         -1.7158, -0.7543]], device='mps:0', grad_fn=<LinearBackward0>)


'\nclass 정보\n\n1: Society & Culture\n2: Science & Mathematics\n3: Health\n4: Education & Reference\n5: Computers & Internet\n6: Sports\n7: Business & Finance\n8: Entertainment & Music\n9: Family & Relationships\n10: Politics & Government\n\n'

In [5]:
class TestDataset(Dataset):
    def __init__(self, df, tokenizer):
        self.sentences = df.iloc[:, 1:].apply(lambda x: ' '.join(x.dropna().astype(str)), axis=1).values
        self.labels = df.iloc[:, 0].values
        self.tokenizer = tokenizer
  
    def __len__(self):
        return len(self.sentences)
    
    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        inputs = self.tokenizer(sentence, truncation=True, max_length=512, padding='max_length', return_tensors="pt")
        label = torch.tensor(self.labels[idx]) # label은 1부터 시작하기 때문에, 나중에 inference할 때에 예측값에 1을 더해줘야 합니다.
        
        return inputs, label

In [6]:
test_data = "./yahoo_answers_csv/test.csv"
test_df = pd.read_csv(test_data, header=None)

test_dataset = TestDataset(test_df, tokenizer)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [14]:
# 1번째 레이어 1번 헤드 끄기

# print(model.bert.encoder.layer[0].attention.self.value.weight.data)
# model.bert.encoder.layer[0].attention.self.value.weight.data[:, :64] = 0
# print(model.bert.encoder.layer[0].attention.self.value.weight.data)

In [19]:
for layer_index in range(12):
    for head_index in range(12):

        model = AutoModelForSequenceClassification.from_pretrained("fabriceyhc/bert-base-uncased-yahoo_answers_topics")
        model = model.to(device)

        # print(model.bert.encoder.layer[layer_index].attention.self.value.weight.data)
        model.bert.encoder.layer[layer_index].attention.self.value.weight.data[head_index*64:(head_index+1)*64, :] = 0
        # print(model.bert.encoder.layer[layer_index].attention.self.value.weight.data)

        preds = []
        true_labels = []

        for batch in tqdm(test_loader, desc="Evaluating"):
            inputs, labels = batch
            inputs = {k: v.squeeze(1).to(device) for k, v in inputs.items()} 
            labels = labels.to(device)

            with torch.no_grad():
                outputs = model(**inputs)
            prediction = outputs.logits.argmax(dim=-1) + 1
            
            preds.extend(prediction.tolist())
            true_labels.extend(labels.tolist())

        print(f"off attention head index // layer: {layer_index+1}, head: {head_index+1}")
            
        print('classification report')
        print(classification_report(true_labels, preds, digits=5))

        print('confusion matrix')
        print(confusion_matrix(true_labels, preds))

Evaluating: 100%|██████████| 1875/1875 [22:57<00:00,  1.36it/s]


off attention head index // layer: 1, head: 1
classification report
              precision    recall  f1-score   support

           1    0.62009   0.61833   0.61921      6000
           2    0.73934   0.80317   0.76993      6000
           3    0.74787   0.86517   0.80226      6000
           4    0.65157   0.54200   0.59176      6000
           5    0.85083   0.88317   0.86670      6000
           6    0.91874   0.91200   0.91536      6000
           7    0.76982   0.41750   0.54139      6000
           8    0.71085   0.79367   0.74998      6000
           9    0.72584   0.82733   0.77327      6000
          10    0.76195   0.83700   0.79771      6000

    accuracy                        0.74993     60000
   macro avg    0.74969   0.74993   0.74276     60000
weighted avg    0.74969   0.74993   0.74276     60000

confusion matrix
[[3710  127  255  405   62   42   90  380  636  293]
 [ 162 4819  410  296   60   43   64   69   17   60]
 [ 157  103 5191   71    9   35   24   81  279   5

Evaluating: 100%|██████████| 1875/1875 [22:50<00:00,  1.37it/s]


off attention head index // layer: 1, head: 2
classification report
              precision    recall  f1-score   support

           1    0.61858   0.62383   0.62119      6000
           2    0.73720   0.80650   0.77030      6000
           3    0.74870   0.86500   0.80266      6000
           4    0.65031   0.54117   0.59074      6000
           5    0.85231   0.88200   0.86690      6000
           6    0.92168   0.90817   0.91488      6000
           7    0.77317   0.41300   0.53840      6000
           8    0.70837   0.79267   0.74815      6000
           9    0.72568   0.82667   0.77289      6000
          10    0.76219   0.83650   0.79762      6000

    accuracy                        0.74955     60000
   macro avg    0.74982   0.74955   0.74237     60000
weighted avg    0.74982   0.74955   0.74237     60000

confusion matrix
[[3743  123  253  404   59   36   82  381  635  284]
 [ 162 4839  403  292   55   40   59   71   19   60]
 [ 159  109 5190   69    9   35   23   80  276   5

Evaluating: 100%|██████████| 1875/1875 [22:49<00:00,  1.37it/s]


off attention head index // layer: 1, head: 3
classification report
              precision    recall  f1-score   support

           1    0.62075   0.62033   0.62054      6000
           2    0.73716   0.80350   0.76890      6000
           3    0.74863   0.86367   0.80204      6000
           4    0.65240   0.53867   0.59010      6000
           5    0.85100   0.88333   0.86686      6000
           6    0.91753   0.91050   0.91400      6000
           7    0.77391   0.41533   0.54056      6000
           8    0.71042   0.79200   0.74900      6000
           9    0.72142   0.82867   0.77133      6000
          10    0.76048   0.83717   0.79699      6000

    accuracy                        0.74932     60000
   macro avg    0.74937   0.74932   0.74203     60000
weighted avg    0.74937   0.74932   0.74203     60000

confusion matrix
[[3722  126  250  401   59   40   84  383  646  289]
 [ 163 4821  411  291   58   41   60   72   19   64]
 [ 152  111 5182   72    9   38   21   81  284   5

Evaluating: 100%|██████████| 1875/1875 [22:21<00:00,  1.40it/s]


off attention head index // layer: 1, head: 4
classification report
              precision    recall  f1-score   support

           1    0.62027   0.62017   0.62022      6000
           2    0.73889   0.80650   0.77122      6000
           3    0.74996   0.86333   0.80267      6000
           4    0.65500   0.53950   0.59167      6000
           5    0.84992   0.88533   0.86727      6000
           6    0.91772   0.91083   0.91426      6000
           7    0.77152   0.41533   0.53998      6000
           8    0.70913   0.79317   0.74880      6000
           9    0.72192   0.82817   0.77140      6000
          10    0.76361   0.83667   0.79847      6000

    accuracy                        0.74990     60000
   macro avg    0.74979   0.74990   0.74259     60000
weighted avg    0.74979   0.74990   0.74259     60000

confusion matrix
[[3721  124  250  396   60   41   85  383  648  292]
 [ 163 4839  403  283   58   42   58   74   19   61]
 [ 156  111 5180   72    9   38   23   80  283   4

Evaluating: 100%|██████████| 1875/1875 [22:09<00:00,  1.41it/s]


off attention head index // layer: 1, head: 5
classification report
              precision    recall  f1-score   support

           1    0.61913   0.62017   0.61965      6000
           2    0.73681   0.81000   0.77167      6000
           3    0.75214   0.86333   0.80391      6000
           4    0.65350   0.53783   0.59005      6000
           5    0.84912   0.88450   0.86645      6000
           6    0.91804   0.91100   0.91451      6000
           7    0.77295   0.41533   0.54033      6000
           8    0.71125   0.79317   0.74998      6000
           9    0.72366   0.82750   0.77210      6000
          10    0.76279   0.83767   0.79847      6000

    accuracy                        0.75005     60000
   macro avg    0.74994   0.75005   0.74271     60000
weighted avg    0.74994   0.75005   0.74271     60000

confusion matrix
[[3721  130  250  396   61   42   86  373  646  295]
 [ 165 4860  385  279   59   43   61   68   18   62]
 [ 153  112 5180   76    9   35   25   81  279   5

Evaluating: 100%|██████████| 1875/1875 [22:50<00:00,  1.37it/s]


off attention head index // layer: 1, head: 6
classification report
              precision    recall  f1-score   support

           1    0.62082   0.61917   0.61999      6000
           2    0.74001   0.80550   0.77137      6000
           3    0.75669   0.85833   0.80431      6000
           4    0.64770   0.54267   0.59055      6000
           5    0.84748   0.88717   0.86687      6000
           6    0.91661   0.91233   0.91447      6000
           7    0.77470   0.41033   0.53650      6000
           8    0.70665   0.79817   0.74963      6000
           9    0.72843   0.82300   0.77283      6000
          10    0.75619   0.84000   0.79589      6000

    accuracy                        0.74967     60000
   macro avg    0.74953   0.74967   0.74224     60000
weighted avg    0.74953   0.74967   0.74224     60000

confusion matrix
[[3715  121  239  415   66   45   83  392  624  300]
 [ 166 4833  376  300   64   44   56   77   18   66]
 [ 162  114 5150   77    9   45   25   85  279   5

Evaluating: 100%|██████████| 1875/1875 [22:33<00:00,  1.39it/s]


off attention head index // layer: 1, head: 7
classification report
              precision    recall  f1-score   support

           1    0.61775   0.62167   0.61970      6000
           2    0.73695   0.80683   0.77031      6000
           3    0.75196   0.86450   0.80431      6000
           4    0.65023   0.54067   0.59041      6000
           5    0.84990   0.88517   0.86717      6000
           6    0.91995   0.90983   0.91487      6000
           7    0.77330   0.41617   0.54112      6000
           8    0.70863   0.79283   0.74837      6000
           9    0.72842   0.82433   0.77342      6000
          10    0.76134   0.83633   0.79708      6000

    accuracy                        0.74983     60000
   macro avg    0.74984   0.74983   0.74268     60000
weighted avg    0.74984   0.74983   0.74268     60000

confusion matrix
[[3730  128  249  406   63   41   82  380  626  295]
 [ 166 4841  401  288   59   41   55   71   17   61]
 [ 164  106 5187   69    9   36   27   78  274   5

Evaluating: 100%|██████████| 1875/1875 [22:40<00:00,  1.38it/s]


off attention head index // layer: 1, head: 8
classification report
              precision    recall  f1-score   support

           1    0.61862   0.62017   0.61939      6000
           2    0.73976   0.80350   0.77031      6000
           3    0.74849   0.86500   0.80254      6000
           4    0.65241   0.54150   0.59180      6000
           5    0.85088   0.88250   0.86640      6000
           6    0.91991   0.90933   0.91459      6000
           7    0.77416   0.41250   0.53822      6000
           8    0.70809   0.79400   0.74859      6000
           9    0.72453   0.82717   0.77245      6000
          10    0.75894   0.83800   0.79651      6000

    accuracy                        0.74937     60000
   macro avg    0.74958   0.74937   0.74208     60000
weighted avg    0.74958   0.74937   0.74208     60000

confusion matrix
[[3721  122  254  401   61   40   83  383  641  294]
 [ 164 4821  406  293   60   43   59   74   18   62]
 [ 158  109 5190   70    9   36   22   78  277   5

Evaluating: 100%|██████████| 1875/1875 [22:46<00:00,  1.37it/s]


off attention head index // layer: 1, head: 9
classification report
              precision    recall  f1-score   support

           1    0.62239   0.61700   0.61969      6000
           2    0.74105   0.80033   0.76955      6000
           3    0.74949   0.86517   0.80319      6000
           4    0.65275   0.53950   0.59075      6000
           5    0.84783   0.88400   0.86554      6000
           6    0.91825   0.91167   0.91495      6000
           7    0.76871   0.41600   0.53985      6000
           8    0.71008   0.79233   0.74896      6000
           9    0.72071   0.83050   0.77172      6000
          10    0.76035   0.83867   0.79759      6000

    accuracy                        0.74952     60000
   macro avg    0.74916   0.74952   0.74218     60000
weighted avg    0.74916   0.74952   0.74218     60000

confusion matrix
[[3702  125  252  399   63   43   86  380  653  297]
 [ 165 4802  412  297   61   43   59   78   19   64]
 [ 149  101 5191   71    9   39   22   80  286   5

Evaluating: 100%|██████████| 1875/1875 [22:40<00:00,  1.38it/s]


off attention head index // layer: 1, head: 10
classification report
              precision    recall  f1-score   support

           1    0.62172   0.61633   0.61902      6000
           2    0.73651   0.80733   0.77029      6000
           3    0.75372   0.86000   0.80336      6000
           4    0.65116   0.53883   0.58969      6000
           5    0.85010   0.88467   0.86704      6000
           6    0.91816   0.91250   0.91532      6000
           7    0.77520   0.41267   0.53861      6000
           8    0.70549   0.79650   0.74824      6000
           9    0.72518   0.82417   0.77151      6000
          10    0.75667   0.84117   0.79669      6000

    accuracy                        0.74942     60000
   macro avg    0.74939   0.74942   0.74198     60000
weighted avg    0.74939   0.74942   0.74198     60000

confusion matrix
[[3698  124  248  402   59   45   84  398  640  302]
 [ 163 4844  385  292   60   43   58   75   17   63]
 [ 159  120 5160   73    9   38   22   87  278   

Evaluating: 100%|██████████| 1875/1875 [22:41<00:00,  1.38it/s]


off attention head index // layer: 1, head: 11
classification report
              precision    recall  f1-score   support

           1    0.61936   0.61967   0.61951      6000
           2    0.73783   0.80583   0.77033      6000
           3    0.75043   0.86400   0.80322      6000
           4    0.65396   0.53733   0.58994      6000
           5    0.84950   0.88433   0.86657      6000
           6    0.91756   0.91083   0.91419      6000
           7    0.77145   0.41517   0.53982      6000
           8    0.70918   0.79333   0.74890      6000
           9    0.72441   0.82800   0.77275      6000
          10    0.76003   0.83667   0.79651      6000

    accuracy                        0.74952     60000
   macro avg    0.74937   0.74952   0.74217     60000
weighted avg    0.74937   0.74952   0.74217     60000

confusion matrix
[[3718  128  248  396   62   44   87  386  640  291]
 [ 163 4835  401  288   61   43   59   71   19   60]
 [ 155  110 5184   74    9   37   25   79  278   

Evaluating: 100%|██████████| 1875/1875 [22:26<00:00,  1.39it/s]


off attention head index // layer: 1, head: 12
classification report
              precision    recall  f1-score   support

           1    0.62114   0.62083   0.62099      6000
           2    0.73638   0.80633   0.76977      6000
           3    0.74892   0.86350   0.80214      6000
           4    0.65367   0.53917   0.59092      6000
           5    0.84975   0.88417   0.86662      6000
           6    0.91999   0.91033   0.91514      6000
           7    0.77093   0.41283   0.53772      6000
           8    0.71051   0.79400   0.74994      6000
           9    0.72306   0.82767   0.77184      6000
          10    0.76015   0.83617   0.79635      6000

    accuracy                        0.74950     60000
   macro avg    0.74945   0.74950   0.74214     60000
weighted avg    0.74945   0.74950   0.74214     60000

confusion matrix
[[3725  125  256  401   58   41   84  380  642  288]
 [ 165 4838  399  284   61   40   58   73   20   62]
 [ 152  110 5181   73    9   38   26   78  282   

Evaluating: 100%|██████████| 1875/1875 [22:30<00:00,  1.39it/s]


off attention head index // layer: 2, head: 1
classification report
              precision    recall  f1-score   support

           1    0.61916   0.61833   0.61875      6000
           2    0.73807   0.80683   0.77092      6000
           3    0.75025   0.86417   0.80319      6000
           4    0.65422   0.54017   0.59175      6000
           5    0.84951   0.88533   0.86705      6000
           6    0.91926   0.91083   0.91503      6000
           7    0.77065   0.41667   0.54089      6000
           8    0.71118   0.79000   0.74852      6000
           9    0.72258   0.83000   0.77257      6000
          10    0.76173   0.83600   0.79714      6000

    accuracy                        0.74983     60000
   macro avg    0.74966   0.74983   0.74258     60000
weighted avg    0.74966   0.74983   0.74258     60000

confusion matrix
[[3710  127  253  400   61   41   87  376  650  295]
 [ 160 4841  398  288   58   43   60   72   19   61]
 [ 155  111 5185   71    9   35   23   81  280   5

Evaluating: 100%|██████████| 1875/1875 [22:24<00:00,  1.39it/s]


off attention head index // layer: 2, head: 2
classification report
              precision    recall  f1-score   support

           1    0.61857   0.61950   0.61904      6000
           2    0.73743   0.80650   0.77042      6000
           3    0.75058   0.86267   0.80273      6000
           4    0.65325   0.53817   0.59015      6000
           5    0.85005   0.88433   0.86685      6000
           6    0.91812   0.91017   0.91413      6000
           7    0.76968   0.41550   0.53967      6000
           8    0.71128   0.79367   0.75022      6000
           9    0.72267   0.82867   0.77205      6000
          10    0.76146   0.83583   0.79692      6000

    accuracy                        0.74950     60000
   macro avg    0.74931   0.74950   0.74222     60000
weighted avg    0.74931   0.74950   0.74222     60000

confusion matrix
[[3717  125  249  397   61   42   87  387  643  292]
 [ 163 4839  398  289   57   42   63   71   19   59]
 [ 157  111 5176   69    9   38   27   79  285   4

Evaluating: 100%|██████████| 1875/1875 [22:35<00:00,  1.38it/s]


off attention head index // layer: 2, head: 3
classification report
              precision    recall  f1-score   support

           1    0.61991   0.61950   0.61971      6000
           2    0.73757   0.80383   0.76928      6000
           3    0.74870   0.86450   0.80244      6000
           4    0.65217   0.54217   0.59210      6000
           5    0.84898   0.88633   0.86725      6000
           6    0.91908   0.91050   0.91477      6000
           7    0.77314   0.41350   0.53882      6000
           8    0.70938   0.79533   0.74990      6000
           9    0.72322   0.82917   0.77258      6000
          10    0.76509   0.83217   0.79722      6000

    accuracy                        0.74970     60000
   macro avg    0.74972   0.74970   0.74241     60000
weighted avg    0.74972   0.74970   0.74241     60000

confusion matrix
[[3717  125  251  406   62   42   85  385  649  278]
 [ 164 4823  408  288   64   41   59   75   19   59]
 [ 153  108 5187   71    9   35   26   83  280   4

Evaluating: 100%|██████████| 1875/1875 [22:38<00:00,  1.38it/s]


off attention head index // layer: 2, head: 4
classification report
              precision    recall  f1-score   support

           1    0.61894   0.62100   0.61997      6000
           2    0.73736   0.80667   0.77046      6000
           3    0.75025   0.86317   0.80276      6000
           4    0.65257   0.54000   0.59097      6000
           5    0.85079   0.88383   0.86700      6000
           6    0.91859   0.91017   0.91436      6000
           7    0.77262   0.41850   0.54292      6000
           8    0.70986   0.79350   0.74935      6000
           9    0.72504   0.82667   0.77253      6000
          10    0.76309   0.83583   0.79780      6000

    accuracy                        0.74993     60000
   macro avg    0.74991   0.74993   0.74281     60000
weighted avg    0.74991   0.74993   0.74281     60000

confusion matrix
[[3726  127  249  404   58   43   84  384  635  290]
 [ 164 4840  400  286   59   40   58   73   19   61]
 [ 158  112 5179   71    9   38   25   81  278   4

Evaluating: 100%|██████████| 1875/1875 [22:32<00:00,  1.39it/s]


off attention head index // layer: 2, head: 5
classification report
              precision    recall  f1-score   support

           1    0.62091   0.62050   0.62071      6000
           2    0.73802   0.80567   0.77036      6000
           3    0.74946   0.86350   0.80245      6000
           4    0.65443   0.53783   0.59043      6000
           5    0.85043   0.88317   0.86649      6000
           6    0.91940   0.91067   0.91501      6000
           7    0.77032   0.41700   0.54109      6000
           8    0.70982   0.79417   0.74963      6000
           9    0.72168   0.82933   0.77177      6000
          10    0.76277   0.83650   0.79793      6000

    accuracy                        0.74983     60000
   macro avg    0.74972   0.74983   0.74259     60000
weighted avg    0.74972   0.74983   0.74259     60000

confusion matrix
[[3723  125  254  398   59   39   85  378  645  294]
 [ 164 4834  398  284   60   42   62   77   19   60]
 [ 149  110 5181   71    9   36   24   85  285   5

Evaluating: 100%|██████████| 1875/1875 [22:45<00:00,  1.37it/s]


off attention head index // layer: 2, head: 6
classification report
              precision    recall  f1-score   support

           1    0.61777   0.62117   0.61946      6000
           2    0.73787   0.80367   0.76937      6000
           3    0.75065   0.86350   0.80313      6000
           4    0.65288   0.53917   0.59060      6000
           5    0.85004   0.88333   0.86637      6000
           6    0.91905   0.91017   0.91459      6000
           7    0.77261   0.41567   0.54053      6000
           8    0.71045   0.79333   0.74961      6000
           9    0.72396   0.82833   0.77264      6000
          10    0.75973   0.83633   0.79619      6000

    accuracy                        0.74947     60000
   macro avg    0.74950   0.74947   0.74225     60000
weighted avg    0.74950   0.74947   0.74225     60000

confusion matrix
[[3727  127  248  396   60   41   85  380  641  295]
 [ 165 4822  405  291   57   42   61   74   19   64]
 [ 155  108 5181   71    9   38   23   82  282   5

Evaluating: 100%|██████████| 1875/1875 [23:15<00:00,  1.34it/s]


off attention head index // layer: 2, head: 7
classification report
              precision    recall  f1-score   support

           1    0.62023   0.61733   0.61878      6000
           2    0.73762   0.80683   0.77068      6000
           3    0.75171   0.86233   0.80323      6000
           4    0.65436   0.54050   0.59200      6000
           5    0.84994   0.88550   0.86736      6000
           6    0.91847   0.91067   0.91455      6000
           7    0.77387   0.41467   0.53999      6000
           8    0.70967   0.79400   0.74947      6000
           9    0.72242   0.82933   0.77219      6000
          10    0.76051   0.83783   0.79730      6000

    accuracy                        0.74990     60000
   macro avg    0.74988   0.74990   0.74256     60000
weighted avg    0.74988   0.74990   0.74256     60000

confusion matrix
[[3704  128  249  401   61   42   85  384  648  298]
 [ 164 4841  396  288   59   42   57   74   18   61]
 [ 156  112 5174   72    9   38   24   81  283   5

Evaluating: 100%|██████████| 1875/1875 [23:01<00:00,  1.36it/s]


off attention head index // layer: 2, head: 8
classification report
              precision    recall  f1-score   support

           1    0.61965   0.61717   0.61840      6000
           2    0.74035   0.80217   0.77002      6000
           3    0.75120   0.86250   0.80301      6000
           4    0.65045   0.54150   0.59100      6000
           5    0.85078   0.88183   0.86603      6000
           6    0.91894   0.90883   0.91386      6000
           7    0.77179   0.41767   0.54201      6000
           8    0.70627   0.79550   0.74824      6000
           9    0.72222   0.82983   0.77230      6000
          10    0.76165   0.83617   0.79717      6000

    accuracy                        0.74932     60000
   macro avg    0.74933   0.74932   0.74220     60000
weighted avg    0.74933   0.74932   0.74220     60000

confusion matrix
[[3703  121  249  402   61   42   87  395  651  289]
 [ 170 4813  404  301   57   40   57   76   19   63]
 [ 159  104 5175   73    9   39   29   81  283   4

Evaluating: 100%|██████████| 1875/1875 [22:59<00:00,  1.36it/s]


off attention head index // layer: 2, head: 9
classification report
              precision    recall  f1-score   support

           1    0.62044   0.62117   0.62080      6000
           2    0.73803   0.80667   0.77082      6000
           3    0.75090   0.86517   0.80400      6000
           4    0.65303   0.54017   0.59126      6000
           5    0.84974   0.88317   0.86613      6000
           6    0.91734   0.91183   0.91458      6000
           7    0.77201   0.41650   0.54108      6000
           8    0.71115   0.79400   0.75030      6000
           9    0.72459   0.82700   0.77242      6000
          10    0.76274   0.83583   0.79761      6000

    accuracy                        0.75015     60000
   macro avg    0.75000   0.75015   0.74290     60000
weighted avg    0.75000   0.75015   0.74290     60000

confusion matrix
[[3727  125  251  399   62   42   85  382  640  287]
 [ 163 4840  396  289   56   41   59   74   19   63]
 [ 154  109 5191   71    9   38   23   77  278   5

Evaluating: 100%|██████████| 1875/1875 [23:10<00:00,  1.35it/s]


off attention head index // layer: 2, head: 10
classification report
              precision    recall  f1-score   support

           1    0.62066   0.62200   0.62133      6000
           2    0.73954   0.80400   0.77042      6000
           3    0.74917   0.86467   0.80279      6000
           4    0.65415   0.54000   0.59162      6000
           5    0.85072   0.88333   0.86672      6000
           6    0.91908   0.91050   0.91477      6000
           7    0.77161   0.41500   0.53972      6000
           8    0.70860   0.79517   0.74939      6000
           9    0.72394   0.82867   0.77277      6000
          10    0.76154   0.83567   0.79688      6000

    accuracy                        0.74990     60000
   macro avg    0.74990   0.74990   0.74264     60000
weighted avg    0.74990   0.74990   0.74264     60000

confusion matrix
[[3732  123  252  396   60   40   84  382  645  286]
 [ 162 4824  406  293   57   42   59   78   19   60]
 [ 155  107 5188   70    9   36   22   81  282   

Evaluating: 100%|██████████| 1875/1875 [23:05<00:00,  1.35it/s]


off attention head index // layer: 2, head: 11
classification report
              precision    recall  f1-score   support

           1    0.62183   0.61633   0.61907      6000
           2    0.73770   0.80717   0.77087      6000
           3    0.75105   0.86283   0.80307      6000
           4    0.65145   0.53767   0.58912      6000
           5    0.84820   0.88567   0.86653      6000
           6    0.91782   0.91017   0.91397      6000
           7    0.76776   0.41433   0.53821      6000
           8    0.71005   0.79383   0.74961      6000
           9    0.72226   0.82783   0.77145      6000
          10    0.76033   0.83700   0.79683      6000

    accuracy                        0.74928     60000
   macro avg    0.74885   0.74928   0.74187     60000
weighted avg    0.74885   0.74928   0.74187     60000

confusion matrix
[[3698  127  246  404   63   44   86  385  647  300]
 [ 159 4843  399  291   57   42   63   66   18   62]
 [ 155  107 5177   72    9   40   27   81  282   

Evaluating: 100%|██████████| 1875/1875 [23:26<00:00,  1.33it/s]


off attention head index // layer: 2, head: 12
classification report
              precision    recall  f1-score   support

           1    0.62485   0.61933   0.62208      6000
           2    0.73896   0.80350   0.76988      6000
           3    0.75087   0.86350   0.80326      6000
           4    0.65001   0.54200   0.59111      6000
           5    0.84999   0.88300   0.86618      6000
           6    0.91946   0.90950   0.91445      6000
           7    0.77013   0.41767   0.54160      6000
           8    0.70895   0.79733   0.75055      6000
           9    0.72318   0.82817   0.77212      6000
          10    0.76219   0.83650   0.79762      6000

    accuracy                        0.75005     60000
   macro avg    0.74986   0.75005   0.74289     60000
weighted avg    0.74986   0.75005   0.74289     60000

confusion matrix
[[3716  123  253  407   60   41   88  380  643  289]
 [ 162 4821  398  300   61   40   60   77   18   63]
 [ 151  113 5181   74    9   37   22   82  283   

Evaluating: 100%|██████████| 1875/1875 [23:24<00:00,  1.34it/s]


off attention head index // layer: 3, head: 1
classification report
              precision    recall  f1-score   support

           1    0.62107   0.61900   0.62003      6000
           2    0.73642   0.80650   0.76987      6000
           3    0.75022   0.86300   0.80267      6000
           4    0.64982   0.54000   0.58984      6000
           5    0.85117   0.88267   0.86663      6000
           6    0.91972   0.91083   0.91526      6000
           7    0.77090   0.41333   0.53814      6000
           8    0.71047   0.79300   0.74947      6000
           9    0.72166   0.82750   0.77096      6000
          10    0.76026   0.83667   0.79664      6000

    accuracy                        0.74925     60000
   macro avg    0.74917   0.74925   0.74195     60000
weighted avg    0.74917   0.74925   0.74195     60000

confusion matrix
[[3714  127  252  403   61   40   86  377  646  294]
 [ 161 4839  399  289   58   43   59   74   19   59]
 [ 152  114 5178   70    9   36   26   78  284   5

Evaluating: 100%|██████████| 1875/1875 [23:22<00:00,  1.34it/s]


off attention head index // layer: 3, head: 2
classification report
              precision    recall  f1-score   support

           1    0.62009   0.62050   0.62029      6000
           2    0.73783   0.80583   0.77033      6000
           3    0.75116   0.86283   0.80313      6000
           4    0.65174   0.53617   0.58833      6000
           5    0.85066   0.88383   0.86693      6000
           6    0.91911   0.91083   0.91495      6000
           7    0.77368   0.41650   0.54150      6000
           8    0.71007   0.79350   0.74947      6000
           9    0.72308   0.83033   0.77300      6000
          10    0.76036   0.83767   0.79715      6000

    accuracy                        0.74980     60000
   macro avg    0.74978   0.74980   0.74251     60000
weighted avg    0.74978   0.74980   0.74251     60000

confusion matrix
[[3723  125  249  397   61   40   82  386  644  293]
 [ 166 4835  396  290   57   42   56   75   19   64]
 [ 157  111 5177   72    9   37   22   79  284   5

Evaluating: 100%|██████████| 1875/1875 [23:27<00:00,  1.33it/s]


off attention head index // layer: 3, head: 3
classification report
              precision    recall  f1-score   support

           1    0.62060   0.62050   0.62055      6000
           2    0.73769   0.80667   0.77064      6000
           3    0.75018   0.86433   0.80322      6000
           4    0.65507   0.54000   0.59200      6000
           5    0.84960   0.88500   0.86694      6000
           6    0.91836   0.91117   0.91475      6000
           7    0.77295   0.41533   0.54033      6000
           8    0.71108   0.79333   0.74996      6000
           9    0.72257   0.82867   0.77199      6000
          10    0.76212   0.83567   0.79720      6000

    accuracy                        0.75007     60000
   macro avg    0.75002   0.75007   0.74276     60000
weighted avg    0.75002   0.75007   0.74276     60000

confusion matrix
[[3723  127  253  393   62   40   84  383  647  288]
 [ 162 4840  400  285   58   43   59   73   19   61]
 [ 154  108 5186   71    9   39   23   80  280   5

Evaluating: 100%|██████████| 1875/1875 [23:22<00:00,  1.34it/s]


off attention head index // layer: 3, head: 4
classification report
              precision    recall  f1-score   support

           1    0.61939   0.61867   0.61903      6000
           2    0.73833   0.80700   0.77114      6000
           3    0.75069   0.86317   0.80301      6000
           4    0.65328   0.53950   0.59096      6000
           5    0.84893   0.88317   0.86571      6000
           6    0.91940   0.91067   0.91501      6000
           7    0.77269   0.41583   0.54069      6000
           8    0.71043   0.79367   0.74974      6000
           9    0.72252   0.82933   0.77225      6000
          10    0.76089   0.83583   0.79660      6000

    accuracy                        0.74968     60000
   macro avg    0.74965   0.74968   0.74241     60000
weighted avg    0.74965   0.74968   0.74241     60000

confusion matrix
[[3712  125  253  403   63   43   82  379  651  289]
 [ 165 4842  395  289   57   41   58   74   19   60]
 [ 156  111 5179   72    9   36   24   80  281   5

Evaluating: 100%|██████████| 1875/1875 [23:17<00:00,  1.34it/s]


off attention head index // layer: 3, head: 5
classification report
              precision    recall  f1-score   support

           1    0.62126   0.61950   0.62038      6000
           2    0.73743   0.80650   0.77042      6000
           3    0.75091   0.86267   0.80292      6000
           4    0.65310   0.54033   0.59139      6000
           5    0.84898   0.88350   0.86589      6000
           6    0.91915   0.91133   0.91522      6000
           7    0.76949   0.41617   0.54018      6000
           8    0.71026   0.79383   0.74972      6000
           9    0.72436   0.82867   0.77301      6000
          10    0.76206   0.83700   0.79778      6000

    accuracy                        0.74995     60000
   macro avg    0.74970   0.74995   0.74269     60000
weighted avg    0.74970   0.74995   0.74269     60000

confusion matrix
[[3717  125  251  400   62   43   87  381  640  294]
 [ 164 4839  397  288   58   42   62   71   19   60]
 [ 153  114 5176   73    9   37   24   82  280   5

Evaluating: 100%|██████████| 1875/1875 [23:24<00:00,  1.34it/s]


off attention head index // layer: 3, head: 6
classification report
              precision    recall  f1-score   support

           1    0.61825   0.62000   0.61912      6000
           2    0.73713   0.80667   0.77033      6000
           3    0.75138   0.86283   0.80326      6000
           4    0.65238   0.54017   0.59099      6000
           5    0.85003   0.88233   0.86588      6000
           6    0.91939   0.91050   0.91492      6000
           7    0.77040   0.41383   0.53844      6000
           8    0.71217   0.79467   0.75116      6000
           9    0.72518   0.82767   0.77304      6000
          10    0.75963   0.83850   0.79712      6000

    accuracy                        0.74972     60000
   macro avg    0.74959   0.74972   0.74243     60000
weighted avg    0.74959   0.74972   0.74243     60000

confusion matrix
[[3720  128  248  396   62   39   87  377  641  302]
 [ 164 4840  397  287   58   42   60   72   19   61]
 [ 155  113 5177   72    9   38   25   79  281   5

Evaluating: 100%|██████████| 1875/1875 [23:18<00:00,  1.34it/s]


off attention head index // layer: 3, head: 7
classification report
              precision    recall  f1-score   support

           1    0.62342   0.61583   0.61960      6000
           2    0.73796   0.80450   0.76980      6000
           3    0.74935   0.86450   0.80282      6000
           4    0.65110   0.53900   0.58977      6000
           5    0.84999   0.88300   0.86618      6000
           6    0.92012   0.91000   0.91503      6000
           7    0.77321   0.41650   0.54138      6000
           8    0.70524   0.79833   0.74891      6000
           9    0.72234   0.82817   0.77164      6000
          10    0.76282   0.83567   0.79758      6000

    accuracy                        0.74955     60000
   macro avg    0.74955   0.74955   0.74227     60000
weighted avg    0.74955   0.74955   0.74227     60000

confusion matrix
[[3695  123  255  407   59   42   86  393  652  288]
 [ 160 4827  403  291   58   42   58   79   18   64]
 [ 146  111 5187   74    9   35   21   84  283   5

Evaluating: 100%|██████████| 1875/1875 [23:24<00:00,  1.34it/s]


off attention head index // layer: 3, head: 8
classification report
              precision    recall  f1-score   support

           1    0.61998   0.61967   0.61982      6000
           2    0.73810   0.80367   0.76949      6000
           3    0.75102   0.86267   0.80298      6000
           4    0.65118   0.53733   0.58880      6000
           5    0.84988   0.88317   0.86620      6000
           6    0.91735   0.91200   0.91467      6000
           7    0.77144   0.41683   0.54122      6000
           8    0.70972   0.79583   0.75031      6000
           9    0.72332   0.82700   0.77170      6000
          10    0.76141   0.83717   0.79749      6000

    accuracy                        0.74953     60000
   macro avg    0.74934   0.74953   0.74227     60000
weighted avg    0.74934   0.74953   0.74227     60000

confusion matrix
[[3718  125  249  401   60   43   85  384  644  291]
 [ 164 4822  401  293   60   44   63   72   19   62]
 [ 156  112 5176   74    9   38   23   81  281   5

Evaluating: 100%|██████████| 1875/1875 [23:20<00:00,  1.34it/s]


off attention head index // layer: 3, head: 9
classification report
              precision    recall  f1-score   support

           1    0.61961   0.61817   0.61889      6000
           2    0.73810   0.80650   0.77079      6000
           3    0.75091   0.86317   0.80313      6000
           4    0.65232   0.53817   0.58977      6000
           5    0.84999   0.88300   0.86618      6000
           6    0.91894   0.91067   0.91478      6000
           7    0.77104   0.41533   0.53986      6000
           8    0.70846   0.79583   0.74961      6000
           9    0.72280   0.82833   0.77198      6000
          10    0.76185   0.83600   0.79720      6000

    accuracy                        0.74952     60000
   macro avg    0.74940   0.74952   0.74222     60000
weighted avg    0.74940   0.74952   0.74222     60000

confusion matrix
[[3709  124  249  405   61   41   86  388  647  290]
 [ 162 4839  397  287   58   42   60   75   19   61]
 [ 156  109 5179   70    9   37   25   83  282   5

Evaluating: 100%|██████████| 1875/1875 [23:21<00:00,  1.34it/s]


off attention head index // layer: 3, head: 10
classification report
              precision    recall  f1-score   support

           1    0.62114   0.62000   0.62057      6000
           2    0.73619   0.80650   0.76974      6000
           3    0.74967   0.86350   0.80257      6000
           4    0.64858   0.53983   0.58923      6000
           5    0.85076   0.88167   0.86594      6000
           6    0.91926   0.91083   0.91503      6000
           7    0.76883   0.41350   0.53777      6000
           8    0.71100   0.79383   0.75014      6000
           9    0.72423   0.82683   0.77214      6000
          10    0.76190   0.83733   0.79784      6000

    accuracy                        0.74938     60000
   macro avg    0.74916   0.74938   0.74210     60000
weighted avg    0.74916   0.74938   0.74210     60000

confusion matrix
[[3720  129  253  406   61   42   86  376  637  290]
 [ 160 4839  398  288   58   42   63   74   18   60]
 [ 152  114 5181   70    9   37   28   77  282   

Evaluating: 100%|██████████| 1875/1875 [23:21<00:00,  1.34it/s]


off attention head index // layer: 3, head: 11
classification report
              precision    recall  f1-score   support

           1    0.62130   0.61550   0.61839      6000
           2    0.73845   0.80467   0.77014      6000
           3    0.75105   0.86233   0.80286      6000
           4    0.64945   0.54067   0.59009      6000
           5    0.84887   0.88467   0.86640      6000
           6    0.91936   0.91017   0.91474      6000
           7    0.77375   0.41267   0.53826      6000
           8    0.70513   0.79550   0.74759      6000
           9    0.72387   0.82883   0.77280      6000
          10    0.76144   0.83783   0.79781      6000

    accuracy                        0.74928     60000
   macro avg    0.74927   0.74928   0.74191     60000
weighted avg    0.74927   0.74928   0.74191     60000

confusion matrix
[[3693  125  245  410   62   41   86  400  645  293]
 [ 157 4828  405  295   59   41   57   78   19   61]
 [ 155  111 5174   74    9   38   24   83  282   

Evaluating: 100%|██████████| 1875/1875 [23:20<00:00,  1.34it/s]


off attention head index // layer: 3, head: 12
classification report
              precision    recall  f1-score   support

           1    0.61860   0.61983   0.61921      6000
           2    0.73833   0.80650   0.77091      6000
           3    0.75123   0.86417   0.80375      6000
           4    0.65323   0.54000   0.59124      6000
           5    0.84985   0.88483   0.86699      6000
           6    0.92008   0.90950   0.91476      6000
           7    0.77156   0.41600   0.54055      6000
           8    0.71032   0.79367   0.74969      6000
           9    0.72232   0.82850   0.77177      6000
          10    0.76251   0.83533   0.79726      6000

    accuracy                        0.74983     60000
   macro avg    0.74980   0.74983   0.74261     60000
weighted avg    0.74980   0.74983   0.74261     60000

confusion matrix
[[3719  123  248  399   61   43   86  382  648  291]
 [ 165 4839  399  286   59   41   59   72   19   61]
 [ 151  110 5185   73    9   38   23   79  281   

Evaluating: 100%|██████████| 1875/1875 [23:19<00:00,  1.34it/s]


off attention head index // layer: 4, head: 1
classification report
              precision    recall  f1-score   support

           1    0.62523   0.61200   0.61855      6000
           2    0.73632   0.80750   0.77027      6000
           3    0.75062   0.86283   0.80282      6000
           4    0.64876   0.54333   0.59138      6000
           5    0.85105   0.88083   0.86568      6000
           6    0.92042   0.90983   0.91510      6000
           7    0.77312   0.41517   0.54023      6000
           8    0.70486   0.79767   0.74840      6000
           9    0.72215   0.82867   0.77175      6000
          10    0.76196   0.83650   0.79749      6000

    accuracy                        0.74943     60000
   macro avg    0.74945   0.74943   0.74217     60000
weighted avg    0.74945   0.74943   0.74217     60000

confusion matrix
[[3672  127  256  413   59   43   85  400  649  296]
 [ 158 4845  395  292   56   40   57   77   19   61]
 [ 145  116 5177   77    9   38   24   81  284   4

Evaluating: 100%|██████████| 1875/1875 [23:21<00:00,  1.34it/s]


off attention head index // layer: 4, head: 2
classification report
              precision    recall  f1-score   support

           1    0.62141   0.62100   0.62121      6000
           2    0.73729   0.80733   0.77072      6000
           3    0.75178   0.86317   0.80363      6000
           4    0.65361   0.53967   0.59120      6000
           5    0.84920   0.88317   0.86585      6000
           6    0.91941   0.91083   0.91510      6000
           7    0.77119   0.41850   0.54257      6000
           8    0.71011   0.79367   0.74957      6000
           9    0.72416   0.82783   0.77253      6000
          10    0.76237   0.83683   0.79787      6000

    accuracy                        0.75020     60000
   macro avg    0.75005   0.75020   0.74303     60000
weighted avg    0.75005   0.75020   0.74303     60000

confusion matrix
[[3726  124  252  400   62   40   86  380  642  288]
 [ 161 4844  394  286   60   41   59   76   19   60]
 [ 155  113 5179   73    9   36   24   82  279   5

Evaluating: 100%|██████████| 1875/1875 [23:31<00:00,  1.33it/s]


off attention head index // layer: 4, head: 3
classification report
              precision    recall  f1-score   support

           1    0.61921   0.62117   0.62018      6000
           2    0.73733   0.80750   0.77082      6000
           3    0.75335   0.86183   0.80395      6000
           4    0.65179   0.54033   0.59085      6000
           5    0.85062   0.88450   0.86723      6000
           6    0.91968   0.91033   0.91498      6000
           7    0.77277   0.41433   0.53944      6000
           8    0.71010   0.79567   0.75045      6000
           9    0.72378   0.82800   0.77239      6000
          10    0.76176   0.83667   0.79746      6000

    accuracy                        0.75003     60000
   macro avg    0.75004   0.75003   0.74278     60000
weighted avg    0.75004   0.75003   0.74278     60000

confusion matrix
[[3727  125  246  401   58   39   84  385  644  291]
 [ 164 4845  389  290   58   42   59   72   18   63]
 [ 154  114 5171   75    9   37   23   83  283   5

Evaluating: 100%|██████████| 1875/1875 [23:21<00:00,  1.34it/s]


off attention head index // layer: 4, head: 4
classification report
              precision    recall  f1-score   support

           1    0.62010   0.61700   0.61855      6000
           2    0.73824   0.80800   0.77154      6000
           3    0.75109   0.86400   0.80360      6000
           4    0.65256   0.53967   0.59077      6000
           5    0.84977   0.88333   0.86623      6000
           6    0.91937   0.91033   0.91483      6000
           7    0.77114   0.41500   0.53960      6000
           8    0.71029   0.79517   0.75033      6000
           9    0.72448   0.82917   0.77330      6000
          10    0.76090   0.83800   0.79759      6000

    accuracy                        0.74997     60000
   macro avg    0.74979   0.74997   0.74263     60000
weighted avg    0.74979   0.74997   0.74263     60000

confusion matrix
[[3702  127  253  401   60   42   89  383  648  295]
 [ 163 4848  394  288   57   42   57   72   17   62]
 [ 147  111 5184   76    9   37   23   81  281   5

Evaluating: 100%|██████████| 1875/1875 [23:23<00:00,  1.34it/s]


off attention head index // layer: 4, head: 5
classification report
              precision    recall  f1-score   support

           1    0.62151   0.61933   0.62042      6000
           2    0.73710   0.80467   0.76940      6000
           3    0.74795   0.86600   0.80266      6000
           4    0.65080   0.54017   0.59035      6000
           5    0.84938   0.88350   0.86611      6000
           6    0.92002   0.91067   0.91532      6000
           7    0.77115   0.41783   0.54200      6000
           8    0.71240   0.79267   0.75039      6000
           9    0.72473   0.82800   0.77293      6000
          10    0.76253   0.83650   0.79781      6000

    accuracy                        0.74993     60000
   macro avg    0.74976   0.74993   0.74274     60000
weighted avg    0.74976   0.74993   0.74274     60000

confusion matrix
[[3716  126  255  404   61   41   84  379  642  292]
 [ 160 4828  404  295   58   41   61   71   19   63]
 [ 149  109 5196   71    9   36   23   78  278   5

Evaluating: 100%|██████████| 1875/1875 [23:22<00:00,  1.34it/s]


off attention head index // layer: 4, head: 6
classification report
              precision    recall  f1-score   support

           1    0.62119   0.62150   0.62134      6000
           2    0.73886   0.80450   0.77029      6000
           3    0.74859   0.86400   0.80217      6000
           4    0.65090   0.53883   0.58959      6000
           5    0.84873   0.88367   0.86584      6000
           6    0.91665   0.91100   0.91382      6000
           7    0.76854   0.41283   0.53714      6000
           8    0.71221   0.79233   0.75014      6000
           9    0.72445   0.82817   0.77284      6000
          10    0.76094   0.83767   0.79746      6000

    accuracy                        0.74945     60000
   macro avg    0.74911   0.74945   0.74206     60000
weighted avg    0.74911   0.74945   0.74206     60000

confusion matrix
[[3729  121  253  399   61   43   88  379  636  291]
 [ 161 4827  405  294   61   42   61   69   19   61]
 [ 151  111 5184   74    9   38   23   80  281   4

Evaluating: 100%|██████████| 1875/1875 [23:24<00:00,  1.34it/s]


off attention head index // layer: 4, head: 7
classification report
              precision    recall  f1-score   support

           1    0.62134   0.61917   0.62025      6000
           2    0.73667   0.80800   0.77069      6000
           3    0.75098   0.86350   0.80332      6000
           4    0.65386   0.53900   0.59090      6000
           5    0.84942   0.88467   0.86668      6000
           6    0.91988   0.91083   0.91533      6000
           7    0.77185   0.41500   0.53978      6000
           8    0.70927   0.79533   0.74984      6000
           9    0.72292   0.82883   0.77226      6000
          10    0.76293   0.83567   0.79765      6000

    accuracy                        0.75000     60000
   macro avg    0.74991   0.75000   0.74267     60000
weighted avg    0.74991   0.75000   0.74267     60000

confusion matrix
[[3715  126  250  396   62   41   85  385  649  291]
 [ 164 4848  396  281   58   42   57   75   19   60]
 [ 152  113 5181   74    9   38   25   79  280   4

Evaluating: 100%|██████████| 1875/1875 [23:26<00:00,  1.33it/s]


off attention head index // layer: 4, head: 8
classification report
              precision    recall  f1-score   support

           1    0.62058   0.62017   0.62037      6000
           2    0.73822   0.80700   0.77108      6000
           3    0.75051   0.86233   0.80254      6000
           4    0.65387   0.53933   0.59110      6000
           5    0.84974   0.88317   0.86613      6000
           6    0.91950   0.91000   0.91473      6000
           7    0.77238   0.41567   0.54047      6000
           8    0.70950   0.79500   0.74982      6000
           9    0.72314   0.82800   0.77203      6000
          10    0.76007   0.83683   0.79660      6000

    accuracy                        0.74975     60000
   macro avg    0.74975   0.74975   0.74249     60000
weighted avg    0.74975   0.74975   0.74249     60000

confusion matrix
[[3721  126  252  393   61   41   84  385  644  293]
 [ 163 4842  396  288   58   41   58   71   19   64]
 [ 159  113 5174   71    9   37   24   81  281   5

Evaluating: 100%|██████████| 1875/1875 [23:18<00:00,  1.34it/s]


off attention head index // layer: 4, head: 9
classification report
              precision    recall  f1-score   support

           1    0.61948   0.61917   0.61932      6000
           2    0.73999   0.80400   0.77067      6000
           3    0.75058   0.86317   0.80295      6000
           4    0.65235   0.53667   0.58888      6000
           5    0.84946   0.88500   0.86687      6000
           6    0.91997   0.91000   0.91496      6000
           7    0.77082   0.41033   0.53557      6000
           8    0.70670   0.79633   0.74884      6000
           9    0.72209   0.83017   0.77237      6000
          10    0.76078   0.83800   0.79753      6000

    accuracy                        0.74928     60000
   macro avg    0.74922   0.74928   0.74179     60000
weighted avg    0.74922   0.74928   0.74179     60000

confusion matrix
[[3715  119  251  401   62   41   85  387  651  288]
 [ 166 4824  402  288   61   41   55   81   19   63]
 [ 155  108 5179   72    9   37   22   82  284   5

Evaluating: 100%|██████████| 1875/1875 [23:13<00:00,  1.35it/s]


off attention head index // layer: 4, head: 10
classification report
              precision    recall  f1-score   support

           1    0.62077   0.62067   0.62072      6000
           2    0.73775   0.80550   0.77014      6000
           3    0.75141   0.86400   0.80378      6000
           4    0.65245   0.54033   0.59112      6000
           5    0.85000   0.88400   0.86667      6000
           6    0.91849   0.91083   0.91464      6000
           7    0.77290   0.41633   0.54116      6000
           8    0.71038   0.79267   0.74927      6000
           9    0.72247   0.82783   0.77157      6000
          10    0.76222   0.83717   0.79793      6000

    accuracy                        0.74993     60000
   macro avg    0.74988   0.74993   0.74270     60000
weighted avg    0.74988   0.74993   0.74270     60000

confusion matrix
[[3724  127  251  399   61   41   84  378  647  288]
 [ 166 4833  401  287   57   41   59   75   19   62]
 [ 152  110 5184   73    9   37   23   81  281   

Evaluating: 100%|██████████| 1875/1875 [22:34<00:00,  1.38it/s]


off attention head index // layer: 4, head: 11
classification report
              precision    recall  f1-score   support

           1    0.61960   0.61867   0.61913      6000
           2    0.73794   0.80583   0.77040      6000
           3    0.75040   0.86283   0.80270      6000
           4    0.65155   0.53883   0.58986      6000
           5    0.84843   0.88533   0.86649      6000
           6    0.92003   0.91083   0.91541      6000
           7    0.77239   0.41400   0.53906      6000
           8    0.70777   0.79400   0.74841      6000
           9    0.72334   0.82750   0.77192      6000
          10    0.76185   0.83600   0.79720      6000

    accuracy                        0.74938     60000
   macro avg    0.74933   0.74938   0.74206     60000
weighted avg    0.74933   0.74938   0.74206     60000

confusion matrix
[[3712  125  250  403   63   42   83  390  643  289]
 [ 168 4835  399  281   61   41   58   77   19   61]
 [ 153  111 5177   76    9   36   25   80  283   

Evaluating: 100%|██████████| 1875/1875 [22:52<00:00,  1.37it/s]


off attention head index // layer: 4, head: 12
classification report
              precision    recall  f1-score   support

           1    0.62029   0.61750   0.61889      6000
           2    0.73846   0.80517   0.77037      6000
           3    0.75029   0.86233   0.80242      6000
           4    0.65112   0.53967   0.59018      6000
           5    0.85045   0.88333   0.86658      6000
           6    0.91895   0.91083   0.91487      6000
           7    0.76933   0.41633   0.54028      6000
           8    0.71170   0.79200   0.74970      6000
           9    0.72155   0.82967   0.77184      6000
          10    0.76005   0.83783   0.79705      6000

    accuracy                        0.74947     60000
   macro avg    0.74922   0.74947   0.74222     60000
weighted avg    0.74922   0.74947   0.74222     60000

confusion matrix
[[3705  126  250  402   60   43   85  381  650  298]
 [ 163 4831  400  292   58   41   61   73   19   62]
 [ 156  114 5174   70    9   37   24   80  285   

Evaluating: 100%|██████████| 1875/1875 [22:57<00:00,  1.36it/s]


off attention head index // layer: 5, head: 1
classification report
              precision    recall  f1-score   support

           1    0.62350   0.61383   0.61863      6000
           2    0.73823   0.80517   0.77025      6000
           3    0.75145   0.86217   0.80301      6000
           4    0.64731   0.53683   0.58692      6000
           5    0.85176   0.88200   0.86662      6000
           6    0.91804   0.91100   0.91451      6000
           7    0.77333   0.40883   0.53489      6000
           8    0.70123   0.79800   0.74649      6000
           9    0.72229   0.83317   0.77378      6000
          10    0.75943   0.83550   0.79565      6000

    accuracy                        0.74865     60000
   macro avg    0.74866   0.74865   0.74107     60000
weighted avg    0.74866   0.74865   0.74107     60000

confusion matrix
[[3683  122  255  407   61   42   80  401  666  283]
 [ 165 4831  398  289   58   41   58   81   19   60]
 [ 148  115 5173   74    9   39   24   81  286   5

Evaluating: 100%|██████████| 1875/1875 [22:57<00:00,  1.36it/s]


off attention head index // layer: 5, head: 2
classification report
              precision    recall  f1-score   support

           1    0.62279   0.61750   0.62014      6000
           2    0.73780   0.80617   0.77047      6000
           3    0.74964   0.86433   0.80291      6000
           4    0.65156   0.53917   0.59006      6000
           5    0.84971   0.88483   0.86692      6000
           6    0.91892   0.91050   0.91469      6000
           7    0.77116   0.41617   0.54059      6000
           8    0.71119   0.79417   0.75039      6000
           9    0.72246   0.82950   0.77229      6000
          10    0.76138   0.83650   0.79717      6000

    accuracy                        0.74988     60000
   macro avg    0.74966   0.74988   0.74256     60000
weighted avg    0.74966   0.74988   0.74256     60000

confusion matrix
[[3705  125  251  403   62   43   84  381  652  294]
 [ 161 4837  401  293   58   42   62   68   19   59]
 [ 152  109 5186   73    9   37   25   79  281   4

Evaluating: 100%|██████████| 1875/1875 [22:59<00:00,  1.36it/s]


off attention head index // layer: 5, head: 3
classification report
              precision    recall  f1-score   support

           1    0.62319   0.61717   0.62016      6000
           2    0.73810   0.80650   0.77079      6000
           3    0.75109   0.86300   0.80316      6000
           4    0.65293   0.54117   0.59182      6000
           5    0.85058   0.88517   0.86753      6000
           6    0.91840   0.91167   0.91502      6000
           7    0.77225   0.41650   0.54114      6000
           8    0.70639   0.79833   0.74955      6000
           9    0.72428   0.82833   0.77282      6000
          10    0.76419   0.83500   0.79802      6000

    accuracy                        0.75028     60000
   macro avg    0.75014   0.75028   0.74300     60000
weighted avg    0.75014   0.75028   0.74300     60000

confusion matrix
[[3703  125  251  403   62   44   83  392  645  292]
 [ 161 4839  398  286   59   43   60   75   19   60]
 [ 149  111 5178   74    9   37   26   84  284   4

Evaluating: 100%|██████████| 1875/1875 [22:58<00:00,  1.36it/s]


off attention head index // layer: 5, head: 4
classification report
              precision    recall  f1-score   support

           1    0.61936   0.61533   0.61734      6000
           2    0.73847   0.80333   0.76954      6000
           3    0.74741   0.86500   0.80192      6000
           4    0.65175   0.53900   0.59004      6000
           5    0.84891   0.88400   0.86610      6000
           6    0.92037   0.91117   0.91575      6000
           7    0.77037   0.41433   0.53885      6000
           8    0.71210   0.79067   0.74933      6000
           9    0.71950   0.83067   0.77110      6000
          10    0.76144   0.83783   0.79781      6000

    accuracy                        0.74913     60000
   macro avg    0.74897   0.74913   0.74178     60000
weighted avg    0.74897   0.74913   0.74178     60000

confusion matrix
[[3692  127  258  403   58   43   84  381  658  296]
 [ 165 4820  409  287   61   42   63   72   19   62]
 [ 149  107 5190   74    9   37   24   76  284   5

Evaluating: 100%|██████████| 1875/1875 [22:56<00:00,  1.36it/s]


off attention head index // layer: 5, head: 5
classification report
              precision    recall  f1-score   support

           1    0.61761   0.62317   0.62037      6000
           2    0.73769   0.80383   0.76934      6000
           3    0.75011   0.86400   0.80304      6000
           4    0.65053   0.53983   0.59004      6000
           5    0.85107   0.88483   0.86763      6000
           6    0.91993   0.91150   0.91570      6000
           7    0.77064   0.41383   0.53849      6000
           8    0.71127   0.79200   0.74947      6000
           9    0.72546   0.82667   0.77277      6000
          10    0.76088   0.83633   0.79682      6000

    accuracy                        0.74960     60000
   macro avg    0.74952   0.74960   0.74237     60000
weighted avg    0.74952   0.74960   0.74237     60000

confusion matrix
[[3739  123  249  404   59   40   83  379  635  289]
 [ 167 4823  404  296   58   42   58   70   19   63]
 [ 158  110 5184   73    9   35   21   80  279   5

Evaluating: 100%|██████████| 1875/1875 [22:58<00:00,  1.36it/s]


off attention head index // layer: 5, head: 6
classification report
              precision    recall  f1-score   support

           1    0.62032   0.61867   0.61949      6000
           2    0.73899   0.80550   0.77081      6000
           3    0.75138   0.86233   0.80304      6000
           4    0.65144   0.54200   0.59170      6000
           5    0.85041   0.88400   0.86688      6000
           6    0.91987   0.91067   0.91524      6000
           7    0.77194   0.41633   0.54093      6000
           8    0.70921   0.79467   0.74951      6000
           9    0.72212   0.82767   0.77130      6000
          10    0.76173   0.83600   0.79714      6000

    accuracy                        0.74978     60000
   macro avg    0.74974   0.74978   0.74260     60000
weighted avg    0.74974   0.74978   0.74260     60000

confusion matrix
[[3712  126  250  408   60   39   85  385  647  288]
 [ 166 4833  397  292   58   40   59   73   19   63]
 [ 154  112 5174   74    9   38   23   81  284   5

Evaluating: 100%|██████████| 1875/1875 [22:54<00:00,  1.36it/s]


off attention head index // layer: 5, head: 7
classification report
              precision    recall  f1-score   support

           1    0.61849   0.62117   0.61982      6000
           2    0.73964   0.80633   0.77155      6000
           3    0.75211   0.86217   0.80339      6000
           4    0.65391   0.53817   0.59042      6000
           5    0.84827   0.88517   0.86632      6000
           6    0.91937   0.91033   0.91483      6000
           7    0.77222   0.41417   0.53916      6000
           8    0.71127   0.79283   0.74984      6000
           9    0.72176   0.82967   0.77196      6000
          10    0.75968   0.83717   0.79654      6000

    accuracy                        0.74972     60000
   macro avg    0.74967   0.74972   0.74238     60000
weighted avg    0.74967   0.74972   0.74238     60000

confusion matrix
[[3727  122  247  397   59   41   87  380  649  291]
 [ 168 4838  392  289   59   41   56   73   19   65]
 [ 158  109 5173   70    9   39   24   80  285   5

Evaluating:  43%|████▎     | 808/1875 [09:53<12:42,  1.40it/s]

In [None]:
preds = []
true_labels = []

for i in range(len(test_df)):
    # 문장과 정답 레이블 가져오기
    sentence = str(test_df.iloc[i, 1]) + " " + str(test_df.iloc[i, 2]) + " " + str(test_df.iloc[i, 3])
    true_label = test_df.iloc[i, 0]
    
    # 문장을 모델의 입력 형식으로 변환
    inputs = tokenizer(sentence, truncation=True, max_length=512, padding='max_length', return_tensors="pt").to(device)

    # 모델에 입력값 넣기
    outputs = model(**inputs)

    # 모델의 예측 결과를 가져오기
    prediction = outputs.logits.argmax(dim=-1).item() + 1 # prediction은 0부터 시작하기 때문에 1을 더해줍니다.
    
    # 예측 결과와 정답 레이블을 리스트에 추가
    preds.append(prediction)
    true_labels.append(true_label)

    if (i+1) % 100 == 0:
        print(f"{i+1}th done.")

# 각 클래스별 정확도 계산 및 출력
print(classification_report(true_labels, preds))