In [None]:
import torch
from PIL import Image
from torchvision import transforms

# Simple fix for BiRefNet compatibility
import transformers.configuration_utils
original_getattribute = transformers.configuration_utils.PretrainedConfig.__getattribute__

def patched_getattribute(self, key):
    if key == 'is_encoder_decoder':
        return False
    return original_getattribute(self, key)

transformers.configuration_utils.PretrainedConfig.__getattribute__ = patched_getattribute

from transformers import AutoModelForImageSegmentation

In [None]:
def load_birefnet_with_custom_weights(checkpoint_path: str):
    """Load BiRefNet model with custom fine-tuned weights"""
    
    # Load the base model from HuggingFace
    model = AutoModelForImageSegmentation.from_pretrained(
        "ZhengPeng7/BiRefNet", 
        trust_remote_code=True
    )
    
    # Load and apply custom weights
    state_dict = torch.load(checkpoint_path, map_location='cpu')
    
    # Clean up weight keys if needed (remove module prefixes)
    clean_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith("module._orig_mod."):
            clean_state_dict[k[len("module._orig_mod."):]] = v
        elif k.startswith("module."):
            clean_state_dict[k[len("module."):]] = v
        else:
            clean_state_dict[k] = v
    
    model.load_state_dict(clean_state_dict)
    return model


def remove_background(image_path: str, model, device='cpu'):
    """Remove background from image using BiRefNet"""
    
    # Image preprocessing
    transform = transforms.Compose([
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    # Load and process image
    image = Image.open(image_path).convert('RGB')
    input_tensor = transform(image).unsqueeze(0).to(device)
    
    # Generate mask
    with torch.no_grad():
        preds = model(input_tensor)[-1].sigmoid().cpu()
    
    # Convert mask to PIL and resize to original size
    mask = transforms.ToPILImage()(preds[0].squeeze())
    mask = mask.resize(image.size)
    
    # Apply mask to create transparent background
    result = image.copy()
    result.putalpha(mask)
    
    return result

In [None]:
# download the weights from huggingface and put the local path here
# https://huggingface.co/joelseytre/toonout
local_path_to_weights = "path/to/your/finetuned_weights.pth"

test_image_path = "images/test_image_samurai.jpg"

model = load_birefnet_with_custom_weights(local_path_to_weights)

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# Remove background from an image  
result = remove_background(test_image_path, model, device)

print("Background removed successfully!")

display(Image.open(test_image_path))
display(result)