In [68]:
import torch
import segmentation_models_pytorch as smp
import cv2
import os
import matplotlib.pyplot as plt
import shutil
from torch.utils.data import Dataset, DataLoader, random_split

model_weights_path = "model.pth"
images_path = "demo_images"

In [69]:
class CustomImageDataset(Dataset):
    def __init__(self, dataset_dir, transform=None, clip = 0, num_of_augmentations = 1, clip_offset = 0, image_work_size = 512):
        self.dataset_dir = dataset_dir
        self.transform = transform
        self.num_of_augmentations = num_of_augmentations
        self.clip_offset = clip_offset
        
        self.image_data = []
        
        files = os.listdir(dataset_dir)
        images = {f.split('_')[0]: os.path.join(dataset_dir, f) for f in files if f.endswith('_sat.jpg')}
        labels = {f.split('_')[0]: os.path.join(dataset_dir, f) for f in files if f.endswith('_mask.png')}

        image_paths = [(images[key], labels[key]) for key in images if key in labels]
        if(clip > 0):
            image_paths = image_paths[clip_offset:clip_offset + clip]
            
        for image_path, label_path in image_paths:         
            image = cv2.resize(cv2.imread(image_path, cv2.IMREAD_COLOR), (image_work_size, image_work_size))
            label = cv2.resize(cv2.imread(label_path, cv2.IMREAD_GRAYSCALE),(image_work_size, image_work_size))

            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            self.image_data.append((image, label))

    def __len__(self):
        return len(self.image_data) * self.num_of_augmentations

    def __getitem__(self, idx):
        image, label = self.image_data[idx // self.num_of_augmentations]
                
        if self.transform:
            image, label = self.transform(image, label)

        image = image / 255.0
        label = label / 255.0
        
        return (torch.tensor(image, dtype=torch.float32).permute(2, 0, 1),
                torch.tensor(label, dtype=torch.float32).unsqueeze(0))

In [70]:
model = smp.Unet(
    encoder_name="efficientnet-b5", 
    encoder_weights="imagenet",
    in_channels=3,
    classes=1
)

state_dict = torch.load(model_weights_path, map_location=torch.device('cpu'), weights_only=True) 
model.load_state_dict(state_dict)  

#Function to display an image, predicted mask, and actual mask
def display_prediction(image, predicted_mask, actual_mask):
    plt.figure(figsize=(12, 4))

    #Input image
    plt.subplot(1, 3, 1)
    plt.imshow(image.permute(1, 2, 0).cpu().numpy())  # Convert CHW to HWC for display
    plt.title("Input Image")
    plt.axis("off")

    #Model prediction
    plt.subplot(1, 3, 2)
    plt.imshow(predicted_mask.cpu().numpy(), cmap='gray')
    plt.title("Model Prediction")
    plt.axis("off")

    #Ground truth
    plt.subplot(1, 3, 3)
    plt.imshow(actual_mask.squeeze(0).cpu().numpy(), cmap='gray')
    plt.title("Actual Mask")
    plt.axis("off")
    
    plt.show()

In [None]:
#Ensure model is in evaluation mode
model.eval()
#Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
dataset = CustomImageDataset(images_path, clip=100)
#Iterate through the test dataset
for idx in range(len(dataset)):
    #Get the image and label
    image, actual_mask = dataset[idx]
    image = image.to(device).unsqueeze(0)  #Add batch dimension
    actual_mask = actual_mask.to(device)

    #Run the image through the model
    with torch.no_grad():
        predicted_mask = model(image)
        predicted_mask = torch.sigmoid(predicted_mask)  #Apply sigmoid for binary segmentation
        predicted_mask = (predicted_mask > 0.5).float().squeeze(0)  #Threshold to binary mask

    #Display the input image, model prediction, and actual mask
    display_prediction(image.squeeze(0), predicted_mask.squeeze(0), actual_mask)
