In [None]:
import torch
from torchvision import models, transforms
from PIL import Image
import os

model = models.segmentation.deeplabv3_resnet101(pretrained=True)
model.eval()

transform = transforms.Compose([
    transforms.ToTensor(),
])


def load_and_preprocess_image(image_path):
    image = Image.open(image_path).convert("RGB")
    input_tensor = transform(image).unsqueeze(0)
    return input_tensor


def segment_image(image_tensor, model):
    with torch.no_grad():
        output = model(image_tensor)['out'][0]
        predictions = output.argmax(0)
    return predictions


def remove_background(input_image_path, output_image_path):
    input_tensor = load_and_preprocess_image(input_image_path)
    mask = segment_image(input_tensor, model)


    binary_mask = (mask == 15).float()


    binary_mask = transforms.Resize(input_tensor.shape[2:])(transforms.ToPILImage()(binary_mask))

    binary_mask = transforms.ToTensor()(binary_mask).permute(1, 2, 0).numpy()
    result_image = input_tensor.squeeze(0).permute(1, 2, 0).numpy() * binary_mask
    result_image = Image.fromarray((result_image * 255).astype('uint8'))
    result_image.save(output_image_path)

input_image_path = "/content/photo.jpg"

input_directory, input_filename = os.path.split(input_image_path)

output_image_filename = f"output_{input_filename}"
output_image_path = os.path.join(input_directory, output_image_filename)

remove_background(input_image_path, output_image_path)
