In [44]:
import numpy as np 
import torch
from torchvision import transforms
from PIL import Image
from model import UNET 
import os 
import albumentations as A 
from albumentations.pytorch import ToTensorV2 


In [51]:
def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 
IMAGE_HEIGHT = 270 
IMAGE_WIDTH = 480 

model = UNET(in_channels=3, out_channels=1).to(DEVICE) 
load_checkpoint(torch.load("my_checkpoint.pth_0.9998.tar"), model) 

def inference(model, image): 
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Disable gradient calculation
        preds = torch.sigmoid(model(image))
        preds = (preds > 0.5).float()
    return preds 

transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            # mean=[0.0, 0.0, 0.0],
            # std=[1.0, 1.0, 1.0],
            # max_pixel_value=255.0,
            max_pixel_value=1.0,
        ),
        ToTensorV2(), 
    ]
)

image_dir = "test_images_real" 
all_images = [f for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))]  # Filter out directories
output_dir = os.path.join(image_dir, "predictions")
os.makedirs(output_dir, exist_ok=True)  # Create the output directory if it doesn't exist

for i in range(len(all_images)): 
    image_path = os.path.join(image_dir, all_images[i]) 
    image = Image.open(image_path).convert("RGB")  # Open the image and convert to RGB
    image = np.array(image)  # Convert the image to a numpy array
    transformed = transform(image=image)  # Apply the transform
    image = transformed["image"].unsqueeze(0).to(DEVICE)  # Add batch dimension and move to device

    pred = inference(model, image)

    # Convert the prediction to a PIL image and save it
    pred_image = transforms.ToPILImage()(pred.squeeze(0).cpu())
    pred_image.save(os.path.join(output_dir, f"prediction_{i}.png"))

=> Loading checkpoint
