In [None]:
import torch
from torchvision import transforms
from PIL import Image

def load_model(model_path, model_class):
    model = model_class()
    model.load_state_dict(torch.load(model_path))
    model.eval()
    return model

def predict(model, image_path, transform):
    image = Image.open(image_path).convert("L")
    image = transform(image).unsqueeze(0).unsqueeze(0)
    with torch.no_grad():
        output = model(image)
    _, predicted = torch.max(output, 1)
    return predicted.item()

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

unet_model = load_model('unet_model.pth', UNet)
capsnet_model = load_model('efficient_capsnet_model.pth', EfficientCapsNet)

def segment_and_predict(image_path):
    # Segment the image using UNet
    segmented_image = predict(unet_model, image_path, transform)
    segmented_image = segmented_image.squeeze().numpy()
    
    # Save the segmented image for further processing
    segmented_image_path = 'segmented_image.png'
    Image.fromarray((segmented_image * 255).astype('uint8')).save(segmented_image_path)

    # Predict the class using Efficient CapsNet
    prediction = predict(capsnet_model, segmented_image_path, transform)
    return prediction

# Example usage
image_path = 'path/to/test/image.png'
prediction = segment_and_predict(image_path)
print(f"Prediction: {prediction}")
