# InSPyReNet Background Removal - Google Colab

This notebook demonstrates how to use InSPyReNet for background removal in Google Colab.

In [None]:
# Install required packages
!pip install gdown

# Clone the InSPyReNet repository
!git clone https://github.com/plemeri/InSPyReNet.git

# Download the pre-trained model
!gdown --id 1qw1TGadiNHPKJ9MwCTgOrRKNIdrsxVsP -O InSPyReNet/saved_models/isnet.pth

In [None]:
import os
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from torch.nn import functional as F
import matplotlib.pyplot as plt
from google.colab import files

# Add the InSPyReNet directory to Python path
import sys
sys.path.append('InSPyReNet')

class InSPyReNetWrapper:
    def __init__(self, model_path='InSPyReNet/saved_models/isnet.pth', 
                 device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.model = None
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
        ])
        
        self.load_model(model_path)

    def load_model(self, model_path):
        try:
            from models.InSPyReNet import InSPyReNet
            
            self.model = InSPyReNet()
            
            if model_path:
                if os.path.exists(model_path):
                    state_dict = torch.load(model_path, map_location=self.device)
                    self.model.load_state_dict(state_dict)
                else:
                    raise FileNotFoundError(f\"Model file not found: {model_path}\")
            
            self.model = self.model.to(self.device)
            self.model.eval()
            print(f\"Model loaded successfully on {self.device}\")
            
        except ImportError as e:
            raise ImportError(f\"Error loading InSPyReNet: {str(e)}\")

    def preprocess_image(self, image):
        if isinstance(image, str):
            if not os.path.exists(image):
                raise FileNotFoundError(f\"Image file not found: {image}\")
            image = Image.open(image).convert('RGB')
        elif not isinstance(image, Image.Image):
            raise TypeError(\"Input must be a PIL Image or path to image\")

        self.original_size = image.size
        image = image.resize((384, 384), Image.BILINEAR)
        image_tensor = self.transform(image).unsqueeze(0)
        return image_tensor.to(self.device)

    def postprocess_mask(self, pred_mask):
        pred_mask = F.interpolate(pred_mask, size=self.original_size, 
                                mode='bilinear', align_corners=False)
        mask = pred_mask.squeeze().cpu().numpy()
        mask = (mask * 255).astype(np.uint8)
        return mask

    def remove_background(self, image, return_mask=False):
        if self.model is None:
            raise RuntimeError(\"Model not loaded. Please load the model first.\")

        image_tensor = self.preprocess_image(image)
        
        if isinstance(image, str):
            image = Image.open(image).convert('RGB')
        
        with torch.no_grad():
            pred_mask = self.model(image_tensor)
        
        mask = self.postprocess_mask(pred_mask)
        mask_image = Image.fromarray(mask)
        
        result = Image.new('RGBA', image.size, (0, 0, 0, 0))
        result.paste(image, mask=mask_image)
        
        if return_mask:
            return result, mask_image
        return result

    def save_result(self, result, output_path):
        result.save(output_path, 'PNG')

def display_results(original, result, mask=None):
    if mask:
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
        ax1.imshow(original)
        ax1.set_title('Original Image')
        ax1.axis('off')
        
        ax2.imshow(result)
        ax2.set_title('Background Removed')
        ax2.axis('off')
        
        ax3.imshow(mask, cmap='gray')
        ax3.set_title('Mask')
        ax3.axis('off')
    else:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
        ax1.imshow(original)
        ax1.set_title('Original Image')
        ax1.axis('off')
        
        ax2.imshow(result)
        ax2.set_title('Background Removed')
        ax2.axis('off')
    
    plt.show()

## Upload and Process an Image

Run the following cell to upload your image and remove its background:

In [None]:
# Initialize the wrapper
wrapper = InSPyReNetWrapper()

# Upload an image
uploaded = files.upload()
image_path = next(iter(uploaded.keys()))

# Open the original image for display
original_image = Image.open(image_path).convert('RGB')

# Remove background and get mask
result, mask = wrapper.remove_background(image_path, return_mask=True)

# Display results
display_results(original_image, result, mask)

# Save the result
output_path = 'result.png'
wrapper.save_result(result, output_path)

# Download the result
files.download('result.png')