In [None]:
import numpy as np

import os
import sys
import json


os.environ["FLAGS_allocator_strategy"] = 'auto_growth'

import paddle

from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list
import tools.program as program
from Visualizer.visualizer import get_local

from sys import argv
#查看当前的argv列表
 
#也可以添加新的参数
argv[1]='-c'
if len(argv) <= 2:
    argv.append('./configs/rec/rec_cloformer_cppd.yml')
else:
    argv[2]='./configs/rec/rec_cloformer_cppd.yml'
if len(argv) > 3:
    del argv[3:]
print(argv)

get_local.clear()
get_local.activate()


In [None]:
# import torch
# import torchvision.transforms as T
# from timm.models.vision_transformer import vit_small_patch16_224
import json
from PIL import Image, ImageDraw
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def grid_show(to_shows, cols):
    rows = (len(to_shows)-1) // cols + 1
    it = iter(to_shows)
    fig, axs = plt.subplots(rows, cols, figsize=(rows*8.5, cols*2))
    for i in range(rows):
        for j in range(cols):
            try:
                image, title = next(it)
            except StopIteration:
                image = np.zeros_like(to_shows[0][0])
                title = 'pad'
            axs[i, j].imshow(image)
            axs[i, j].set_title(title)
            axs[i, j].set_yticks([])
            axs[i, j].set_xticks([])
    plt.show()

def visualize_head(att_map):
    ax = plt.gca()
    # Plot the heatmap
    im = ax.imshow(att_map)
    # Create colorbar
    cbar = ax.figure.colorbar(im, ax=ax)
    plt.show()
    
def visualize_heads(att_map, cols):
    to_shows = []
    att_map = att_map.squeeze()
    for i in range(att_map.shape[0]):
        to_shows.append((att_map[i], f'Head {i}'))
    average_att_map = att_map.mean(axis=0)
    to_shows.append((average_att_map, 'Head Average'))
    grid_show(to_shows, cols=cols)

def gray2rgb(image):
    return np.repeat(image[...,np.newaxis],3,2)
    
def cls_padding(image, mask, cls_weight, grid_size):
    if not isinstance(grid_size, tuple):
        grid_size = (grid_size, grid_size)
        
    image = np.array(image)

    H, W = image.shape[:2]
    delta_H = int(H/grid_size[0])
    delta_W = int(W/grid_size[1])
    
    padding_w = delta_W
    padding_h = H
    padding = np.ones_like(image) * 255
    padding = padding[:padding_h, :padding_w]
    
    padded_image = np.hstack((padding,image))
    padded_image = Image.fromarray(padded_image)
    draw = ImageDraw.Draw(padded_image)
    draw.text((int(delta_W/4),int(delta_H/4)),'CLS', fill=(0,0,0)) # PIL.Image.size = (W,H) not (H,W)

    mask = mask / max(np.max(mask),cls_weight)
    cls_weight = cls_weight / max(np.max(mask),cls_weight)
    
    if len(padding.shape) == 3:
        padding = padding[:,:,0]
        padding[:,:] = np.min(mask)
    mask_to_pad = np.ones((1,1)) * cls_weight
    mask_to_pad = Image.fromarray(mask_to_pad)
    mask_to_pad = mask_to_pad.resize((delta_W, delta_H))
    mask_to_pad = np.array(mask_to_pad)

    padding[:delta_H,  :delta_W] = mask_to_pad
    padded_mask = np.hstack((padding, mask))
    padded_mask = padded_mask
    
    meta_mask = np.zeros((padded_mask.shape[0], padded_mask.shape[1],4))
    meta_mask[delta_H:,0: delta_W, :] = 1 
    
    return padded_image, padded_mask, meta_mask
    

def visualize_grid_to_grid_with_cls(att_map, grid_index, image, grid_size=7, alpha=0.6):
    if not isinstance(grid_size, tuple):
        grid_size = (7, 9)
    
    attention_map = att_map[grid_index]
    cls_weight = attention_map[0]
    
    mask = attention_map[1:].reshape(grid_size[0], grid_size[1])
    mask = Image.fromarray(mask).resize((image.size))
    
    padded_image ,padded_mask, meta_mask = cls_padding(image, mask, cls_weight, grid_size)
    
    if grid_index != 0: # adjust grid_index since we pad our image
        grid_index = grid_index + (grid_index-1) // grid_size[1]
        
    grid_image = highlight_grid(padded_image, [grid_index], (grid_size[0], grid_size[1]+1))
    
    fig, ax = plt.subplots(1, 2, figsize=(10,7))
    fig.tight_layout()
    
    ax[0].imshow(grid_image)
    ax[0].axis('off')
    
    ax[1].imshow(grid_image)
    ax[1].imshow(padded_mask, alpha=alpha, cmap='jet')
    ax[1].imshow(meta_mask)
    ax[1].axis('off')
    

def visualize_grid_to_grid(att_map, grid_index, image, grid_size=8, alpha=0.6):
    if not isinstance(grid_size, tuple):
        grid_size = (grid_size, grid_size)
        # grid_size = (7, 9)
    
    H,W = att_map.shape
    with_cls_token = False
      
    # grid_image = highlight_grid(image, [grid_index], grid_size)
    
    mask = att_map[grid_index].reshape(grid_size[0], grid_size[1])
    mask = Image.fromarray(mask).resize((image.size))
    
    fig, ax = plt.subplots(1, 2, figsize=(10,7))
    fig.tight_layout()
    
    ax[0].imshow(image)
    ax[0].axis('off')
    
    ax[1].imshow(image)
    ax[1].imshow(mask/np.max(mask), alpha=alpha, cmap='jet')
    ax[1].axis('off')
    plt.show()
    
def highlight_grid(image, grid_indexes, grid_size=14):
    if not isinstance(grid_size, tuple):
        grid_size = (grid_size, grid_size)
    
    W, H = image.size
    h = H / grid_size[0]
    w = W / grid_size[1]
    image = image.copy()
    for grid_index in grid_indexes:
        x, y = np.unravel_index(grid_index, (grid_size[0], grid_size[1]))
        a= ImageDraw.ImageDraw(image)
        a.rectangle([(y*w,x*h),(y*w+w,x*h+h)],fill =None,outline ='red',width =2)
    return image

In [None]:
def main():
    global_config = config['Global']

    # build post process
    post_process_class = build_post_process(config['PostProcess'],
                                            global_config)

    # build model
    if hasattr(post_process_class, 'character'):
        char_num = len(getattr(post_process_class, 'character'))
        if config["Architecture"]["algorithm"] in ["Distillation",
                                                   ]:  # distillation model
            for key in config["Architecture"]["Models"]:
                if config["Architecture"]["Models"][key]["Head"][
                        "name"] == 'MultiHead':  # multi head
                    out_channels_list = {}
                    if config['PostProcess'][
                            'name'] == 'DistillationSARLabelDecode':
                        char_num = char_num - 2
                    if config['PostProcess'][
                            'name'] == 'DistillationNRTRLabelDecode':
                        char_num = char_num - 3
                    out_channels_list['CTCLabelDecode'] = char_num
                    out_channels_list['SARLabelDecode'] = char_num + 2
                    out_channels_list['NRTRLabelDecode'] = char_num + 3
                    config['Architecture']['Models'][key]['Head'][
                        'out_channels_list'] = out_channels_list
                else:
                    config["Architecture"]["Models"][key]["Head"][
                        "out_channels"] = char_num
        elif config['Architecture']['Head'][
                'name'] == 'MultiHead':  # multi head
            out_channels_list = {}
            char_num = len(getattr(post_process_class, 'character'))
            if config['PostProcess']['name'] == 'SARLabelDecode':
                char_num = char_num - 2
            if config['PostProcess']['name'] == 'NRTRLabelDecode':
                char_num = char_num - 3
            out_channels_list['CTCLabelDecode'] = char_num
            out_channels_list['SARLabelDecode'] = char_num + 2
            out_channels_list['NRTRLabelDecode'] = char_num + 3
            config['Architecture']['Head'][
                'out_channels_list'] = out_channels_list
        else:  # base rec model
            config["Architecture"]["Head"]["out_channels"] = char_num
    model = build_model(config['Architecture'])

    load_model(config, model)

    # create data ops
    transforms = []
    for op in config['Eval']['dataset']['transforms']:
        op_name = list(op)[0]
        if 'Label' in op_name:
            continue
        elif op_name in ['RecResizeImg']:
            op[op_name]['infer_mode'] = True
        elif op_name == 'KeepKeys':
            if config['Architecture']['algorithm'] == "SRN":
                op[op_name]['keep_keys'] = [
                    'image', 'encoder_word_pos', 'gsrm_word_pos',
                    'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
                ]
            elif config['Architecture']['algorithm'] == "SAR":
                op[op_name]['keep_keys'] = ['image', 'valid_ratio']
            elif config['Architecture']['algorithm'] == "RobustScanner":
                op[op_name][
                    'keep_keys'] = ['image', 'valid_ratio', 'word_positons']
            else:
                op[op_name]['keep_keys'] = ['image']
        transforms.append(op)
    global_config['infer_mode'] = True
    ops = create_operators(transforms, global_config)

    save_res_path = config['Global'].get('save_res_path',
                                         "./output/rec/predicts_rec.txt")
    if not os.path.exists(os.path.dirname(save_res_path)):
        os.makedirs(os.path.dirname(save_res_path))

    model.eval()
    
    infer_imgs = config['Global']['infer_img']
    infer_list = config['Global'].get('infer_list', None)
    with open(save_res_path, "w") as fout:
        for file in get_image_file_list(infer_imgs, infer_list=infer_list):
            logger.info("infer_img: {}".format(file))
            with open(file, 'rb') as f:
                img = f.read()
                data = {'image': img}
            batch = transform(data, ops)
            if config['Architecture']['algorithm'] == "SRN":
                encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
                gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
                gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
                gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)

                others = [
                    paddle.to_tensor(encoder_word_pos_list),
                    paddle.to_tensor(gsrm_word_pos_list),
                    paddle.to_tensor(gsrm_slf_attn_bias1_list),
                    paddle.to_tensor(gsrm_slf_attn_bias2_list)
                ]
            if config['Architecture']['algorithm'] == "SAR":
                valid_ratio = np.expand_dims(batch[-1], axis=0)
                img_metas = [paddle.to_tensor(valid_ratio)]
            if config['Architecture']['algorithm'] == "RobustScanner":
                valid_ratio = np.expand_dims(batch[1], axis=0)
                word_positons = np.expand_dims(batch[2], axis=0)
                img_metas = [
                    paddle.to_tensor(valid_ratio),
                    paddle.to_tensor(word_positons),
                ]
            if config['Architecture']['algorithm'] == "CAN":
                image_mask = paddle.ones(
                    (np.expand_dims(
                        batch[0], axis=0).shape), dtype='float32')
                label = paddle.ones((1, 36), dtype='int64')
            images = np.expand_dims(batch[0], axis=0)
            images = paddle.to_tensor(images)
            if config['Architecture']['algorithm'] == "SRN":
                preds = model(images, others)
            elif config['Architecture']['algorithm'] == "SAR":
                preds = model(images, img_metas)
            elif config['Architecture']['algorithm'] == "RobustScanner":
                preds = model(images, img_metas)
            elif config['Architecture']['algorithm'] == "CAN":
                preds = model([images, image_mask, label])
            else:
                preds = model(images)
            post_result = post_process_class(preds)
            info = None
            if isinstance(post_result, dict):
                rec_info = dict()
                for key in post_result:
                    if len(post_result[key][0]) >= 2:
                        rec_info[key] = {
                            "label": post_result[key][0][0],
                            "score": float(post_result[key][0][1]),
                        }
                info = json.dumps(rec_info, ensure_ascii=False)
            elif isinstance(post_result, list) and isinstance(post_result[0],
                                                              int):
                # for RFLearning CNT branch 
                info = str(post_result[0])
            else:
                if len(post_result[0]) >= 2:
                    info = post_result[0][0] + "\t" + str(post_result[0][1])

            if info is not None:
                logger.info("\t result: {}".format(info))
                fout.write(file + "\t" + info + "\n")
    logger.info("success!")
    
config, device, logger, vdl_writer = program.preprocess()
main()



In [None]:
image = Image.open('./train_data/common_benchmarks/CUTE80/imgs/cute_226.jpg')
image

In [None]:
cache = get_local.cache
print(list(cache.keys()))
attention_maps = cache['EdgeDecoderLayer.forward']
print(len(attention_maps))
attention_maps[0].shape

In [None]:
visualize_grid_to_grid_with_cls(attention_maps[0][0,1,:,:], 2, image)

In [None]:
visualize_grid_to_grid(attention_maps[0][0,5,:,:], 1, image)

In [None]:
visualize_head(attention_maps[0][0,0])

In [None]:
print(attention_maps[3][0,0].shape)
visualize_head(attention_maps[3][0,0].transpose((1,0)))

In [None]:
visualize_heads(attention_maps[2], cols=8)

In [None]:
visualize_heads(attention_maps[3], cols=8)

In [None]:
import cv2
# 假设attn是你已经提取的注意力图，shape为(num_heads, height, width)
# 假设original_image是一个PIL图像或者一个numpy数组

# 此函数将归一化和重塑注意力图以匹配原图大小
def resize_attention(attn, target_size):
    # 假设attn是一个平均化的注意力图
    attn = attn.mean(0)  # 取所有头的平均值，如果你想要可视化特定的头，可以单独提取
    attn = attn - attn.min()  # Min-max normalization
    attn = attn / attn.max()
    attn = Image.fromarray(attn)  # 转换为PIL图像以便调整大小
    attn = attn.resize(target_size, Image.BILINEAR)  # 重置大小
    attn = np.asarray(attn)  # 转换回numpy数组
    return attn

# 将注意力图调整至原图大小
attn = resize_attention(attention_maps[0][0][0], image.size)

# 转换原图为numpy数组
image = Image.open('./train_data/common_benchmarks/CUTE80/imgs/cute_226.jpg')
original_image = np.array(image)

# 创建一个与原图大小相同的彩色热图
heatmap = plt.get_cmap('jet')(attn)[:, :, :3]  # 获取RGB颜色
heatmap = (heatmap * 255).astype(np.uint8)  # 将颜色值范围调整到0-255

# 将热图以一定透明度叠加到原图上
alpha = 0.6  # 设置热图的透明度
overlay = cv2.addWeighted(original_image, alpha, heatmap, 1 - alpha, 0)

# 显示叠加图
plt.imshow(overlay)
plt.axis('off')  # 不显示坐标轴
plt.show()


In [None]:
import cv2

# 此函数将归一化和重塑注意力图以匹配原图大小
def resize_attention(attn, target_size):
    # 假设attn是一个平均化的注意力图
    attn = attn.mean(0)  # 取所有头的平均值，如果你想要可视化特定的头，可以单独提取
    attn = attn - attn.min()  # Min-max normalization
    attn = attn / attn.max()
    attn = Image.fromarray(attn)  # 转换为PIL图像以便调整大小
    attn = attn.resize(target_size, Image.BILINEAR)  # 重置大小
    attn = np.asarray(attn)  # 转换回numpy数组
    return attn

def visual_att(head, img_file, save_file):
    cache = get_local.cache
    attention_maps = cache['EdgeDecoderLayer.forward']
    map_0 = attention_maps[0]
    map_0 = map_0.reshape((1, 12, 26, 2, 32))
    heatmap = map_0.squeeze()

    image = Image.open(img_file)
    original_image = np.array(image).transpose(1,0,2)

    maps = []
    for i in range(0,16):
        # print(i)
        H,W = image.size
        #! -----------
        #! config head
        #! -----------
        attn = resize_attention(heatmap[head][int(i)], (W,H))
        # print(attn.shape)

        # 创建一个与原图大小相同的彩色热图
        heatmap_i = plt.get_cmap('jet')(attn)[:, :, :3]  # 获取RGB颜色
        heatmap_i = (heatmap_i * 255).astype(np.uint8)  # 将颜色值范围调整到0-255
        # print(heatmap_i.shape)

        # 将热图以一定透明度叠加到原图上
        alpha = 0.6  # 设置热图的透明度
        overlay = cv2.addWeighted(original_image, alpha, heatmap_i, 1 - alpha, 0)
        overlay = overlay.transpose(1,0,2)
        maps.append(overlay)
        dir_name, filename = os.path.split(save_file)
        sub_img_path = os.path.join(dir_name, 'sub_img', filename)
        path, extension = sub_img_path.rsplit('.', 1)
        print("{}_{}.{}".format(path, i, extension))
        plt.imsave("{}_{}.{}".format(path, i, extension), overlay)
    # 计算子图的布局，你可以根据图像数量和期望的布局调整这里的参数
    nrows = int(np.ceil(len(maps) ** 0.5))
    ncols = int(np.ceil(len(maps) / nrows))

    # 创建一个大图和多个子图
    fig, axs = plt.subplots(nrows, ncols, figsize=(ncols * 3, nrows * 3))

    # 遍历所有的子图并在每个子图中显示图像
    for ax, img in zip(axs.flatten(), maps):
        ax.imshow(img)
        ax.axis('off')  # 不显示坐标轴

    # 如果图片数量不是完全填满子图，隐藏多余的子图
    for ax in axs.flatten()[len(maps):]:
        ax.axis('off')

    # 调整子图间距
    plt.tight_layout()
    plt.savefig(save_file)
    plt.show()

In [None]:
dataset = 'CUTE80'
img_path = './train_data/common_benchmarks/{}/imgs/'.format(dataset)
save_path = './output/rec/visiual/'
head = 11
img = 'cute_124.jpg'
img_file = os.path.join(img_path, img)
save_file = os.path.join(save_path, '{}_h{}'.format(dataset, head), img)
os.makedirs(os.path.join(save_path, '{}_h{}'.format(dataset, head), 'sub_img'), exist_ok=True)

visual_att(head, img_file, save_file)

In [None]:
os.makedirs('./output/rec/visiual/CUTE80_h3/sub_img/', exist_ok=True)

In [None]:
dataset = 'CUTE80'
img_path = './train_data/common_benchmarks/{}/imgs/'.format(dataset)
save_path = './output/rec/visiual/'
head = 3
img_name = os.listdir(img_path)

for img in img_name:
    img_file = os.path.join(img_path, img)
    os.makedirs(os.path.join(save_path, '{}_h{}'.format(dataset, head), 'sub_img'), exist_ok=True)
    save_file = os.path.join(save_path, '{}_h{}'.format(dataset, head), img)
    while len(argv) <=5:
        argv.append('')
    
    argv[1]='-c'
    argv[2]='./configs/rec/rec_cloformer_cppd.yml'
    argv[3]='-o'
    argv[4]='Global.infer_img={}'.format(img_file)        
    
    if len(argv) > 5:
        del argv[5:]
        
    get_local.clear()
    get_local.activate()
    
    config, device, logger, vdl_writer = program.preprocess()
    main()
    visual_att(head, img_file, save_file)

    print(img_file)