In [None]:
!git clone "https://github.com/facebookresearch/segment-anything.git"
!git clone "https://github.com/LilianFontalvo/HackathonSSD2024.git"
!pip install -q supervision --upgrade supervision

In [None]:
import cv2
import supervision as sv
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
%cd segment-anything

In [None]:
!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

In [None]:
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_h"


CHECKPOINT_PATH = "./sam_vit_h_4b8939.pth"
sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
mask_generator = SamAutomaticMaskGenerator(sam)

In [None]:
def get_fragments(image_path):
    image_bgr = cv2.imread(image_path)
    image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
    sam_result = mask_generator.generate(image_rgb)
    masks = [o["segmentation"] for o in sam_result]
    all_mask = np.zeros(image_rgb.shape[:2])
    for mask in masks:
        all_mask = np.logical_or(all_mask, mask)
    remaining = np.logical_not(all_mask)
    fragments = [remaining.astype(int)]
    fragments.extend([mask.astype(int) for mask in masks])
    return fragments

In [None]:
dir_path = "../HackathonSSD2024/dataset/trainset/images/"
files = os.listdir(dir_path)
all_fragments = {}

for file in tqdm(files[:20]):
    all_fragments[file] = get_fragments(dir_path + file)

In [None]:
def get_label(frag, file, show=False):
    mask_file = "../HackathonSSD2024/dataset/trainset/masks/" + file.replace(
        "jpg", "png"
    )
    mask = cv2.imread(mask_file)
    mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)//255
    test = mask * frag
    if show:
        plt.subplot(1,3,1)
        img = cv2.imread("../HackathonSSD2024/dataset/trainset/images/" + file)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        plt.imshow(img)
        plt.subplot(1,3,2)
        plt.imshow(mask, vmin=0, vmax=1)
        plt.subplot(1,3,3)
        plt.imshow(frag, vmin=0, vmax=1)
        plt.show()
    return np.sum(test) / np.sum(frag) > .7


file = list(all_fragments.keys())[0]
frag = all_fragments[file][2]
print(get_label(frag, file, show=True))

In [None]:
# Load a pretrained ResNet model
model = models.resnet50(weights="ResNet50_Weights.DEFAULT")
# Remove the classification layer
model = torch.nn.Sequential(*list(model.children())[:-1])
# Set the model to evaluation mode
model.eval()

# Define preprocessing transformations
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def get_features(frag):
    input_tensor = preprocess(Image.fromarray(frag))
    input_batch = input_tensor.unsqueeze(0)
    with torch.no_grad():
        features = model(input_batch)
    return features.view(-1).numpy()

In [None]:
def add_file_fragments(directory, file, fragments, training_set):
    origin = cv2.imread(directory+file)
    origin = cv2.cvtColor(origin, cv2.COLOR_BGR2RGB)
    for frag in fragments:
        fragment = np.zeros(origin.shape)
        for i in range(3):
            fragment[:,:,i] = origin[:,:,i] * frag.astype(int)
        training_set.append(
            {
                "filename": file,
                "mask": frag.astype(int),
                "size": np.sum(frag.astype(int)),
                "label": get_label(frag, file),
                "fragment": fragment.astype("uint8"),
                "features": get_features(fragment.astype("uint8"))
            }
        )
    

training_set = []
for file in tqdm(all_fragments):
    fragments = all_fragments[file]
    add_file_fragments("../HackathonSSD2024/dataset/trainset/images/",file, fragments, training_set)
print(f"{len(training_set)} fragments in training set.")

In [None]:
sample_frags = get_fragments("../HackathonSSD2024/sample.jpg")
print(len(sample_frags))

sample_mask = sample_frags[0]
sample_frag = cv2.imread("../HackathonSSD2024/sample.jpg")
sample_frag = cv2.cvtColor(sample_frag, cv2.COLOR_BGR2RGB)
for i in range(3):
    sample_frag[:, :, i] = sample_frag[:, :, i] * sample_mask
plt.imshow(sample_frag)
plt.show()


In [None]:
def is_frag_water(frag, show=False):
    features = get_features(frag)
    scores = [cosine_similarity(features.reshape(1,-1), 
                               o["features"].reshape(1,-1)
                              )[0][0] 
              for o in training_set]
    best = np.argmax(scores)
    if show:
        plt.imshow(training_set[best]["fragment"])
        plt.show()
    return training_set[best]["label"]

is_frag_water(sample_frag)