In [None]:
import torch
import numpy as np
import segmentation_models_pytorch as smp
from torch.utils.data import DataLoader
from torchvision import transforms
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
from segmentation import PetSegmentationDataset, inference_from_dataset  # segmentation.py

# Perform inference
def perform_inference(model_path, device='cuda', output_path='segmentation_result.png'):
    # Load the model
    model = smp.Unet(
        encoder_name="resnet34",
        encoder_weights=None,
        in_channels=3,
        classes=3
    )
    model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
    model = model.to(device)
    model.eval()  # Ensure the model is in evaluation mode

    # Validation dataset transform
    val_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    # Validation dataset & loader
    val_dataset = PetSegmentationDataset(split='val', transform=val_transform)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    # Debug dataset
    print(f"Validation dataset size: {len(val_dataset)}")
    for idx in range(5):  # Check first 5 samples
        _, mask = val_dataset[idx]
        print(f"Mask unique values in sample {idx}: {torch.unique(mask)}")

    # Start inference
    with torch.no_grad():
        for img, mask in val_loader:
            img = img.to(device)
            mask = mask.to(device)

            # Forward pass
            output = model(img)
            preds = torch.argmax(output, dim=1).cpu().numpy()

            # Debugging information
            print(f"Prediction unique values: {np.unique(preds)}")
            print(f"Mask unique values: {torch.unique(mask)}")

            # Save the result with class-based visualization
            cmap = ListedColormap(['purple', 'green', 'blue'])  # Adjust colors for classes
            plt.imshow(preds[0], cmap=cmap)
            plt.axis('off')
            plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
            plt.close()
            print(f"Inference completed and saved to: {output_path}")
            break

# Main execution
if __name__ == "__main__":
    model_path = 'best_model.pth'  # Path to your trained model
    output_path = 'segmentation_result.png'  # Output file path
    perform_inference(model_path, device='cuda:1', output_path=output_path)
