In [None]:
import torch
import torchvision.transforms as T
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os

In [None]:
# Load the pre-trained DeepLabV3 model
model = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet101', pretrained=True)
model.eval()

# Define the preprocessing transformation
preprocess = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def preprocess_image(image_path):
    input_image = Image.open(image_path).convert("RGB")
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0)  # Create a mini-batch as expected by the model
    return input_batch, input_image

def segment_image(input_batch):
    with torch.no_grad():
        output = model(input_batch)['out'][0]
    output_predictions = output.argmax(0).byte().cpu().numpy()
    return output_predictions

def apply_mask(image, mask, class_id=15):
    binary_mask = (mask == class_id).astype(np.uint8)
    masked_image = np.array(image) * np.expand_dims(binary_mask, axis=2)
    return Image.fromarray(masked_image)


In [None]:
def process_and_visualize(image_folder):
    segmented_folder = 'segmented_salamanders'
    os.makedirs(segmented_folder, exist_ok=True)

    image_paths = [os.path.join(image_folder, img) for img in os.listdir(image_folder)]

    for image_path in image_paths:
        input_batch, input_image = preprocess_image(image_path)
        output_predictions = segment_image(input_batch)
        masked_image = apply_mask(input_image, output_predictions)

        segmented_image_path = os.path.join(segmented_folder, os.path.basename(image_path))
        masked_image.save(segmented_image_path)
        print(f"Segmented image saved to {segmented_image_path}")

        # Display the original and masked images side by side
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.title("Original Image")
        plt.imshow(input_image)
        plt.axis('off')

        plt.subplot(1, 2, 2)
        plt.title("Masked Image")
        plt.imshow(masked_image)
        plt.axis('off')

        plt.show()

# Example usage
image_folder = '../../../tfe_data/Building_images' 
process_and_visualize(image_folder)
