In [None]:
!nvidia-smi

In [None]:
pip install opencv-python torch pybase64 numpy supervision matplotlib pytesseract sam2 pillow imutils

In [None]:
!git clone https://github.com/facebookresearch/segment-anything-2.git
%cd ./segment-anything-2
!pip install -e . -q

!mkdir -p ./checkpoints
!wget -q https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt -P ./checkpoints
!wget -q https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt -P ./checkpoints
!wget -q https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt -P ./checkpoints
!wget -q https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt -P ./checkpoints

In [None]:
%cd ./segment-anything-2

In [None]:
import cv2
import torch
import base64

import math
import pytesseract
import numpy as np
import supervision as sv
import matplotlib.pyplot as plt

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

In [None]:
# Configures torch
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

if torch.cuda.get_device_properties(0).major >= 8:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

In [None]:
# Sets Device to CUDA GPU if available and  configures the SAM2 model
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CHECKPOINT = f"./checkpoints/sam2_hiera_large.pt"
CONFIG = "sam2_hiera_l.yaml"

sam2_model = build_sam2(CONFIG, CHECKPOINT, device=DEVICE, apply_postprocessing=False)
mask_generator = SAM2AutomaticMaskGenerator(sam2_model)

In [None]:
# Generates the segmented output
IMAGE_PATH = "/home/d4rkc10ud/Documents/Projects/SmartSplit/receipt_scan/inputs/receipt_walmart.png"

image_bgr = cv2.imread(IMAGE_PATH)
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

sam2_result = mask_generator.generate(image_rgb)

In [None]:
# Annotates and displays the source and segmented images
mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
detections = sv.Detections.from_sam(sam_result=sam2_result)

annotated_image = mask_annotator.annotate(scene=image_bgr.copy(), detections=detections)

sv.plot_images_grid(
    images=[image_bgr, annotated_image],
    grid_size=(1, 2),
    titles=['source image', 'segmented image']
)

In [None]:
# Crops the images based on segment and identifies and saves the segmented receipt image
cropped_images = []
area = math.inf
image_num = None

for i, result in enumerate(sam2_result):
    print(f"Result #{i} bbox:", result["bbox"])
    x, y, width, height = result["bbox"]
    
    x_end = x + width
    y_end = y + height
    
    cropped_image = image_rgb[math.floor(y):math.ceil(y_end), math.floor(x):math.ceil(x_end)]
    
    crop_text = pytesseract.image_to_string(cropped_image).lower()
    
    if result["area"] < area and ("total" in crop_text or "receipt" in crop_text):
        area = result["area"]
        image_num = i
        
    cropped_images.append(cropped_image)

    cv2.imwrite(f"/home/d4rkc10ud/Documents/Projects/SmartSplit/receipt_scan/cropped_receipt/cropped_image_{i}.png", cv2.cvtColor(cropped_image, cv2.COLOR_RGB2BGR))

if image_num != None:
    x, y, width, height = sam2_result[image_num]["bbox"]
    x_end = x + width
    y_end = y + height
    cropped_image = image_rgb[math.floor(y):math.ceil(y_end), math.floor(x):math.ceil(x_end)]
    cv2.imwrite("/home/d4rkc10ud/Documents/Projects/SmartSplit/receipt_scan/cropped_receipt/cropped_image.png", cv2.cvtColor(cropped_image, cv2.COLOR_RGB2BGR))
else:
    print("Receipt Crop not found")

for i, cropped_image in enumerate(cropped_images):
    plt.subplot(1, len(cropped_images), i + 1)
    plt.imshow(cropped_image)
    plt.axis("off")
plt.show()

In [None]:
# Defines preprocessing function
import cv2
import numpy as np
from PIL import Image
import imutils


def preprocess_receipt(image):
    
    processed = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    thresh = cv2.threshold(processed, 0, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)[1]
    
    dist = cv2.distanceTransform(thresh, cv2.DIST_L2, 5)
    dist = cv2.normalize(dist, dist, 0, 1.0, cv2.NORM_MINMAX)
    dist = (dist * 255).astype("uint8")
    processed = cv2.threshold(dist, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
    
    return processed

In [None]:
# Reads the cropped image and preprocesses it for OCR detection
IMAGE_PATH = "/home/d4rkc10ud/Documents/Projects/SmartSplit/receipt_scan/cropped_receipt/cropped_image.png"
image = cv2.imread(IMAGE_PATH)

preprocessed_image = preprocess_receipt(image)

SAVE_PATH = "/home/d4rkc10ud/Documents/Projects/SmartSplit/receipt_scan/cropped_receipt/preprocessed_image.png"
cv2.imwrite(SAVE_PATH, preprocessed_image)
# cv2.imshow("Preprocessed Image", preprocessed_image)

plt.imshow(preprocessed_image)
plt.axis("off")
plt.show()