In [None]:
try:
    import cc3d
except:
    !pip install 'connected-components-3d'


import sys
sys.path.append('/kaggle/input/hengck23-rectification-net-demo')

import numpy as np
import pandas as pd
import cv2
import cc3d

import matplotlib.pyplot as plt
import matplotlib
from scipy.interpolate import griddata

from model import *


print('import ok!')

In [None]:
def encode_with_resnet(e, x):
	encode = []
	# x 256

	x = e.conv1(x)
	x = e.bn1(x)
	x = e.act1(x)
	#x = e.maxpool(x)

	x = e.layer1(x); encode.append(x)  # 128
	x = e.layer2(x); encode.append(x)  # 64
	x = e.layer3(x); encode.append(x)  # 32
	x = e.layer4(x); encode.append(x)  # 16
	return encode


class Net(nn.Module):
	def __init__(self, pretrained=False, cfg=None):
		super(Net, self).__init__()
		self.output_type = ['infer', 'loss']
		self.register_buffer('D', torch.tensor(0))
		self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1))
		self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1))

		#arch = 'convnext_tiny.fb_in22k'  
		arch = 'resnet18d.ra4_e3600_r224_in1k'

		encoder_dim = {
			'resnet18d.ra4_e3600_r224_in1k': [64, 128, 256, 512],
			'convnext_tiny_in22k': [96, 192, 384, 768],
			'convnext_tiny.fb_in22k': [96, 192, 384, 768],
			'convnext_small.fb_in22k': [96, 192, 384, 768], #96, 192, 384, 768
			'convnext_base.fb_in22k': [128, 256, 512, 1024], #96, 192, 384, 768
			'resnet50d': [64, 256, 512, 1024, 2048, ],
		}[arch]
		decoder_dim = [256, 128, 64, 32]

		#self.upby2 = UpSampleDeconv(decoder_dim[-1],decoder_dim[-1])

		self.encoder = timm.create_model(
			model_name=arch, pretrained=pretrained, in_chans=3, num_classes=0, global_pool=''
		)

		self.decoder = MyUnetDecoder(
			in_channel=encoder_dim[-1],
			skip_channel=encoder_dim[:-1][::-1]+[0],
			out_channel=decoder_dim,
			scale = [2,2,2,2]
		)
		#self.lead      = nn.Conv2d(decoder_dim[-1], NUM_LEAD + 1, kernel_size=1)  #softmax
		#self.marker    = nn.Conv2d(decoder_dim[-1], NUM_MARKER + 1, kernel_size=1)  #softmax
		self.gpoint = nn.Conv2d(decoder_dim[-1], 1, kernel_size=1) #yx: sigmoid
		self.ghline = nn.Conv2d(decoder_dim[-1], 1, kernel_size=1) #yx: sigmoid
		self.gvline = nn.Conv2d(decoder_dim[-1], 1, kernel_size=1) #yx: sigmoid


	# todo image level grade ???

	def forward(self, batch):
		device = self.D.device
		image = batch['image'].to(device)

		B, _3_, H, W = image.shape

		x = image.float() / 255
		x = (x - self.mean) / self.std

		# ---------------------------------------

		e = self.encoder
		encode = encode_with_resnet(e, x) 
		#[print(f'encode_{i}', e.shape) for i,e in enumerate(encode)]

		last, decode = self.decoder(
			feature=encode[-1], skip=encode[:-1][::-1]+[None]
		)
		#last = self.upby2(last) 
		#[print(f'decode_{i}', e.shape) for i,e in enumerate(decode)]
		#print('last', last.shape)


		#lead     = self.lead(last)
		#marker   = self.marker(last)
		gpoint = self.gpoint(last)
		ghline = self.ghline(last)
		gvline = self.gvline(last)


		output = {}
		if 'loss' in self.output_type:
		 
			output['gpoint_loss'] = F.binary_cross_entropy_with_logits(
				gpoint, (batch['gpoint'].to(device) > 0.5).float())
			output['ghline_loss'] = F.binary_cross_entropy_with_logits(
				ghline, (batch['ghline'].to(device) > 0.5).float())
			output['gvline_loss'] = F.binary_cross_entropy_with_logits(
				gvline, (batch['gvline'].to(device) > 0.5).float())
			output['grid_loss'] = 2*output['gpoint_loss']+output['ghline_loss']+output['gvline_loss']

		# ----
			#output['signal_loss'] = F.mse_loss(signal, batch['signal'].to(device).float()) #snr
		#todo masked loss ... (invalid point)


		if 'infer' in self.output_type:
			output['gpoint'] = torch.sigmoid(gpoint)
			output['ghline'] = torch.sigmoid(ghline)
			output['gvline'] = torch.sigmoid(gvline)

		return output


def run_check_net():
	H, W = 320, 320
	batch_size = 4

	batch = {
		'image': torch.from_numpy(np.random.randint(0, 256, (batch_size, 3, H, W))).byte(),
		'gpoint': torch.from_numpy(np.random.uniform(0,1, (batch_size, 1, H, W))).float(),
		'ghline': torch.from_numpy(np.random.uniform(0,1, (batch_size, 1, H, W))).float(),
		'gvline': torch.from_numpy(np.random.uniform(0,1, (batch_size, 1, H, W))).float(),
	}

	net = Net(pretrained=True)#.cuda()
	# print(net)

	with torch.no_grad():
		with torch.amp.autocast('cuda'):
			output = net(batch)
	# ---


	print('batch')
	for k, v in batch.items():
		print(f'{k:>32} : {v.shape} ')

	print('output')
	for k, v in output.items():
		if 'loss' not in k:
			print(f'{k:>32} : {v.shape} ')
	print('loss')
	for k, v in output.items():
		if 'loss' in k:
			print(f'{k:>32} : {v.item()} ')
run_check_net()

In [None]:
KAGGLE_DIR='/kaggle/input/physionet-ecg-image-digitization'
DEVICE = 'cuda'

image_id = '31294838'  
type_id  = '0005'  

image = cv2.imread(f'{KAGGLE_DIR}/train/{image_id}/{image_id}-{type_id}.png', cv2.IMREAD_COLOR_RGB)
H, W = image.shape[:2]
print('image:', H, W)

pH = int(H // 32) * 32 + 32
pW = int(W // 32) * 32 + 32
padded = np.pad(image, [[0, pH - H], [0, pW - W], [0, 0]], mode='constant', constant_values=0)


def run_infer_demo():
    net = Net(pretrained=False) 
    #f = torch.load(f'/kaggle/input/hengck23-rectification-net-demo/last.checkpoint.pth', map_location=lambda storage, loc: storage)
    f = torch.load(f'/kaggle/input/hengck23-rectification-net-demo/00002500.checkpoint.pth', map_location=lambda storage, loc: storage)

    state_dict = f['state_dict']
    print(net.load_state_dict(state_dict, strict=False))  # True
    net = net.to(DEVICE)
    net = net.eval()
    net.output_type = ['infer']

    prob = 0
    count = 0
    print('padded', padded.shape)
    for trial in range(4):
        print('trial', trial)
        crop = torch.from_numpy(np.ascontiguousarray(padded.transpose(2, 0, 1))).unsqueeze(0)

        # augment
        if trial == 1:
            crop = torch.flip(crop, [2]).contiguous()
        if trial == 2:
            crop = torch.flip(crop, [3]).contiguous()
        if trial == 3:
            crop = torch.flip(crop, [2, 3]).contiguous()

        batch = {
            'image': crop,
        }
        #with torch.amp.autocast('cuda', dtype=torch.bfloat16):
        with torch.no_grad():
            output = net(batch)
        

        #print('g1', output['gridpoint'].max())
        p = torch.cat([
            output['gpoint'],output['ghline'],output['gvline']
        ],1)
        if trial == 1:
            p = torch.flip(p, [2]).contiguous()
        if trial == 2:
            p = torch.flip(p, [3]).contiguous()
        if trial == 3:
            p = torch.flip(p, [2, 3]).contiguous()

        p = p.float().data.cpu().numpy()[0]
        prob += p
        count +=1
        torch.cuda.empty_cache()
        

    prob = prob/count
    prob = prob[..., :H, :W]

    #save
    np.save(f'prob.npy', prob)
    cv2.imwrite(f'image.png', cv2.cvtColor(image, cv2.COLOR_RGB2BGR))

run_infer_demo()
print('infer ok!') 

In [None]:
from post_process import *

def probability_to_gridpoint(probability, image=None):
	is_debug = True
	is_ok = True

	gp, gh, gv = probability[0], probability[1], probability[2]
	_3_, H, W = probability.shape

	cc = cc3d.connected_components(gp > 0.7)
	stats = cc3d.statistics(cc)
	pcenter = stats['centroids'][1:]
	print('pcenter:', len(pcenter))  # 2492

	## label horizontal lines
	hcc = np.zeros((H, W), dtype=np.uint8)
	cc = cc3d.connected_components(gh > 0.5)
	cc = cc3d.dust(
        cc, threshold=1000,
        connectivity=26, in_place=False
    ) 
	stats = cc3d.statistics(cc)
	hcenter = stats['centroids'][1:]
	print('hcenter', len(hcenter))  # 44

	argsort = np.argsort(hcenter[:, 0])
	for j, a in enumerate(argsort):
		y = hcenter[a, 0]
		hcc[cc == a + 1] = j + 1

	## label vertical lines
	vcc = np.zeros((H, W), dtype=np.uint8)
	cc = cc3d.connected_components(gv > 0.5)
	cc = cc3d.dust(
        cc, threshold=1000,
        connectivity=26, in_place=False
    ) 
	stats = cc3d.statistics(cc)
	vcenter = stats['centroids'][1:]
	print('vcenter', len(vcenter))  # 57

	argsort = np.argsort(vcenter[:, 1])
	for j, a in enumerate(argsort):
		x = vcenter[a, 1]
		vcc[cc == a + 1] = j + 1
 

	## meshing ...
	gridpoint_xy = np.zeros((44, 57, 2), np.float32)

	##--
	## filtering
	#   choose top/longest 44 horizontal lines ...
 
	# temporary ....
	if (len(pcenter) == 2492) & (len(hcenter) == 44) & (len(vcenter) == 57):
		## gridpoint
		for y, x in pcenter:
			uy = ROUND(y)
			ux = ROUND(x)
			j = hcc[uy, ux]
			i = vcc[uy, ux]
			# print(f'({x},{y}) --> ({j},{i})')
			gridpoint_xy[j - 1, i - 1] = [x, y]
	else:
		print('filtering method not impelemnted yet ....!') 
		#return None


	return {
		'gridpoint_xy': gridpoint_xy,
		'vcc': vcc,
		'hcc': hcc,
	}

def draw_probability_point(probability, image):
    gp = probability[0]
    
    threshed = gp > 0.7
    cc = cc3d.connected_components(threshed)
    stats = cc3d.statistics(cc)
    centroid = stats['centroids'][1:]
    print('centroid', len(centroid)) #2492
    
    overlay1= image//3
    for y1,x1 in centroid:
    	y1 = ROUND(y1)
    	x1 = ROUND(x1)
    	cv2.circle(overlay1, (x1, y1), 10, color=[0, 255, 0], thickness=-1)
        
    return overlay1
    
######################################################################

image = cv2.imread('image.png', cv2.IMREAD_COLOR_RGB)
probability = np.load('prob.npy')

#convert from unet probability into recification grid
out = probability_to_gridpoint(probability, image)
gridpoint_xy = out['gridpoint_xy']

rectified = rectify_image(
	image, 
    gridpoint_xy
    #interpolate_xy 
)
plt.imshow(image); plt.show()
print('rectified')
plt.imshow(rectified); plt.show()
print('rectified (zoom)')
plt.imshow(rectified[-800:,:800]); plt.show()



#option: these are for visualisation only
print('unet lines probability')
mline, mpoint = draw_probability(probability) 
plt.imshow(RESIZE(mline)); plt.show()

mpoint = draw_probability_point(probability, image) 
print('unet points probability')
plt.imshow(mpoint); plt.show()



vcc = out['vcc']
hcc = out['hcc']
hcc1 = color_line(hcc, cmap='repeat')
vcc2 = color_line(vcc, cmap='repeat')
print('clustered horizontal lines using connected component analysis cca')
plt.imshow(RESIZE(hcc1)); plt.show()
print('clustered vertical lines using connected component analysis cca')
plt.imshow(RESIZE(vcc2)); plt.show()