In [1]:
import torch
print(f"Can work with GPU?: {torch.cuda.is_available()}")
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
from PIL import Image
from IPython.display import clear_output
from bs4 import BeautifulSoup
from segment_anything import sam_model_registry, SamPredictor

Can work with GPU?: True


In [4]:
from tqdm import tqdm

In [2]:
"""
1. Load in the images
2. Load in the annotations
3. Use this as the prompt for the SAM
4. Run SAM with top 3
5. Store all 3 masks
6. Pick best masks and store

Dataset -> SAM segmented -> Data augmented (in process of model) -> Model training

Pipeline:
Image upload -> click on point for SAM segmentation -> predict model -> return probabilities
"""

'\n1. Load in the images\n2. Load in the annotations\n3. Use this as the prompt for the SAM\n4. Run SAM with top 3\n5. Store all 3 masks\n6. Pick best masks and store\n\nDataset -> SAM segmented -> Data augmented (in process of model) -> Model training\n\nPipeline:\nImage upload -> click on point for SAM segmentation -> predict model -> return probabilities\n'

### Load in Annotations

### Helper Functions

In [5]:
def get_image_data(species, image):
    file_path = f"dataset/dataset_new/{species}/{image['name']}"
    x,y = image.select("points")[0]['points'].split(",")
    x,y = int(float(x)), int(float(y))
    return file_path, x,y

def show_mask(mask, ax):
    color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))  

In [6]:
device = 'cuda'
print(torch.__version__, torch.version.cuda, torch.backends.cudnn.version())

1.13.1+cu117 11.7 8500


### Setup SAM

In [7]:
sam = sam_model_registry["default"](checkpoint="./sam_vit_h_4b8939.pth")

In [8]:
sam.to(device='cuda');

In [9]:
predictor = SamPredictor(sam)

# Pipeline
- Load in image with annotation
- Load into SAM
- Use annotation for SAM prompt
- Return mask
- Combine mask with image for new image
- Store this image

In [11]:
def get_image_data(species, image):
    file_path = f"dataset/dataset_new/{species}/{image['name']}"
    x,y = image.select("points")[0]['points'].split(",")
    x,y = int(float(x)), int(float(y))
    return file_path, x,y

In [21]:
done = ["aal", 'alver', 'amerikaanse_hondsvis', 'atlantische_steur', 'barbeel', 'beekforel', 'bermpje', 'bittervoorn', 'blankvoorn', 
        'blauwneus', 'bot', 'brasem', 'bronforel',
        'diklipharder', 'donaubrasem', 'driedoornige_stekelbaars', 'dunlipharder', 'elrits', 'europese_meerval', 'fint',
                'gestippelde_alver', 'giebel', 'goudharder', 'goudvis', 'graskarper', 'grootkopkarper', 'grote_marene', 'gup',
        'karper', 'kesslers_grondel', 'kolblei', 'kopvoorn', 'kroeskarper', 'kwabaal', 'marmergrondel', 'pontische_stroomgrondel', 'pos'
       ]

# blauwband not included
species_list = ['regenboogforel', 'rietvoorn', 'rivierdonderpad', 'riviergrondel', 'roofblei', 'serpeling', 'siberische_steur', 'sneep',
                'snoekbaars', 'spiering', 'sterlet', 'tiendoornige_stekelbaars', 'vetje', 'wijting', 'winde', 'zalm', 'zeeforel', 'zeelt',
                'zilverkarper', 'zonnebaars', 'zwarte_dwergmeerval', 'zwarte_grondel'
               ]
err = 0
for species in species_list:
    with open(f'{species}.xml', 'r') as f:
        data = f.read()

    soup = BeautifulSoup(data, "xml")

    images = soup.select("image")

    output_path = f"sam_dataset/images/{species}"
    for elem in tqdm(images):
        file_path, x,y = get_image_data(species, elem)
        print(file_path)
        image = cv2.imread(file_path)
        try:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        except:
            err += 1
            continue
        predictor.set_image(image)

        input_point = np.array([[x,y]])
        input_label = np.array([1])

        masks, scores, logits = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            multimask_output=True,
        )
        correct_mask = 1

        y=np.expand_dims(masks[int(correct_mask)],axis=2)
        newmask=np.concatenate((y,y,y),axis=2)
        cob= image * newmask

        im = Image.fromarray(cob)
        im.save(f"{output_path}/sam_{elem['name']}")
        clear_output(wait=True)
    
    

100%|██████████| 49/49 [01:01<00:00,  1.25s/it]


In [23]:
err

0