In [18]:
import cv2
import os
import numpy as np
import torch
import torch.optim as optim
from torch import nn
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from PIL import Image, ImageDraw
from segmentation_models_pytorch.losses import DiceLoss

In [19]:
if torch.cuda.is_available():
    device = torch.device("cuda:2")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

using device: cuda:2


In [20]:
def save_mask(image: Image, mask: np.ndarray, box, iter, alpha=0.7):
    image = (np.array(image) / 255.).transpose(2, 0, 1)
    mask = (np.array(mask) / 255.)
    mask = np.stack([mask[0] * 87./255, mask[0] * 186./255, mask[0] * 168./255])

    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    
    image_transparency = (image * (1 - alpha) + mask * alpha).clip(0, 1)
    
    image = np.where(mask, image_transparency, image).clip(0, 1)
    image = Image.fromarray((image.transpose(1, 2, 0) * 255).astype(dtype=np.uint8))
    
    draw = ImageDraw.Draw(image)
    x0, y0, x1, y1 = box
    
    for i in range(5):
        draw.rectangle([x0 + i, y0 + i, x1 - i, y1 - i], outline=(0, 255, 0), width=1)

    image.save('./bbox_frames/' + str(iter).zfill(4) + '.jpg')

In [21]:
def square (coords):
    return (coords[2] - coords[0]) * (coords[3] - coords[1])

In [102]:
image = Image.open('../images/GrabCut/data_GT/person6.jpg')
image = np.array(image.convert("RGB"))
# image = image[:-1,:-1,:]

In [103]:
h, w = image.shape[:2]
print(h,w)

600 450


In [104]:
zero_mask = Image.open('../images/GrabCut/boundary_GT/person6.bmp')
zero_mask = np.array(zero_mask)
# zero_mask = zero_mask[:-1,:-1]
zero_mask = torch.as_tensor(zero_mask, dtype=torch.float)

In [105]:
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

sam2_checkpoint = "../../checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"

sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)

predictor = SAM2ImagePredictor(sam2_model)

# убираем отслеживание всех параметров
for param in predictor.model.parameters():
    param.requires_grad = False

In [106]:
predictor.set_image(image)

In [115]:
input_box = np.array([170, 206, 296, 282])

In [116]:
mask_input, unnorm_coords, labels, unnorm_box_input = predictor._prep_prompts(
    point_coords=None, point_labels=None, box=input_box, mask_logits=None, normalize_coords=True
)

In [117]:
unnorm_box_input

tensor([[[386.8445, 351.5733],
         [673.5645, 481.2800]]], device='cuda:2')

In [118]:
# boxes_n = torch.as_tensor(input_box, dtype=torch.float, device=predictor.device)
# boxes = boxes_n.reshape(-1, 2, 2)
# coords = boxes.clone()
# coords[..., 0] = coords[..., 0] / w
# print(coords)
# coords[..., 1] = coords[..., 1] / h
# print(coords)
# coords = coords * 1024  # хз почему, мб макс разрешение sam2
# print(coords)

In [119]:
# делаем обратную нормализацию координат
def back_norm(unnorm_box, w, h):
    norm_box = unnorm_box/1024
    coords = norm_box.clone()
    coords[..., 0] = coords[..., 0] * w
    coords[..., 1] = coords[..., 1] * h
    box_coord = coords.squeeze(0).float().detach().cpu().numpy().flatten()
    return box_coord

In [120]:
criterion = DiceLoss('binary')

In [121]:
iters = 2000
lambda_reg = 0.0000001

unnorm_box = unnorm_box_input
unnorm_box.requires_grad = True
opt = optim.Adam([unnorm_box], lr=1e-2)

for iter in range(iters):
    opt.zero_grad()
    
    masks, iou_predictions, low_res_masks = predictor._predict(
        unnorm_coords,
        labels,
        unnorm_box,
        mask_input,
        multimask_output=False,
        return_logits=True
    )
    
    masks_np = masks.clip(0, 1).squeeze(0).detach().cpu().numpy()*255
    iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy()
    low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
    unnorm_coords_np = back_norm(unnorm_box, w, h)
    
    masks_np2 = masks_np.squeeze(0)
    masks_loss = masks.cpu()
    zero_mask_torch = zero_mask.unsqueeze(0).clip(0, 1)
    
    save_mask(image, masks_np, unnorm_coords_np, iter, alpha=0.7)
    
    loss = criterion(masks_loss, zero_mask_torch)
    reg_loss = abs(square(input_box)-square(unnorm_coords_np))
    
    #минимизация
    # total_loss = loss + lambda_reg * reg_loss**2
    #максимизация
    total_loss = -loss
    
    if iter % 100 == 0:
        print('\n\niter = ', iter)
        print('IoU', iou_predictions)
        print('loss',loss)
        print('reg',lambda_reg * reg_loss**2)
    
    (-iou_predictions).backward()
    opt.step()



iter =  0
IoU tensor([[0.7196]], device='cuda:2', grad_fn=<WhereBackward0>)
loss tensor(0.8217, grad_fn=<MeanBackward0>)
reg 0.0


iter =  100
IoU tensor([[0.7261]], device='cuda:2', grad_fn=<WhereBackward0>)
loss tensor(0.8220, grad_fn=<MeanBackward0>)
reg 0.0009332805275917053


iter =  200
IoU tensor([[0.7388]], device='cuda:2', grad_fn=<WhereBackward0>)
loss tensor(0.8241, grad_fn=<MeanBackward0>)
reg 0.004136059619522094


iter =  300
IoU tensor([[0.7445]], device='cuda:2', grad_fn=<WhereBackward0>)
loss tensor(0.8241, grad_fn=<MeanBackward0>)
reg 0.008085969876194


iter =  400
IoU tensor([[0.7512]], device='cuda:2', grad_fn=<WhereBackward0>)
loss tensor(0.8247, grad_fn=<MeanBackward0>)
reg 0.01349995518503189


iter =  500
IoU tensor([[0.7588]], device='cuda:2', grad_fn=<WhereBackward0>)
loss tensor(0.8239, grad_fn=<MeanBackward0>)
reg 0.021127758384799956


iter =  600
IoU tensor([[0.7657]], device='cuda:2', grad_fn=<WhereBackward0>)
loss tensor(0.8219, grad_fn=<MeanBackward0

In [122]:
!ffmpeg -i /home/user20/segment-anything-2/notebooks/sam2_adversarial_attacks/bbox_frames/%04d.jpg -r 30 bbox_video/out54.mp4

ffmpeg version 4.4.2-0ubuntu0.22.04.1 Copyright (c) 2000-2021 the FFmpeg developers
  built with gcc 11 (Ubuntu 11.2.0-19ubuntu1)
  configuration: --prefix=/usr --extra-version=0ubuntu0.22.04.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librabbitmq --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libsrt --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enab