In [1]:
# set the configuration
import configparser
import os
from pathlib import Path

import cv2
import numpy as np

import torch
import torchvision

import seg_data

In [3]:

## Read the config
def read_ini(file_path):
    config = configparser.ConfigParser()
    config.read(file_path)
    return config

 
config = read_ini("./pred_config.ini")


img_path = config["DIR"]["image_dir"]
output_vis_path =config["DIR"]["output_vis_path"]
checkpoint_path = config["DIR"]["checkpoint_path"]
output_path = config["DIR"]["output_path"]

start_class_i = int(config["PARAMS"].get('start_class_i',0))

scale = int(config["PARAMS"]["scale"])

Path(output_path).mkdir(parents=True, exist_ok=True)
Path(output_vis_path).mkdir(parents=True, exist_ok=True)


assert os.path.isfile(checkpoint_path), "Checkpoint file not exist"


print(f'''Predicting info:
        Reading images from: {img_path}
        Reading the checkpoint from: {checkpoint_path}
        The output directory: {output_path}
''') 

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f'''System info:
        Using device: {device}
        CPU cores: {os.cpu_count()}
        GPU count: {torch.cuda.device_count()}
''')


Predicting info:
        Reading images from: C:\Users\sabbi\Downloads\project1\models\suture_demo_data\test_data
        Reading the checkpoint from: C:\Users\sabbi\Downloads\project1\models\checkpoint\unet_checkpoint_epoch30.pth\unet_checkpoint_epoch15.pth
        The output directory: C:\Users\sabbi\Downloads\project1\models\output\suture_demo_cc

System info:
        Using device: cpu
        CPU cores: 16
        GPU count: 0



In [4]:
dataset_pred = seg_data.segDataset(img_path = img_path,scale=scale,is_train=False, start_class_i = start_class_i)
data_loader_pred = torch.utils.data.DataLoader(dataset_pred, batch_size=1, shuffle=False)


model = torch.load(checkpoint_path)
model = model.to(device)
model.eval()

dataset_pred.img_path
dataset_pred.imgs

## Iterate through the images and save them to the directory.
for idx, (img,img_info) in enumerate(data_loader_pred):
    img_name = dataset_pred.imgs[idx]
        
    # print("img info++++++++++",img_info)
    img = img.to(device)  
    out = model(img)
    
    if "deeplab" in checkpoint_path:
        out = out['out']
    out_temp = out.cpu().detach().numpy()



    seg= out_temp[0].transpose(1, 2, 0).argmax(2)
    
    # output_mask = np.zeros(seg.shape).astype('uint8')
    # output_mask[seg==1]=255

    if scale!=1:
        seg = cv2.resize(seg, (img_info['w'].item(),img_info['h'].item()),
                interpolation = cv2.INTER_NEAREST )
    
    # print(os.path.join(output_path,img_name))
    cv2.imwrite( os.path.join(output_path,img_name),seg)

    cv2.imwrite( os.path.join(output_vis_path,img_name), np.interp(seg, [0, np.max(seg)],[1,255]).astype('uint8'))
    #  np.interp(img[0].cpu().detach().numpy().transpose(1, 2, 0),[0,1],[1,255]).astype('uint8'))


In [5]:
import matplotlib.pyplot as plt

# plt.imshow(np.interp(seg, [0, np.max(seg)],[1,255]) )
cv2.imwrite( os.path.join(output_vis_path,img_name), np.interp(seg, [0, np.max(seg)],[1,255]) )
img_name

'Acanthidops_bairdii_1_M_Back_Vis_G078269.jpg'

In [6]:
cv2.imwrite( os.path.join(output_vis_path,"a.tif"), np.interp(seg, [0, np.max(seg)],[1,255]).astype('uint8') )

True

In [7]:
img_new = np.interp(img[0].cpu().detach().numpy().transpose(1, 2, 0),[0,1],[1,255]).astype('uint8')

cv2.imshow("j",img_new)

In [8]:
type

type

In [9]:
np.interp(img.cpu().detach().numpy(),[0,1],[1,255])

array([[[[24.90588282, 23.90980445, 24.90588282, ..., 24.90588282,
          20.92156933, 20.92156933],
         [24.90588282, 23.90980445, 24.90588282, ..., 22.91372608,
          20.92156933, 20.92156933],
         [24.90588282, 23.90980445, 24.90588282, ..., 23.90980445,
          20.92156933, 20.92156933],
         ...,
         [47.81568727, 47.81568727, 47.81568727, ..., 39.84706029,
          40.84313866, 40.84313866],
         [47.81568727, 47.81568727, 47.81568727, ..., 39.84706029,
          40.84313866, 40.84313866],
         [47.81568727, 47.81568727, 47.81568727, ..., 39.84706029,
          40.84313866, 40.84313866]],

        [[24.90588282, 23.90980445, 24.90588282, ..., 24.90588282,
          19.92549096, 19.92549096],
         [24.90588282, 23.90980445, 24.90588282, ..., 22.91372608,
          19.92549096, 19.92549096],
         [24.90588282, 23.90980445, 24.90588282, ..., 23.90980445,
          19.92549096, 19.92549096],
         ...,
         [47.81568727, 47.81568727