In [1]:
import os
from os.path import basename
from pathlib import Path

import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.plugins import CheckpointIO
from pytorch_lightning.utilities import rank_zero_only
from sconf import Config
import PIL.Image

# from donut import DonutModel, DonutConfig
import donut.model_custom
from donut.model_custom import DonutModel, DonutConfig, BARTCustomTokenizer
from donut import DonutDataset
from lightning_module import DonutDataPLModule, DonutModelPLModule

import cv2
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config_path='/data/murayama/k8s/ocr_dxs1/donut/config/test_health_box.yaml'
config=Config(config_path)
model = DonutModel.from_pretrained(
    config.pretrained_model_name_or_path,
    input_size=config.input_size,
    max_length=config.max_length,
    align_long_axis=config.align_long_axis,
    enable_char_map=config.char_map,
    box_pred=config.get('box_pred',False)
)

if 'classes' in config:
    stokens = []
    for cn in config.classes:
        stokens += [fr"<s_{cn}>", fr"</s_{cn}>"]
    model.decoder.add_special_tokens(stokens)

model = model.cuda().half().eval()


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'XLMRobertaTokenizer'. 
The class this function is called from is 'BARTCustomTokenizer'.


In [53]:
image_path='/data/murayama/k8s/ocr_dxs1/dxsx-atypical-engine/data/vhr_testdata_64_focr/testdata_2021/G49-2021-000442-001.jpg'
image=PIL.Image.open(image_path)
# outs = model.inference(image, prompt='<s_health>', return_attentions=True, return_confs=True, return_tokens=True)
try:
    outs = model.inference(image, prompt='<s_health><s_health>', return_attentions=True, return_segmap=False, return_confs=True, return_tokens=True, token_score_thresh=0.45)
except Exception as ex:
    import traceback
    traceback.print_exc()

In [23]:
outs["predictions"]

[[{'bmi_0': ['20.5', 1.0, [2, 3]]},
  {'bmi_1': ['20.8', 1.0, [7, 8]]},
  {'bmi_2': ['20.8', 0.99951171875, [12, 13]]},
  {'bp_diastolic_0': ['66', 1.0, [17]]},
  {'bp_diastolic_1': ['70', 1.0, [21]]},
  {'bp_diastolic_2': ['48', 1.0, [25]]},
  {'bp_diastolic_2_0': ['55', 1.0, [29]]},
  {'bp_diastolic_2_1': ['60', 0.9990234375, [33]]},
  {'bp_systolic_0': ['102', 1.0, [37]]},
  {'bp_systolic_1': ['104', 1.0, [41]]},
  {'bp_systolic_2': ['84', 0.99951171875, [45]]},
  {'bp_systolic_2_0': ['96', 1.0, [49]]},
  {'bp_systolic_2_1': ['108', 0.99951171875, [53]]},
  {'btype_abo_0': ['B', 0.99853515625, [57]]},
  {'btype_abo_1': ['B', 1.0, [61]]},
  {'btype_abo_2': ['B', 1.0, [65]]},
  {'btype_rh_0': ['+', 0.99853515625, [69]]},
  {'btype_rh_1': ['+', 1.0, [73]]},
  {'btype_rh_2': ['+', 1.0, [77]]},
  {'ecg_rest_0': ['異常所見なし', 0.9995930989583334, [82, 83, 84, 85, 86, 87]]},
  {'ecg_rest_1': ['異常所見なし', 0.9998372395833334, [91, 92, 93, 94, 95, 96]]},
  {'ecg_rest_2': ['異常所見なし',
    0.9999186197

In [5]:
decoder_states = [ds[-1] for ds in outs["decoder_state"]]
decoder_states = torch.cat(decoder_states, dim=1)
_, boxes, confs = model.forward_box_head(decoder_states)
box_mask = outs['tokens'][:,1:] > model.decoder.tokenizer.vocab_size
#boxes = boxes[box_mask][1::2]
boxes *= 1600
boxes = boxes.detach().cpu().numpy().astype(np.int)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  boxes = boxes.detach().cpu().numpy().astype(np.int)


In [19]:
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
mean = torch.Tensor(IMAGENET_DEFAULT_MEAN).reshape(3,1,1)
std = torch.Tensor(IMAGENET_DEFAULT_STD).reshape(3,1,1)

vsize = model.decoder.tokenizer.vocab_size + 1
_img = model.encoder.prepare_input(image)
_img = ((_img*std+mean)*255).permute(1,2,0).numpy().astype(np.uint8)
_img = _img.copy()
tkns = outs['tokens'][:,1:]
scores = outs['scores'][1:]
# idxs = [19,21]
idxs = [ii for ii in range(len(scores))]
_boxes = [boxes[0,ii] for ii in idxs]
_tkns = [tkns[0,ii] for ii in idxs]
_scores = [scores[ii] for ii in idxs]
_confs = [confs[0,ii] for ii in idxs]
# _boxes = boxes
for box, tt, ss, cc in zip(_boxes,_tkns,_scores,_confs):
# for box in _boxes:
    if tt < vsize:
        continue
    x0 = box[0] - box[2]//2
    y0 = box[1] - box[3]//2
    x1 = box[0] + box[2]//2
    y1 = box[1] + box[3]//2
    print(tt,ss,cc, x0,y0,x1,y1)
    # print(x0,y0,x1,y1)
    cv2.rectangle(_img, (x0,y0),(x1,y1),(0,255,0),2)
cv2.imwrite('hoge1.jpg',_img)


tensor(58397, device='cuda:0') 0.9892578125 tensor(0.2440, device='cuda:0', dtype=torch.float16, grad_fn=<SelectBackward0>) 1109 684 1145 706
tensor(57553, device='cuda:0') 1.0 tensor(0.5596, device='cuda:0', dtype=torch.float16, grad_fn=<SelectBackward0>) 1130 678 1158 692
tensor(58398, device='cuda:0') 1.0 tensor(0.5044, device='cuda:0', dtype=torch.float16, grad_fn=<SelectBackward0>) 1313 678 1341 692
tensor(57554, device='cuda:0') 1.0 tensor(0.5679, device='cuda:0', dtype=torch.float16, grad_fn=<SelectBackward0>) 1311 678 1339 694
tensor(58399, device='cuda:0') 1.0 tensor(0.5269, device='cuda:0', dtype=torch.float16, grad_fn=<SelectBackward0>) 1456 677 1486 693
tensor(57555, device='cuda:0') 0.9951171875 tensor(0.5591, device='cuda:0', dtype=torch.float16, grad_fn=<SelectBackward0>) 1455 677 1481 693
tensor(58403, device='cuda:0') 0.95654296875 tensor(0.4392, device='cuda:0', dtype=torch.float16, grad_fn=<SelectBackward0>) 1137 811 1159 827
tensor(57559, device='cuda:0') 1.0 tensor

True

In [114]:
import cv2
import numpy as np
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import importlib
importlib.reload(donut.model_custom)
token_indexes = [41]
decoder_cross_attentions = outs["attentions"]["cross_attentions"]
mean = torch.Tensor(IMAGENET_DEFAULT_MEAN).reshape(3,1,1)
std = torch.Tensor(IMAGENET_DEFAULT_STD).reshape(3,1,1)

box, thres_heatmap, agg_heatmap = donut.model_custom.DonutModel.max_bbox_from_heatmap(decoder_cross_attentions, token_indexes, discard_ratio=0, return_thres_agg_heatmap=True, final_h=1600,final_w=1600,heatmap_h=1600//32,heatmap_w=1600//32)
_img = model.encoder.prepare_input(image)
_img = ((_img*std+mean)*255).permute(1,2,0).numpy().astype(np.uint8)

hmap = cv2.applyColorMap(agg_heatmap, cv2.COLORMAP_JET)
viz = cv2.addWeighted(_img, 0.5, hmap, 0.5, 0)
cv2.imwrite('hoge.jpg',viz)


True

In [14]:
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
mean = torch.Tensor(IMAGENET_DEFAULT_MEAN).reshape(3,1,1)
std = torch.Tensor(IMAGENET_DEFAULT_STD).reshape(3,1,1)

_img = model.encoder.prepare_input(image)
_img = ((_img*std+mean)*255).permute(1,2,0).numpy().astype(np.uint8)
hh, ww = _img.shape[0]//32, _img.shape[1]//32

decoder_cross_attentions = outs["attentions"]["cross_attentions"]
token_indexes = [73]

# attens = [decoder_cross_attentions[ii][-1].squeeze() for ii in token_indexes]
# attens = torch.stack(attens).max(dim=0)[0].max(dim=0)[0]
# attens = attens.reshape(_img.shape[0]//32, _img.shape[1]//32).detach().cpu().numpy().astype(np.float32)

attens = [decoder_cross_attentions[ii][-1] for ii in token_indexes]
attens = torch.stack(attens).squeeze(3)
attens = attens.permute(1,2,0,3)
attens = attens.reshape(1,16,-1,hh,ww).max(dim=2)[0]
attens = torch.pixel_shuffle(attens,4)[0,0].detach().cpu().numpy().astype(np.float32)

cv2.imwrite('attens.jpg', attens*255)
# attens = cv2.resize(attens, None, None, fx=8, fy=8, interpolation=cv2.INTER_NEAREST)
attens = cv2.resize(attens, None, None, fx=2, fy=2, interpolation=cv2.INTER_NEAREST)

seg_map = outs['segmap'].detach().cpu().numpy().squeeze().astype(np.float32)
cv2.imwrite('segmap.jpg', seg_map*255)

seg_map = seg_map*attens
cv2.imwrite('score.jpg', seg_map*255)

# seg_map = np.tile(seg_map[:,:,np.newaxis], (1,1,3))
seg_map = cv2.resize(attens, None, fx=4, fy=4, interpolation=cv2.INTER_NEAREST)
seg_map = cv2.applyColorMap((seg_map*255).astype(np.uint8), cv2.COLORMAP_JET)

hoge = cv2.addWeighted(_img, 0.5, seg_map, 0.5, 0)
cv2.imwrite('viz.jpg', hoge)


KeyError: 'segmap'

In [9]:
enc = outs['last_hidden_state']
token_indexes = [123, 124, 125]

attens = [outs['attentions']['cross_attentions'][ii][-1] for ii in token_indexes]
#attens = [decoder_cross_attentions[ii][-1] for ii in range(len(decoder_cross_attentions))]
attens = torch.stack(attens).squeeze(3)
attens = attens.permute(1,2,0,3)
mask = torch.where(attens<0.5)
attens[mask]=0
seg = model.forward_seg_head(enc, attens).sigmoid()[0,0]
seg = seg.detach().cpu().numpy().astype(np.float32)*255
seg=cv2.resize(seg, None, fx=4, fy=4, interpolation=cv2.INTER_NEAREST)
seg = cv2.applyColorMap(seg.astype(np.uint8), cv2.COLORMAP_JET)
hoge = cv2.addWeighted(_img, 0.5, seg, 0.5, 0)
cv2.imwrite('hoge.jpg', hoge)

True

TypeError: decode() missing 1 required positional argument: 'token_confs'

In [106]:
sel=outs["attentions"]['self_attentions'][418][-1]
sel = sel.squeeze().max(dim=0)[0]
sel

tensor([2.5415e-01, 4.1237e-03, 1.6699e-03, 8.2350e-04, 2.3880e-03, 6.2764e-05,
        8.3447e-04, 5.2214e-05, 1.0192e-04, 8.3113e-04, 9.4235e-05, 6.9666e-04,
        1.9729e-05, 8.7142e-05, 1.3008e-03, 5.2261e-03, 3.7651e-03, 4.6396e-04,
        1.9588e-03, 7.1168e-05, 8.5354e-04, 1.0031e-04, 8.0681e-04, 1.4365e-05,
        5.7745e-04, 2.1243e-04, 3.4828e-03, 8.0156e-04, 3.2711e-03, 5.5981e-04,
        7.2327e-03, 1.0681e-04, 6.8760e-04, 8.8573e-05, 1.0271e-03, 4.4966e-04,
        4.9896e-03, 1.0614e-03, 1.4534e-03, 2.5392e-05, 7.4148e-04, 4.6432e-05,
        2.1756e-04, 4.1187e-05, 2.7704e-04, 1.7881e-04, 1.1482e-03, 2.0466e-03,
        3.1700e-03, 4.9400e-04, 4.6611e-04, 3.2377e-04, 3.3741e-03, 4.0829e-05,
        2.4872e-03, 4.2000e-03, 1.3985e-02, 2.3937e-03, 2.2297e-03, 2.3193e-03,
        1.6098e-02, 2.2831e-03, 2.4834e-03, 1.1129e-03, 7.5417e-03, 1.1629e-04,
        1.2887e-04, 4.1187e-05, 4.2248e-04, 8.7142e-05, 1.3418e-03, 2.9707e-04,
        1.3771e-02, 1.8063e-03, 9.4318e-

In [54]:
decoder_states = [ds[-1] for ds in outs["decoder_state"]]
decoder_states = torch.cat(decoder_states, dim=1)
decoder_states.shape

torch.Size([1, 577, 1024])