In [None]:
!pip install segmentation_models_pytorch

In [2]:
from torchvision.transforms import transforms
from torch.utils.data import Dataset
from PIL import Image
import os

class CustomDataset1(Dataset):
    def __init__(self, data_folder, input_size=(960,640)):
        self.input_size = input_size
        self.image_dir = data_folder
        self.images = [file for file in os.listdir(self.image_dir) if file.lower().endswith(('.jpg', '.jpeg', '.png'))]

        self.transform_image = transforms.Compose([
            transforms.Resize(self.input_size),
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        image = Image.open(img_path)
        image = self.transform_image(image)
        return image


In [None]:
import segmentation_models_pytorch as smp
import torch
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import cv2
import timeit

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = 'resnet18'

model = smp.Unet(encoder, in_channels=3, classes=1, activation=None).to(device)

checkpoint_path = "/model/unet_resnet18_100k.pth"
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

model.eval()

image_path = "/data/demo/"
dataset = CustomDataset1(image_path)
input_size = (640, 960)

x0, x1, x2, x3 = 380,960-380,960-300,300
y0, y1 = 520, 640

y, x = np.ogrid[:640, :960]

condition = (x >= x0 + (x3 - x0) / (y1 - y0) * (y - y0)) & (x < x1 + (x2 - x1) / (y1 - y0) * (y - y0)) & (y >= y0) & (y < y1)

transform = transforms.Compose([
        transforms.ToTensor()
    ])

for input_image in dataset:
    start_time = timeit.default_timer()
    input_tensor = transform(input_image).to(device)
    input_batch = input_tensor.unsqueeze(0)
    with torch.no_grad():
        output = model(input_batch)

    output_image = output.squeeze().detach().cpu().numpy()
    output_image = ((output_image - output_image.min()) / (output_image.max() - output_image.min()))
    output_image[output_image>0.83]=255

    mask = np.zeros((640, 960), dtype=np.uint8)
    mask = torch.zeros((1, 640, 960), dtype=torch.uint8).to(device)

    mask[0, condition] = 1
    res = output_image * mask.cpu().numpy()
    sum_value = np.sum(res)

    if sum_value < 25500:
        input_image=input_image.resize((960,640))
        original_image_np = np.array(input_image)
        green_mask = np.zeros_like(original_image_np)
        green_mask[:, :, 1] = 255
        green_mask_resized = cv2.resize(green_mask, (960, 640))
        overlay = original_image_np.copy()
        boolean_mask_np = mask.cpu().numpy()[0, ...].astype(bool)
        overlay[boolean_mask_np] = cv2.addWeighted(
            original_image_np[boolean_mask_np], 1,
            green_mask_resized[boolean_mask_np], 0.5, 0
        )
    else:
        input_image=input_image.resize((960,640))
        original_image_np = np.array(input_image)
        red_mask = np.zeros_like(original_image_np)
        red_mask[:, :, 0] = 255
        red_mask_resized = cv2.resize(red_mask, (960, 640))
        overlay = original_image_np.copy()
        boolean_mask_np = mask.cpu().numpy()[0, ...].astype(bool)
        overlay[boolean_mask_np] = cv2.addWeighted(
            original_image_np[boolean_mask_np], 1,
            red_mask_resized[boolean_mask_np], 0.5, 0
        )

    end_time = timeit.default_timer()
    elapsed_time = end_time - start_time
    print(f"Elapsed time: {elapsed_time} seconds")

    plt.figure(figsize=(10, 5))

    plt.subplot(1, 3, 1)
    plt.title('Original Image')
    plt.imshow(input_image)

    plt.subplot(1, 3, 2)
    plt.title('Output Image')
    plt.imshow(overlay)

    plt.subplot(1, 3, 3)
    plt.title('Model Output')
    plt.imshow(output_image, cmap='gray')

    plt.show()
