In [1]:
import numpy as np 
import torch
from torchvision import transforms
from PIL import Image
from model import UNET 
import os 
import albumentations as A 
from albumentations.pytorch import ToTensorV2 
import time 

  check_for_updates()


In [2]:
# Assuming your model is defined elsewhere
# model = UNET(in_channels=3, out_channels=1).to(DEVICE)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 
IMAGE_HEIGHT = 480 
IMAGE_WIDTH = 640 

model = UNET(in_channels=3, out_channels=1).to(DEVICE) 
load_checkpoint(torch.load("./models/my_checkpoint.pth.tar", map_location=torch.device(DEVICE)), model) 

def inference(model, image): 
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Disable gradient calculation
        preds = torch.sigmoid(model(image))
        preds = (preds > 0.5).float()
    return preds 

transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            max_pixel_value=1.0,
        ),
        ToTensorV2(), 
    ]
)

# image_dir = "./Test Images/sim_images_20250325/" 
# image_dir = "/home/anegi/abhay_ws/marker_detection_failure_recovery/segmentation_model/Test Images/sdg_markers_20250325-132238/rgb/"
image_dir = "/home/anegi/abhay_ws/marker_detection_failure_recovery/segmentation_model/Test Images/GITAI/Top Right Frames/"
# image_dir = "sim_images_20250314" 
# image_dir = "/home/anegi/abhay_ws/marker_detection_failure_recovery/output/markers_20250314-181037/rgb/"
all_images = [f for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))]  # Filter out directories
output_dir = os.path.join(image_dir,f"predictions_{time.strftime('%Y%m%d-%H%M%S')}")
os.makedirs(output_dir, exist_ok=True)  # Create the output directory if it doesn't exist
os.makedirs(os.path.join(output_dir, "predictions"), exist_ok=True)  # Create the output directory if it doesn't exist
os.makedirs(os.path.join(output_dir, "combined"), exist_ok=True)  # Create the output directory if it doesn't exist

for i in range(len(all_images)): 
    image_path = os.path.join(image_dir, all_images[i]) 
    image = Image.open(image_path).convert("RGB")  # Open the image and convert to RGB
    image = np.array(image)  # Convert the image to a numpy array
    transformed = transform(image=image)  # Apply the transform
    image_tensor = transformed["image"].unsqueeze(0).to(DEVICE)  # Add batch dimension and move to device

    pred = inference(model, image_tensor)

    # Convert the prediction to a PIL image (single channel)
    pred_image = transforms.ToPILImage()(pred.squeeze(0).cpu())

    # Convert original image to PIL and resize it to the same size as the prediction
    original_image = Image.fromarray(image)
    original_image = original_image.resize((IMAGE_WIDTH, IMAGE_HEIGHT))

    # Convert the predicted image to 3 channels (repeat the single channel across 3 channels)
    pred_image_rgb = pred_image.convert("RGB")
    pred_image_rgb.save(os.path.join(output_dir, "predictions", all_images[i].removesuffix(".png") + "_prediction.png"))

    # Concatenate the original image and the predicted image side by side
    combined = np.concatenate((np.array(original_image), np.array(pred_image_rgb)), axis=1)

    # Convert back to PIL Image and save the combined image
    combined_image = Image.fromarray(combined)
    combined_image.save(os.path.join(output_dir, "combined", all_images[i].removesuffix(".png") + "_combined.png"))


  load_checkpoint(torch.load("./models/my_checkpoint.pth.tar", map_location=torch.device(DEVICE)), model)


=> Loading checkpoint
