In [6]:
import os
import torch
import cv2
import matplotlib.pyplot as plt
import numpy as np
from torchvision import transforms
import segmentation_models_pytorch as smp

In [7]:
DEVICE = 'cpu'

# Function to load model and optimizer from checkpoint
def load_model_from_checkpoint(model, optimizer, checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
    
    # Load the state dict of model and optimizer
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    start_epoch = checkpoint['epoch']
    best_loss = checkpoint['loss']
    
    print(f"Model loaded from {checkpoint_path}. Resuming from epoch {start_epoch} with best loss {best_loss:.4f}")
    
    return model, optimizer, start_epoch, best_loss

# Create the UNet model
model = smp.Unet(
    encoder_name="resnet50",        # Use ResNet50 as the encoder
    encoder_weights="imagenet",     # Use pretrained weights from ImageNet
    in_channels=1,                  # Number of input channels (e.g., 1 for grayscale)
    classes=1                       # Number of output classes (binary mask)
)

# Create an optimizer for the model
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)

# Load the model from the checkpoint
checkpoint_path = '../U-Net/model/best_model_checkpoint.pth'
model, optimizer, start_epoch, best_loss = load_model_from_checkpoint(model, optimizer, checkpoint_path)

# Move the model to the appropriate device, such as CPU or GPU
model = model.to(DEVICE)

  checkpoint = torch.load(checkpoint_path, map_location=DEVICE)


Model loaded from ../U-Net/model/best_model_checkpoint.pth. Resuming from epoch 10 with best loss 0.0356


In [22]:
def predict_and_visualize(image_file, model, threshold=0.5):

    image = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE)
    original_image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)  
    image_resized = cv2.resize(image, (224, 224)) 
    
    transform = transforms.Compose([
        transforms.ToTensor(),  
        transforms.Normalize(mean=[0.485], std=[0.229]), 
    ])
    input_tensor = transform(image_resized).unsqueeze(0).to(DEVICE)  
    
    model.eval()  
    with torch.no_grad(): 
        prediction = model(input_tensor)
    
    predicted_mask = (prediction > threshold).float().squeeze().cpu().numpy()  
    
    predicted_mask_resized = cv2.resize(predicted_mask, (image.shape[1], image.shape[0]))
    
    plt.figure(figsize=(10, 5))

    plt.subplot(1, 2, 1)
    plt.imshow(predicted_mask_resized, cmap='gray')
    plt.title("Predicted Mask")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(original_image)
    plt.imshow(predicted_mask_resized, cmap='hot', alpha=0.5)  # Overlay predicted mask
    plt.title("Overlay")
    plt.axis('off')

    plt.show()
    
    return predicted_mask_resized


In [23]:
def predict_on_folder(input_folder, model, output_folder=None, threshold=0.5):
    if output_folder:
        os.makedirs(output_folder, exist_ok=True)

    for img_name in os.listdir(input_folder):
        img_path = os.path.join(input_folder, img_name)
        
        print(f"Predicting for {img_name}...")
        predicted_mask_resized = predict_and_visualize(img_path, model, threshold)  
        
        if output_folder:
            output_path = os.path.join(output_folder, f"{img_name}_mask.png")
            cv2.imwrite(output_path, predicted_mask_resized * 255)  

In [25]:
input_folder = './input_images/'
output_folder = './output_predictions/'

In [None]:
predict_on_folder(input_folder, model, output_folder=output_folder)