In [None]:
import os
import glob
import torch
import cv2
import onnx
import onnxruntime as ort
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import imutils

from denku import show_images
from torch import nn
from omegaconf import OmegaConf
from run import preprocess_config, parse_loggers, seed_everything, DataModule, get_obj_from_str
from matplotlib import animation, rc
rc('animation', html='jshtml')






In [None]:
main_config = OmegaConf.load('./output/atom/efficientnetb3_Unet_randcrop_/config.yaml')
config = preprocess_config(main_config)

seed_everything(config['common']['seed'], workers=True)
datamodule = DataModule(config)

model = get_obj_from_str(config['lightning_model'])(config)
CHECKPOINT_PATH = './output/atom/efficientnetb3_Unet_randcrop_/best-epoch=60-iou_valid=0.47.ckpt'
ckpt = torch.load(CHECKPOINT_PATH, map_location='cuda:0')

model.model.load_state_dict(ckpt['state_dict']) 
model = model.eval()

In [None]:
test = pd.read_csv('./dataset/test.csv')
dem_df = test.sample(n=20, replace=True)

for col, row in dem_df.iterrows():
    mean=np.array([0, 0, 0])
    std=np.array([1, 1, 1])
    path_folder = '/home/raid/hdd_storage/datasets/rosatom/sam'
    img_path = f'{path_folder}/{row["path_to_imgs"]}'
    gt_mask_path = f'{path_folder}/{row["path_to_masks"]}'
    gt_mask = cv2.imread(gt_mask_path)
    gt_mask = cv2.resize(gt_mask, (512, 512), interpolation=cv2.INTER_LINEAR)
    
    gt_mask = cv2.cvtColor(gt_mask, cv2.COLOR_RGB2GRAY)
    
    
    image = cv2.imread(img_path)
    image = cv2.resize(image, (512, 512), interpolation=cv2.INTER_LINEAR)
    
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    img = ((image.astype(np.float32) / 255.0 - mean) / std).astype(np.float32)
#     img = image.astype(np.float32)

    img = torch.from_numpy(img).permute(2, 0, 1)
    
    with torch.no_grad():
        pred_mask = model(img.unsqueeze(0))
#         pred_mask = process_multimask2np(pred_mask.squeeze(0), [x for x in range(0, 14)])
        pred_mask = pred_mask.squeeze(0).argmax(dim=0).detach().cpu().numpy().astype(np.uint8)

    
    image = cv2.resize(image, (512, 512), interpolation=cv2.INTER_LINEAR)
    pred_mask = cv2.resize(pred_mask * 18, (512, 512), interpolation=cv2.INTER_NEAREST)


    plt.figure(figsize=(20, 20))
    plt.subplot(1,4,1)
    plt.title('orig')
    plt.imshow(image)
    plt.subplot(1,4,2)
    plt.title('overlay')
    plt.imshow(image)
    plt.imshow(pred_mask, alpha=0.5)
    plt.subplot(1,4,3)
    plt.title('pred_mask')
    plt.imshow(pred_mask)
    plt.subplot(1,4,4)
    plt.title('gt_mask')
    plt.imshow(image)
    plt.imshow(gt_mask, alpha=0.5)
    plt.show
    

In [None]:
test = pd.read_csv('./dataset/test.csv')
dem_df = test.sample(n=20, replace=True)

test_data = pd.DataFrame(columns=['filename','x','y','class'])

for index, row in dem_df.iterrows():
    mean=np.array([0, 0, 0])
    std=np.array([1, 1, 1])
    path_folder = '/home/raid/hdd_storage/datasets/rosatom/sam'
    img_path = f'{path_folder}/{row["path_to_imgs"]}'
    gt_mask_path = f'{path_folder}/{row["path_to_masks"]}'
    gt_mask = cv2.imread(gt_mask_path)
    gt_mask = cv2.resize(gt_mask, (512, 512), interpolation=cv2.INTER_LINEAR)
    
    gt_mask = cv2.cvtColor(gt_mask, cv2.COLOR_RGB2GRAY)
    
    
    image = cv2.imread(img_path)
    image = cv2.resize(image, (512, 512), interpolation=cv2.INTER_LINEAR)
    
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    img = ((image.astype(np.float32) / 255.0 - mean) / std).astype(np.float32)

    img = torch.from_numpy(img).permute(2, 0, 1)
    
    with torch.no_grad():
        pred_mask = model(img.unsqueeze(0))
        pred_mask = pred_mask.squeeze(0).argmax(dim=0).detach().cpu().numpy().astype(np.uint8)

    
    image = cv2.resize(image, (512, 512), interpolation=cv2.INTER_NEAREST)
    pred_mask = cv2.resize(pred_mask, (512, 512), interpolation=cv2.INTER_NEAREST)
    
    u_labels = np.unique(pred_mask)
    for u_label in u_labels:
        if u_label == 0:
            continue
        mask = (pred_mask == u_label).astype(np.uint8) * 255
        cnts = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cnts = imutils.grab_contours(cnts)

        for c in cnts:
            M = cv2.moments(c)
            if M["m00"] == 0:
                continue
            cX = int(M["m10"] / M["m00"])
            cY = int(M["m01"] / M["m00"])
            
            filename = row["path_to_imgs"]
            x = cX
            y = cY
            class_ = u_label
            

            cv2.drawContours(image, [c], -1, (0, 128, 255), 2)
            cv2.circle(image, (cX, cY), 7, (32, 32, 255), -1)
            cv2.putText(image, str(u_label), (cX - 5, cY - 5), cv2.FONT_HERSHEY_COMPLEX, 0.9, (128, 0, 0), 2)
    show_images([image])
    

In [None]:
import tqdm
test = pd.read_csv('./dataset/test.csv')

test_data = pd.DataFrame(columns=['filename','x','y','class'])
filename = []
x = []
y = []
class_ = [] 

for index, row in tqdm.tqdm(test.iterrows()):
    mean=np.array([0, 0, 0])
    std=np.array([1, 1, 1])
    path_folder = '/home/raid/hdd_storage/datasets/rosatom/sam'
    img_path = f'{path_folder}/{row["path_to_imgs"]}'
    gt_mask_path = f'{path_folder}/{row["path_to_masks"]}'
    gt_mask = cv2.imread(gt_mask_path)
    gt_mask = cv2.resize(gt_mask, (512, 512), interpolation=cv2.INTER_LINEAR)
    
    gt_mask = cv2.cvtColor(gt_mask, cv2.COLOR_RGB2GRAY)
    
    
    image = cv2.imread(img_path)
    image = cv2.resize(image, (512, 512), interpolation=cv2.INTER_LINEAR)
    
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    img = ((image.astype(np.float32) / 255.0 - mean) / std).astype(np.float32)

    img = torch.from_numpy(img).permute(2, 0, 1)
    
    with torch.no_grad():
        pred_mask = model(img.unsqueeze(0))
        pred_mask = pred_mask.squeeze(0).argmax(dim=0).detach().cpu().numpy().astype(np.uint8)

    
    image = cv2.resize(image, (512, 512), interpolation=cv2.INTER_NEAREST)
    pred_mask = cv2.resize(pred_mask, (512, 512), interpolation=cv2.INTER_NEAREST)
    
    u_labels = np.unique(pred_mask)
    for u_label in u_labels:
        if u_label == 0:
            continue
        mask = (pred_mask == u_label).astype(np.uint8) * 255
        cnts = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cnts = imutils.grab_contours(cnts)

        for c in cnts:
            M = cv2.moments(c)
            if M["m00"] == 0:
                continue
            cX = int(M["m10"] / M["m00"])
            cY = int(M["m01"] / M["m00"])
            
            filename.append(row["name"])
            x.append(cX)
            y.append(cY)
            class_.append(u_label)
            

            cv2.drawContours(image, [c], -1, (0, 128, 255), 2)
            cv2.circle(image, (cX, cY), 7, (32, 32, 255), -1)
            cv2.putText(image, str(u_label), (cX - 5, cY - 5), cv2.FONT_HERSHEY_COMPLEX, 0.9, (128, 0, 0), 2)

In [None]:
test_data['filename'] = filename
test_data['x'] = x
test_data['y'] = y
test_data['class'] = class_

In [None]:
test_data.to_csv('./test_.csv', index=False)