In [2]:
from PIL import Image
import torch
import matplotlib.pyplot as plt
import numpy as np

In [3]:
%reload_ext autoreload
%autoreload 2

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation, CLIPSegModel

processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
# model = torch.nn.DataParallel(model,device_ids=[0, 1, 2, 3, 4, 5, 6, 7])
model = model.to(device)

In [6]:
COLOR_MAP = dict({
    # 'background': (0, 0, 0),
    'ship': (0, 0, 63),
    'storage tank': (0, 191, 127),
    'baseball diamond': (0, 63, 0),
    'tennis court': (0, 63, 127),
    'basketball court': (0, 63, 191),
    'ground track field': (0, 63, 255),
    'bridge': (0, 127, 63),
    'large vehicle': (0, 127, 127),
    'small vehicle': (0, 0, 127),
    'helicopter': (0, 0, 191),
    'swimming pool': (0, 0, 255),
    'roundabout': (0, 63, 63),
    'soccer ball field': (0, 127, 191),
    'plane': (0, 127, 255),
    'harbor': (0, 100, 155),
})

prompts = list(COLOR_MAP.keys())
class_num = len(prompts)
color = torch.tensor(list(COLOR_MAP.values()), dtype=torch.uint8).to(device)
threshold = 0.3
print(prompts)

['ship', 'storage tank', 'baseball diamond', 'tennis court', 'basketball court', 'ground track field', 'bridge', 'large vehicle', 'small vehicle', 'helicopter', 'swimming pool', 'roundabout', 'soccer ball field', 'plane', 'harbor']


In [7]:
def get_pred(image_np, debug=0):
  image_ori = Image.fromarray(image_np)
  final_shape = np.ceil(np.array(image_np.shape)[0:2]/352).astype(np.int32)*352
  final_shape = (class_num, final_shape[0], final_shape[1])
  final_pred = torch.zeros(final_shape).to(device)
  # maybe use unfold
  for w_start in range(0, final_shape[2], 352):
    for h_start in range(0, final_shape[1], 352):
      image = image_ori.crop((w_start, h_start, 352+w_start, 352+h_start))
      inputs = processor(text=prompts, images=[image] * class_num, padding="max_length", return_tensors="pt")
      inputs = inputs.to(device)
      # predict
      with torch.no_grad():
        outputs = model(**inputs)
      # visualize prediction
      preds = outputs.logits
      final_pred[:, h_start:h_start+352, w_start:w_start+352] = preds
      if not debug:
        continue
      display(image)
      _, ax = plt.subplots(1, class_num+1, figsize=(20, 20))
      [a.axis('off') for a in ax.flatten()]
      ax[0].imshow(image_ori)
      p = preds.sigmoid()
      for i in range(class_num):
          res = p[i].to('cpu')
          ax[i+1].imshow(res)
      [ax[i+1].text(0, -15, prompts[i]) for i in range(class_num)];
      plt.show()
      prob, pred = p.max(dim=0)
      ans_map = color[pred]
      # 置信度低的标签为背景
      ans_map[prob<threshold] = 255
      img = Image.fromarray(ans_map.to('cpu').numpy(), mode='RGB')
      display(img)
  
  final_pred = final_pred[:, :image_ori.height, :image_ori.width]
  return final_pred.sigmoid()

In [8]:
# visualize prediction
def vis(image_np, final_pred, threshold_list = [0.3]):
    image_ori = Image.fromarray(image_np)
    prob, pred = final_pred.max(dim=0)
    # 置信度低的标签为背景

    for threshold in threshold_list:
        ans_map = color[pred].clone()
        ans_map[prob<threshold] = 255
        img = Image.fromarray(ans_map.cpu().numpy(), mode='RGB')
        display(Image.blend(image_ori, img, 0.6))
        ans_map[prob<threshold] = 0
        img = Image.fromarray(ans_map.cpu().numpy(), mode='RGB')
        display(Image.blend(image_ori, img, 1))
# save prediction
def save(image_np, final_pred, file_name, threshold_list = [0.3]):
    image_ori = Image.fromarray(image_np)
    prob, pred = final_pred.max(dim=0)
    # 置信度低的标签为背景

    for threshold in threshold_list:
        ans_map = color[pred].clone()
        # ans_map[prob<threshold] = 255
        # img = Image.fromarray(ans_map.cpu().numpy(), mode='RGB')
        # display(Image.blend(image_ori, img, 0.6))
        ans_map[prob<threshold] = 0
        img = Image.fromarray(ans_map.cpu().numpy(), mode='RGB')
        img.save('farsegResultWithThreshold='+str(threshold)+'/'+file_name)

In [9]:
def get_iou_per_class(cm):
    sum_over_row = cm.sum(axis=0)
    sum_over_col = cm.sum(axis=1)
    diag = torch.diag(cm)
    print(sum_over_col.dtype)
    print(sum_over_row)
    print(diag)
    denominator = sum_over_row + sum_over_col - diag
    iou_per_class = diag / denominator
    return iou_per_class

In [10]:

from data.farseg import ImageFolderDataset
from torchmetrics.classification import MulticlassConfusionMatrix
from torchmetrics.classification import MulticlassJaccardIndex
dataset = ImageFolderDataset(img_dir='isaid_segm/val/images/images', mask_dir='isaid_segm/val/masks/images')

In [12]:
mIOU = []
cm = []

threshold_list = [0.3, 0.4, 0.5]
for i in range(3):
    mIOU.append(MulticlassJaccardIndex(class_num+1).to(device))
    cm.append(MulticlassConfusionMatrix(class_num+1).to(device))

for i in range(140, len(dataset)):
    img, gt, filename = dataset[i]
    gt = torch.tensor(gt).to(device)
    pred = get_pred(img)
    # vis(img, pred)
    save(img, pred, filename, threshold_list)

    print(filename)
    prob, y_pred = pred.max(dim=0)
    y_pred += 1

    gt = gt.reshape(-1)
    t = y_pred.reshape(-1)
    for i in range(3):
        y_pred[prob<threshold_list[i]]=0
        cm[i](t, gt)
        print(mIOU[i](t, gt))

for i in range(3):
    print(mIOU.compute())

P2393.png
tensor(0.0533, device='cuda:0')
tensor(0.0603, device='cuda:0')
tensor(0.0668, device='cuda:0')
P0665.png
tensor(0.0447, device='cuda:0')
tensor(0.0431, device='cuda:0')
tensor(0.0443, device='cuda:0')
P1143.png
tensor(0.0799, device='cuda:0')
tensor(0.0828, device='cuda:0')
tensor(0.0838, device='cuda:0')
P1269.png
tensor(0.0168, device='cuda:0')
tensor(0.0249, device='cuda:0')
tensor(0.0349, device='cuda:0')
P0543.png
tensor(0.0378, device='cuda:0')
tensor(0.0435, device='cuda:0')
tensor(0.0541, device='cuda:0')
P1397.png
tensor(0.0397, device='cuda:0')
tensor(0.0514, device='cuda:0')
tensor(0.0646, device='cuda:0')
P0368.png
tensor(0.0578, device='cuda:0')
tensor(0.0637, device='cuda:0')
tensor(0.0694, device='cuda:0')
P2478.png
tensor(0.0328, device='cuda:0')
tensor(0.0451, device='cuda:0')
tensor(0.0543, device='cuda:0')
P2714.png
tensor(0.0478, device='cuda:0')
tensor(0.0545, device='cuda:0')
tensor(0.0596, device='cuda:0')
P2413.png
tensor(0.0366, device='cuda:0')
tens

AttributeError: 'list' object has no attribute 'compute'

In [15]:
# print(cm.compute())
for i in range(3):
    print(mIOU[i].compute())
    print(cm[i].compute())
    # iou_per_class = get_iou_per_class(cm[i].compute())
    # print(iou_per_class, iou_per_class.mean())

tensor(0.0809, device='cuda:0')
tensor(0.0809, device='cuda:0')
tensor(0.0919, device='cuda:0')
tensor(0.0919, device='cuda:0')
tensor(0.1018, device='cuda:0')
tensor(0.1018, device='cuda:0')
