# Implémenter GradCAM au CNN à la maille caractère

cf le travail de Khaled GracCam.ipynb

In [1]:
from pathlib import Path

current_dir = Path.cwd()  # this points to 'notebooks/' folder
proj_path = current_dir.parent.parent 
print(proj_path)

C:\Users\wenceslas\Documents\cours\ENSAE\2A\Normal\statapp\nlp_understanding


In [2]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader
from deep_nlp.cnncharclassifier import CNNCharClassifier, charToTensor
import pickle

from deep_nlp.grad_cam.utils.letter import rebuild_text, prepare_heatmap, LetterToToken
from deep_nlp.grad_cam.plot import plot_bar_heatmap, plot_text_and_heatmap

In [3]:
train_df= pd.read_csv(r"../../data/01_raw/allocine_train.csv")
test_df= pd.read_csv(r"../../data/01_raw/allocine_test.csv")
valid_df= pd.read_csv(r"../../data/01_raw/allocine_valid.csv")

model_path_saved= r"../../data/06_models/cnn_char_classifier/cnn_char_model/cnn_char_model.pt"

In [4]:
cnn_sequence_len= 1014
cnn_feature_num= 83
cnn_feature_size= 256
cnn_kernel_one= 7
cnn_kernel_two= 3
cnn_stride_one= 1
cnn_stride_two= 3
cnn_output_linear= 1024
cnn_num_class= 2
cnn_dropout= 0.5
cnn_cuda_allow= True

### Load model to CPU method

In [5]:
with open(model_path_saved, 'rb') as f:
    model_saved= pickle.load(f)

test_data= charToTensor(data_df= test_df, sentence_max_size= cnn_sequence_len)

test_load = DataLoader(test_data, batch_size= 1
                      , num_workers=4, pin_memory= True)

test_sentence= next(iter(test_load))

# Initialisation
parameters = {"sequence_len": cnn_sequence_len, "feature_num": cnn_feature_num
    , "feature_size": cnn_feature_size, "kernel_one": cnn_kernel_one
    , "kernel_two": cnn_kernel_two, "stride_one": cnn_stride_one
    , "stride_two": cnn_stride_two, "output_linear": cnn_output_linear
    , "num_class": cnn_num_class, "dropout": cnn_dropout}


model = CNNCharClassifier(**parameters)
model = torch.nn.DataParallel(model)
if cnn_cuda_allow:
    model = torch.nn.DataParallel(model).cuda()
model.load_state_dict(model_saved)
model.eval()

state_dict= model.module.module.state_dict() # delete module to allow cpu loading

cpu_model= CNNCharClassifier(**parameters).cpu()
cpu_model.load_state_dict(state_dict)

cpu_model.eval()

CNNCharClassifier(
  (before_conv): Sequential(
    (conv1_conv): Conv1d(83, 256, kernel_size=(7,), stride=(1,))
    (conv1_relu): ReLU()
  )
  (pool): Sequential(
    (conv1_maxpool): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  )
  (after_conv): Sequential(
    (conv2): Sequential(
      (0): Conv1d(256, 256, kernel_size=(7,), stride=(1,))
      (1): ReLU()
      (2): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
    )
    (conv3): Sequential(
      (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,))
      (1): ReLU()
    )
    (conv4): Sequential(
      (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,))
      (1): ReLU()
    )
    (conv5): Sequential(
      (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,))
      (1): ReLU()
    )
    (conv6): Sequential(
      (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,))
      (1): ReLU()
      (2): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  

In [6]:
pred_test = []
lab = []
reviews= []
alphabet= test_data.get_alphabet()+" "

with torch.no_grad():
    for review, label in test_load:
        pred_test.append(torch.exp(cpu_model(review)))
        lab.append(label.float())
        reviews.append(rebuild_text(text= review
                                 , alphabet= alphabet
                                 , space_index= 83
                                 , sequence_len= cnn_sequence_len))
        
pred_test = torch.cat(pred_test)
lab = torch.cat(lab)

In [7]:
results= pd.DataFrame({
    "review": reviews
    , "label": lab
    , "prediction_prob_1":  pred_test[:,1]
    , "prediction_label": [1.0 if i > 0.5 else 0.0 for i in pred_test[:,1]]
    , "true_review": test_df["review"].values
#     , "heatmap_test_normalized_0": heatmap_test_normalized_0
#     , "heatmap_test_normalized_1": heatmap_test_normalized_1.tolist()
#     , "heatmap_test_max_0": heatmap_test_max_0.tolist()
#     , "heatmap_test_max_1": heatmap_test_max_1.tolist()
})

# del heatmap_test_normalized_1, heatmap_test_max_0, heatmap_test_max_1

print(results.shape)
results.head()

(20000, 5)


Unnamed: 0,review,label,prediction_prob_1,prediction_label,true_review
0,"Magnifique épopée, une belle histoire, touchan...",1.0,0.997272,1.0,"Magnifique épopée, une belle histoire, touchan..."
1,Je n'ai pas aimé mais pourtant je lui mets 2 é...,0.0,0.018032,0.0,Je n'ai pas aimé mais pourtant je lui mets 2 é...
2,Un dessin animé qui brille par sa féerie et se...,1.0,0.804264,1.0,Un dessin animé qui brille par sa féerie et se...
3,"Si c'est là le renouveau du cinéma fran ais, c...",0.0,0.004019,0.0,"Si c'est là le renouveau du cinéma français, c..."
4,Et pourtant on s en Doutait !Second volet très...,0.0,0.000941,0.0,Et pourtant on s’en Doutait !Second volet très...


In [8]:
#define metric
def binary_accuracy(preds, y):
    #round predictions to the closest integer
    rounded_preds = torch.round(preds[:,1])    
    correct = (rounded_preds == y).float() 
    acc = correct.sum() / len(correct)
    
    return acc

acc= binary_accuracy(pred_test, lab.float())
print("ACC on test set : {}".format(acc))

ACC on test set : 0.9117000102996826


### Study model errors

In [9]:
# When the model is wrong what is the MAE & RMSE
mae= np.mean(np.abs(results["label"] - results["prediction_prob_1"]))
print("MAE : {}".format(mae))
rmse= np.sqrt(np.mean((results["label"] - results["prediction_prob_1"])**2))
print("RMSE : {}".format(rmse))

MAE : 0.1163390651345253
RMSE : 0.2559721767902374


In [10]:
threshold= 0.85
result_worse_pred= results[np.abs(results["label"] - results["prediction_prob_1"]) >= threshold]
print(result_worse_pred.shape)
print(result_worse_pred.describe())
result_worse_pred.head()

(670, 5)
            label  prediction_prob_1  prediction_label
count  670.000000         670.000000        670.000000
mean     0.601493           0.412906          0.398507
std      0.489957           0.423401          0.489957
min      0.000000           0.000069          0.000000
25%      0.000000           0.054643          0.000000
50%      1.000000           0.117460          0.000000
75%      1.000000           0.910351          1.000000
max      1.000000           0.999778          1.000000


Unnamed: 0,review,label,prediction_prob_1,prediction_label,true_review
16,Eli Roth est un réalisateur intéressant et red...,1.0,0.13562,0.0,Eli Roth est un réalisateur intéressant et red...
23,"Il y a maintenant presque 50ans, un petit ciné...",0.0,0.962284,1.0,"Il y a maintenant presque 50ans, un petit ciné..."
46,Cécile de France crève l écran en Marquise de ...,1.0,0.083666,0.0,Cécile de France crève l’écran en Marquise de ...
49,"C est dr le, follement dr le car les personnag...",1.0,0.112146,0.0,"C est drôle, follement drôle car les personnag..."
50,"Après la Une de ""Minute"", je ne peux que recom...",1.0,0.056054,0.0,"Après la Une de ""Minute"", je ne peux que recom..."


In [11]:
# rebuild dataloader
worst_data= charToTensor(data_df= result_worse_pred[["review", "label"]].copy(), sentence_max_size= cnn_sequence_len)

worst_load = DataLoader(worst_data, batch_size= 1
                      , num_workers=4, pin_memory= True)

In [23]:
pred_test = []
lab = []
reviews= []
alphabet= test_data.get_alphabet()+" "
heatmap= []

for review, label in worst_load:
    
    label_inv= 0 if label == 1 else 1
    
    pred_test.append(torch.exp(cpu_model(review)))
    lab.append(label.float())
    reviews.append(rebuild_text(text= review
                             , alphabet= alphabet
                             , space_index= 83
                             , sequence_len= cnn_sequence_len))
    heatmap.append(cpu_model.get_heatmap(text= review
                                    , num_class= label_inv
                                    , dim= [0, 2]
                                    , type= "normalized"))
        
pred_test = torch.cat(pred_test)
lab = torch.cat(lab)

tensor([[[ 0.0125,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0181,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0054,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.0364,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0039,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0383,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0000e+00,  0.0000e+00, -3.7210e-03,  ...,  0.0000e+00,
          -5.1706e-05,  0.0000e+00],
         [ 4.4227e-03,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           1.0625e-04,  0.0000e+00],
         [-1.5631e-03,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          -1.3263e-04,  0.0000e+00],
         ...,
         [ 0.0000e+00,  0.0000e+00,  3.7312e-03,  ...,  7.1406e-05,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  4.8222e-04,  0.0000e+00,  ...,  0.0000e+00,
          -1.1420e-04,  0.000

tensor([[[ 0.0123,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.0252,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0111,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000, -0.0299,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.0008,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0322,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 2.5136e-04,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-9.3938e-05,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 9.5664e-05,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         ...,
         [-4.4626e-04,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 4.4336e-05,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.000

tensor([[[-0.0092,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0235,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0080,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0295,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0003,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0269,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0012,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0026,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0007,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.0030,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0005,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0033,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0000,  0.0000, -0.0015,  ...,  0.0000,  0.0000,  

tensor([[[ 0.0000,  0.0276,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0243,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0125,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000, -0.0490,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0111,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0558,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0415,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0600,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0576,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000, -0.0625,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0175,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0613,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0000,  0.0021,  0.0000,  ...,  0.0000,  0.0000,  

tensor([[[-0.0151,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0350,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0170,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0334,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0050,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0458,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0122,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0095,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0019,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.0209,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0115,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0145,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0000,  0.0000, -0.0691,  ...,  0.0000,  0.0000,  

tensor([[[ 2.1753e-02,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  1.7536e-04],
         [-4.8951e-02,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00, -3.8245e-05],
         [ 0.0000e+00,  0.0000e+00,  4.3756e-02,  ...,  0.0000e+00,
           6.6937e-05,  0.0000e+00],
         ...,
         [-7.0753e-02,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00, -1.2660e-04],
         [ 0.0000e+00,  3.2146e-02,  0.0000e+00,  ...,  0.0000e+00,
           7.1954e-05,  0.0000e+00],
         [-8.5872e-02,  0.0000e+00,  0.0000e+00,  ..., -5.0043e-05,
           0.0000e+00,  0.0000e+00]]])
torch.Size([1, 256, 1008])
tensor([[[-7.4314e-05,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 7.2705e-05,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00, -7.8583e-05,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         ...

tensor([[[ 0.0176,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0481,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0228,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.0488,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0085,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0815,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0000,  0.0000,  0.0034,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0139,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0041,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.0161,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.0012,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.0237,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0000,  0.0000,  0.0057,  ...,  0.0000,  0.0000,  

tensor([[[ 0.0000,  0.0024,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0036,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0040,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000, -0.0049,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0004,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0082,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0000e+00,  0.0000e+00,  6.2361e-03,  ...,  1.7890e-05,
           0.0000e+00,  0.0000e+00],
         [-1.0133e-02,  0.0000e+00,  0.0000e+00,  ..., -5.9102e-05,
           0.0000e+00,  0.0000e+00],
         [ 9.4659e-03,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  1.1157e-04],
         ...,
         [-1.5731e-02,  0.0000e+00,  0.0000e+00,  ..., -1.3115e-04,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  1.5478e-03,  ...,  1.8854e-04,
           0.0000e+00,  0.000

tensor([[[ 0.0000,  0.0000, -0.0023,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0050,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0009,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0047,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0012,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0058,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0113,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0161,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0060,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.0147,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0034,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0200,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[-5.1600e-03,  0.0000e+00,  0.0000e+00,  ..., -1.4718

tensor([[[-0.0047,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0057,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0029,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0110,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.0003,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0107,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0000e+00,  0.0000e+00, -2.9082e-03,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 7.3191e-03,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00, -2.8676e-03,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         ...,
         [ 8.9390e-03,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  9.5776e-05,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.000

tensor([[[ 0.0000,  0.0035,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.0025,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0023,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.0061,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0001,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0057,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[-4.7416e-04,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 1.4789e-03,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00, -6.9467e-04,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         ...,
         [ 2.7799e-03,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  6.4339e-05,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.000

tensor([[[ 0.0136,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0076,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0027,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.0204,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0053,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0185,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[-0.0054,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0102,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.0055,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0181,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0022,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0183,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[-0.0064,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  

tensor([[[-0.0014,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0025,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.0011,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0032,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0010,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0049,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0137,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0230,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0145,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000, -0.0256,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0097,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0320,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[-0.0070,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  

tensor([[[ 0.0000, -0.0261,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0441,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0183,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0618,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0030,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0660,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0000e+00,  1.0589e-03,  0.0000e+00,  ...,  1.0411e-05,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00, -7.7639e-04,  0.0000e+00,  ...,  0.0000e+00,
          -3.9248e-06,  0.0000e+00],
         [ 5.3086e-04,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  1.1625e-05],
         ...,
         [-2.2640e-03,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00, -1.6294e-05],
         [ 4.6583e-04,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  1.140

tensor([[[ 0.0000,  0.0000, -0.0191,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0089,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0009,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000,  0.0246,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0062,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0310,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0048,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0099,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0035,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.0067,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0009,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0107,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0516,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  

tensor([[[-1.3379e-02,  0.0000e+00,  0.0000e+00,  ..., -7.3866e-06,
           0.0000e+00,  0.0000e+00],
         [ 5.2096e-02,  0.0000e+00,  0.0000e+00,  ..., -2.1965e-06,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00, -1.0896e-02,  ...,  0.0000e+00,
           0.0000e+00, -1.1700e-05],
         ...,
         [ 7.2415e-02,  0.0000e+00,  0.0000e+00,  ...,  6.9727e-06,
           0.0000e+00,  0.0000e+00],
         [-8.0285e-04,  0.0000e+00,  0.0000e+00,  ..., -5.4767e-06,
           0.0000e+00,  0.0000e+00],
         [ 4.9009e-02,  0.0000e+00,  0.0000e+00,  ..., -6.0702e-06,
           0.0000e+00,  0.0000e+00]]])
torch.Size([1, 256, 1008])
tensor([[[-2.4135e-03,  0.0000e+00,  0.0000e+00,  ..., -3.7177e-05,
           0.0000e+00,  0.0000e+00],
         [ 6.3868e-03,  0.0000e+00,  0.0000e+00,  ...,  2.2299e-05,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00, -1.9173e-03,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00, -1.0845e-04],
         ...

tensor([[[ 0.0035,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0083,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0030,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000, -0.0094,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0022,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0127,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[-0.0015,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0059,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0022,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0072,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0014,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0077,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0017,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  

tensor([[[-2.6993e-03,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00, -1.5490e-05],
         [ 0.0000e+00,  8.1634e-03,  0.0000e+00,  ..., -3.1728e-05,
           0.0000e+00,  0.0000e+00],
         [-4.3291e-03,  0.0000e+00,  0.0000e+00,  ...,  1.2218e-04,
           0.0000e+00,  0.0000e+00],
         ...,
         [ 0.0000e+00,  8.6491e-03,  0.0000e+00,  ...,  2.6358e-05,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  1.0522e-03,  ...,  0.0000e+00,
           4.5487e-05,  0.0000e+00],
         [ 1.2357e-02,  0.0000e+00,  0.0000e+00,  ..., -4.7460e-06,
           0.0000e+00,  0.0000e+00]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0038,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0086,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0034,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000, -0.0101,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0015,  0.0000,  

tensor([[[ 0.0030,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0019,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0014,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.0040,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0003,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0051,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 2.3485e-03,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  5.2009e-05],
         [-2.3639e-03,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  6.4676e-06],
         [ 1.6884e-03,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          -1.6272e-04,  0.0000e+00],
         ...,
         [ 0.0000e+00,  0.0000e+00, -4.1037e-03,  ...,  0.0000e+00,
           6.0297e-05,  0.0000e+00],
         [ 1.2742e-03,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          -9.6524e-05,  0.000

tensor([[[ 0.0435,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.0683,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0331,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.1024,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0013,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.1107,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0002,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0002,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0001,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000, -0.0004,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0001,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0003,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0000, -0.0113,  0.0000,  ...,  0.0000,  0.0000,  

tensor([[[ 0.0246,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.0868,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0295,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.0861,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0007,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.1054,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0026,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.0077,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0020,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.0075,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0007,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0130,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0059,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  

tensor([[[ 0.0000e+00,  0.0000e+00, -1.7616e-02,  ...,  0.0000e+00,
           0.0000e+00, -3.7574e-05],
         [ 0.0000e+00,  4.2882e-02,  0.0000e+00,  ...,  6.1875e-05,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00, -9.4140e-03,  ...,  0.0000e+00,
          -5.6477e-05,  0.0000e+00],
         ...,
         [ 0.0000e+00,  3.8959e-02,  0.0000e+00,  ...,  1.2564e-04,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  9.1949e-03,  ..., -8.1229e-05,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  4.3864e-02,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  1.4246e-04]]])
torch.Size([1, 256, 1008])
tensor([[[-2.7419e-02,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00, -1.4968e-07],
         [ 1.1694e-02,  0.0000e+00,  0.0000e+00,  ..., -3.4640e-09,
           0.0000e+00,  0.0000e+00],
         [-1.3156e-02,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          -1.6715e-07,  0.0000e+00],
         ...

tensor([[[ 6.7825e-04,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-8.2580e-04,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 1.6302e-05,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         ...,
         [ 0.0000e+00, -1.1056e-03,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00, -1.0603e-04,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-1.2346e-03,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]]])
torch.Size([1, 256, 1008])
tensor([[[-1.1939e-04,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 1.4406e-04,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00, -1.1963e-04,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         ...

tensor([[[ 0.0038,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0109,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0108,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000, -0.0134,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0005,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0127,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[-0.0157,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0215,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.0098,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0350,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0003,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0302,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[-0.0349,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  

tensor([[[ 0.0533,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0636,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0225,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000, -0.0936,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.0076,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0890,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0175,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0096,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0073,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.0246,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0046,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0144,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0000e+00, -3.0387e-04,  0.0000e+00,  ...,  2.8801

tensor([[[ 0.0105,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0125,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0004,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000, -0.0202,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0011,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0143,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 5.8623e-04,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           4.6503e-05,  0.0000e+00],
         [-1.0988e-03,  0.0000e+00,  0.0000e+00,  ...,  2.1332e-05,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  1.8350e-04,  ..., -4.3380e-06,
           0.0000e+00,  0.0000e+00],
         ...,
         [-1.4746e-03,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00, -2.1024e-05],
         [ 0.0000e+00,  4.9729e-04,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  3.563

tensor([[[ 0.0113,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0188,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0157,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.0243,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.0043,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0315,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[-0.0215,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0474,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0098,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0654,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0172,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0553,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0000,  0.0000,  0.0407,  ...,  0.0000,  0.0000,  

tensor([[[ 0.0010,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0007,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0004,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.0019,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0005,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0017,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[-0.0401,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.1185,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0480,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.1601,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0105,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.1752,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[-0.0187,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  

tensor([[[ 0.0000,  0.0000, -0.0108,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0065,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0009,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0318,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.0004,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0298,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[-4.3268e-03,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          -6.6088e-06,  0.0000e+00],
         [ 1.4024e-02,  0.0000e+00,  0.0000e+00,  ..., -9.0024e-06,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00, -1.3672e-03,  0.0000e+00,  ..., -1.8634e-05,
           0.0000e+00,  0.0000e+00],
         ...,
         [ 1.8431e-02,  0.0000e+00,  0.0000e+00,  ...,  2.1662e-05,
           0.0000e+00,  0.0000e+00],
         [-4.2879e-03,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00, -2.311

tensor([[[ 0.0000e+00,  0.0000e+00, -6.8169e-03,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 6.7582e-03,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00, -5.7403e-03,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         ...,
         [ 0.0000e+00,  0.0000e+00,  1.2357e-02,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00, -3.2927e-05,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 1.3758e-02,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0000e+00, -8.2291e-03,  0.0000e+00,  ..., -1.5623e-04,
           0.0000e+00,  0.0000e+00],
         [ 2.0203e-03,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00, -1.7270e-05],
         [ 0.0000e+00,  0.0000e+00, -2.6562e-03,  ...,  0.0000e+00,
          -1.2528e-04,  0.0000e+00],
         ...

tensor([[[ 0.0000,  0.0062,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0302,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0049,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.0339,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0108,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0346,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0004,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.0002,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0002,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.0004,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0002,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0007,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0000e+00,  1.0150e-04,  0.0000e+00,  ...,  0.0000

tensor([[[ 0.0031,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0028,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0010,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.0041,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0004,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0051,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0538,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0421,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0066,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000, -0.0824,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0032,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0808,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[-0.0021,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  

tensor([[[ 2.3686e-03,  0.0000e+00,  0.0000e+00,  ...,  9.7761e-07,
           0.0000e+00,  0.0000e+00],
         [-4.4491e-03,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          -4.1243e-05,  0.0000e+00],
         [ 0.0000e+00,  4.1413e-03,  0.0000e+00,  ...,  4.0355e-05,
           0.0000e+00,  0.0000e+00],
         ...,
         [-5.7084e-03,  0.0000e+00,  0.0000e+00,  ..., -7.0388e-05,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  8.2343e-04,  ...,  8.6408e-05,
           0.0000e+00,  0.0000e+00],
         [-7.7513e-03,  0.0000e+00,  0.0000e+00,  ..., -5.0981e-05,
           0.0000e+00,  0.0000e+00]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0116,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0322,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0233,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.0356,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  

tensor([[[ 0.0018,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0104,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0026,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.0093,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.0011,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0137,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[-0.0124,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0157,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0129,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000,  0.0328,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.0077,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0233,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0000,  0.0050,  0.0000,  ...,  0.0000,  0.0000,  

tensor([[[-0.0354,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0676,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0128,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0859,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0008,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0820,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0000, -0.0136,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0204,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0125,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0403,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0038,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0393,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0099,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  

tensor([[[ 0.0000e+00,  0.0000e+00, -1.4046e-04,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 2.4152e-04,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00, -8.3106e-05,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         ...,
         [ 2.7335e-04,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 2.3027e-05,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 2.8414e-04,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]]])
torch.Size([1, 256, 1008])
tensor([[[ 3.3884e-03,  0.0000e+00,  0.0000e+00,  ..., -5.6826e-05,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00, -3.9093e-03,  ...,  0.0000e+00,
           0.0000e+00, -3.8207e-05],
         [ 1.4374e-03,  0.0000e+00,  0.0000e+00,  ...,  1.3148e-04,
           0.0000e+00,  0.0000e+00],
         ...

tensor([[[ 0.0107,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0118,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0050,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.0144,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0024,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0066,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[-4.3155e-03,  0.0000e+00,  0.0000e+00,  ..., -6.6527e-06,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  3.3886e-03,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  2.4933e-05],
         [ 0.0000e+00,  0.0000e+00, -1.5418e-03,  ..., -3.9245e-05,
           0.0000e+00,  0.0000e+00],
         ...,
         [ 5.5594e-03,  0.0000e+00,  0.0000e+00,  ...,  1.9270e-05,
           0.0000e+00,  0.0000e+00],
         [-5.0635e-04,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00, -3.032

tensor([[[ 1.8777e-03,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00, -1.3375e-02,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00, -7.4754e-05,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         ...,
         [ 0.0000e+00,  0.0000e+00, -8.9383e-03,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  2.4246e-04,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-1.6864e-02,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]]])
torch.Size([1, 256, 1008])
tensor([[[-0.0220,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0456,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0024,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0510,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  

tensor([[[-3.5122e-03,  0.0000e+00,  0.0000e+00,  ..., -2.3023e-06,
           0.0000e+00,  0.0000e+00],
         [ 4.6031e-03,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00, -1.5112e-06],
         [ 0.0000e+00,  0.0000e+00, -3.8883e-04,  ...,  0.0000e+00,
           0.0000e+00, -4.0207e-05],
         ...,
         [ 6.9950e-03,  0.0000e+00,  0.0000e+00,  ...,  3.7958e-05,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00, -3.8900e-03,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00, -3.8108e-05],
         [ 7.6137e-03,  0.0000e+00,  0.0000e+00,  ...,  2.6456e-05,
           0.0000e+00,  0.0000e+00]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0000,  0.0035,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0034,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0008,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.0102,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  

tensor([[[-0.0060,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0171,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0029,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0148,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0020,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0189,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[-0.0005,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0011,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0002,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0013,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0004,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0015,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
torch.Size([1, 256, 1008])
tensor([[[ 0.0000,  0.0000, -0.0010,  ...,  0.0000,  0.0000,  

In [24]:
review_rebuild= result_worse_pred["review"].values
tokens= []
cleaned_tokens= []
heatmap_test= []
i= 0
for r, h in zip(review_rebuild, heatmap):
    prepared_heatmap= prepare_heatmap(heatmap= h, text= r)
    letter_to_token= LetterToToken(text= r
                               , heatmap= prepared_heatmap)
    try:
        results_dict= letter_to_token.transform_letter_to_token(type= "tanh")
        
        tokens.append(results_dict["tokens"])
        cleaned_tokens.append(results_dict["cleaned_tokens"])
        heatmap_test.append(results_dict["heatmap"])
    except:
        print(r)
        print(i)
        print(prepared_heatmap)
    i+=1
#     tokens= results_dict["tokens"]
#     heatmap_test= results_dict["heatmap"]

In [28]:
%matplotlib
indexed_test= 0
seuil_heatmap= 0.75
h= heatmap_test[indexed_test]
plot_text_and_heatmap(text= tokens[indexed_test]
                      , heatmap= np.where(np.abs(h) >= seuil_heatmap, h, 0)
                      , figsize=(7, 7)
                      , cmap= "PiYG"
                      , word_or_letter= "word")
pred_test[indexed_test]

Using matplotlib backend: TkAgg


tensor([0.8644, 0.1356], grad_fn=<SelectBackward>)

In [29]:
plot_text_and_heatmap(text= tokens[indexed_test]
                      , heatmap= heatmap_test[indexed_test]
                      , figsize=(7, 7)
                      , cmap= "PiYG"
                      , word_or_letter= "word")
pred_test[indexed_test]

tensor([0.8644, 0.1356], grad_fn=<SelectBackward>)

### Compute GradCam on special prediction case