In [1]:
# set the configuration
import configparser
import os
from pathlib import Path

import cv2
import numpy as np

import torch
import torchvision

import seg_data

In [2]:

## Read the config
def read_ini(file_path):
    config = configparser.ConfigParser()
    config.read(file_path)
    return config

 
config = read_ini("./pred_config.ini")


img_path = config["DIR"]["image_dir"]
output_vis_path =config["DIR"]["output_vis_path"]
checkpoint_path = config["DIR"]["checkpoint_path"]
output_path = config["DIR"]["output_path"]

start_class_i = int(config["PARAMS"].get('start_class_i',0))

scale = int(config["PARAMS"]["scale"])

Path(output_path).mkdir(parents=True, exist_ok=True)
Path(output_vis_path).mkdir(parents=True, exist_ok=True)


assert os.path.isfile(checkpoint_path), "Checkpoint file not exist"


print(f'''Predicting info:
        Reading images from: {img_path}
        Reading the checkpoint from: {checkpoint_path}
        The output directory: {output_path}
''') 

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f'''System info:
        Using device: {device}
        CPU cores: {os.cpu_count()}
        GPU count: {torch.cuda.device_count()}
''')


Predicting info:
        Reading images from: ./data/suture/img
        Reading the checkpoint from: ./checkpoint/unet_checkpoint_epoch30.pth
        The output directory: ./output/suture/

System info:
        Using device: cuda
        CPU cores: 16
        GPU count: 1



In [3]:
dataset_pred = seg_data.segDataset(img_path = img_path,scale=scale,is_train=False, start_class_i = start_class_i)
data_loader_pred = torch.utils.data.DataLoader(dataset_pred, batch_size=1, shuffle=False)


model = torch.load(checkpoint_path)
model = model.to(device)
model.eval()

dataset_pred.img_path
dataset_pred.imgs

## Iterate through the images and save them to the directory.
for idx, (img,img_info) in enumerate(data_loader_pred):
    img_name = dataset_pred.imgs[idx]
        
    # print("img info++++++++++",img_info)
    img = img.to(device)  
    out = model(img)
    
    if "deeplab" in checkpoint_path:
        out = out['out']
    out_temp = out.cpu().detach().numpy()



    seg= out_temp[0].transpose(1, 2, 0).argmax(2)
    
    # output_mask = np.zeros(seg.shape).astype('uint8')
    # output_mask[seg==1]=255

    if scale!=1:
        seg = cv2.resize(seg, (img_info['w'].item(),img_info['h'].item()),
                interpolation = cv2.INTER_NEAREST )
    
    # print(os.path.join(output_path,img_name))
    cv2.imwrite( os.path.join(output_path,img_name),seg)

    cv2.imwrite( os.path.join(output_vis_path,img_name), np.interp(seg, [0, np.max(seg)],[1,255]).astype('uint8'))
    #  np.interp(img[0].cpu().detach().numpy().transpose(1, 2, 0),[0,1],[1,255]).astype('uint8'))


img info++++++++++ {'w': tensor([1550]), 'h': tensor([1580])}
img info++++++++++ {'w': tensor([1550]), 'h': tensor([1580])}
img info++++++++++ {'w': tensor([1550]), 'h': tensor([1580])}
img info++++++++++ {'w': tensor([1550]), 'h': tensor([1580])}
img info++++++++++ {'w': tensor([1550]), 'h': tensor([1580])}
img info++++++++++ {'w': tensor([1550]), 'h': tensor([1580])}
img info++++++++++ {'w': tensor([1550]), 'h': tensor([1580])}
img info++++++++++ {'w': tensor([1550]), 'h': tensor([1580])}


In [10]:
import matplotlib.pyplot as plt

# plt.imshow(np.interp(seg, [0, np.max(seg)],[1,255]) )
cv2.imwrite( os.path.join(output_vis_path,img_name), np.interp(seg, [0, np.max(seg)],[1,255]) )
img_name

'M1907_0215.tiff'

In [16]:
cv2.imwrite( os.path.join(output_vis_path,"a.tif"), np.interp(seg, [0, np.max(seg)],[1,255]).astype('uint8') )

True

In [49]:
img_new = np.interp(img[0].cpu().detach().numpy().transpose(1, 2, 0),[0,1],[1,255]).astype('uint8')

cv2.imshow("j",img_new)

In [15]:
type

array([[  1.,   1.,   1., ...,   1.,   1.,   1.],
       [  1.,   1.,   1., ...,   1.,   1.,   1.],
       [  1.,   1.,   1., ...,   1.,   1.,   1.],
       ...,
       [128., 128., 128., ...,   1.,   1.,   1.],
       [128., 128., 128., ...,   1.,   1.,   1.],
       [128., 128., 128., ...,   1.,   1.,   1.]])

In [27]:
np.interp(img.cpu().detach().numpy(),[0,1],[1,255])

array([[[[23.90980445, 21.9176477 , 22.91372608, ..., 23.90980445,
          22.91372608, 23.90980445],
         [20.92156933, 22.91372608, 21.9176477 , ..., 24.90588282,
          24.90588282, 25.90196119],
         [20.92156933, 21.9176477 , 21.9176477 , ..., 23.90980445,
          23.90980445, 22.91372608],
         ...,
         [31.87843142, 32.87451169, 33.87059006, ..., 39.84706029,
          38.85098192, 39.84706029],
         [38.85098192, 38.85098192, 39.84706029, ..., 39.84706029,
          40.84313866, 41.83921704],
         [45.82353052, 46.8196089 , 47.81568727, ..., 40.84313866,
          40.84313866, 42.83529541]],

        [[23.90980445, 21.9176477 , 22.91372608, ..., 23.90980445,
          22.91372608, 22.91372608],
         [20.92156933, 22.91372608, 21.9176477 , ..., 24.90588282,
          24.90588282, 25.90196119],
         [20.92156933, 21.9176477 , 21.9176477 , ..., 23.90980445,
          23.90980445, 22.91372608],
         ...,
         [32.87451169, 32.87451169