In [20]:
import torch
import torch.nn.functional as F
from functorch.dim import dims
import math
import torch.nn as nn
from torchvision import models
from torchvision.ops import roi_align
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np

In [21]:
# Roi_Align 부분, 서현님 part
def bilinear_interpolate(input, height, width, y, x, ymask, xmask):
    
    y = y.clamp(min=0)
    x = x.clamp(min=0)
    y_low = y.int()
    x_low = x.int()
    y_high = torch.where(y_low >= height - 1, height - 1, y_low + 1)
    y_low = torch.where(y_low >= height - 1, height - 1, y_low)
    y = torch.where(y_low >= height - 1, y.to(input.dtype), y)
  
    x_high = torch.where(x_low >= width - 1, width - 1, x_low + 1)
    x_low = torch.where(x_low >= width - 1, width - 1, x_low)
    x = torch.where(x_low >= width - 1, x.to(input.dtype), x)
    
    ly = y - y_low
    lx = x - x_low
    hy = 1. - ly
    hx = 1. - lx
    
    def masked_index(y, x):
        y = torch.where(ymask, y, 0)
        x = torch.where(xmask, x, 0)
        return input[y, x]

    v1 = masked_index(y_low, x_low)
    v2 = masked_index(y_low, x_high)
    v3 = masked_index(y_high, x_low)
    v4 = masked_index(y_high, x_high)
    w1 = hy * hx
    w2 = hy * lx
    w3 = ly * hx
    w4 = ly * lx

    val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
    return val

def roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
    _, _, height, width = input.size()

    n, c, ph, pw = dims(4)
    
    ph.size = pooled_height
    pw.size = pooled_width
    offset_rois = rois[n]
    roi_batch_ind = offset_rois[0].int()
    offset = 0.5 if aligned else 0.0
    roi_start_w = offset_rois[1] * spatial_scale - offset
    roi_start_h = offset_rois[2] * spatial_scale - offset
    roi_end_w = offset_rois[3] * spatial_scale - offset
    roi_end_h = offset_rois[4] * spatial_scale - offset

    roi_width = roi_end_w - roi_start_w
    roi_height = roi_end_h - roi_start_h
    if not aligned:
        roi_width = torch.clamp(roi_width, min=1.0)
        roi_height = torch.clamp(roi_height, min=1.0)

    bin_size_h = roi_height / pooled_height
    bin_size_w = roi_width / pooled_width

    offset_input = input[roi_batch_ind][c]

    roi_bin_grid_h = sampling_ratio if sampling_ratio > 0 else torch.ceil(roi_height / pooled_height)
    roi_bin_grid_w = sampling_ratio if sampling_ratio > 0 else torch.ceil(roi_width / pooled_width)

    count = torch.clamp(roi_bin_grid_h * roi_bin_grid_w, min=1)

    iy, ix = dims(2)

    iy.size = height  # < roi_bin_grid_h
    ix.size = width  # < roi_bin_grid_w
    
    y = roi_start_h + ph * bin_size_h + (iy + 0.5) * bin_size_h / roi_bin_grid_h
    x = roi_start_w + pw * bin_size_w + (ix + 0.5) * bin_size_w / roi_bin_grid_w
    ymask = iy < roi_bin_grid_h
    xmask = ix < roi_bin_grid_w
    val = bilinear_interpolate(offset_input, height, width, y, x, ymask, xmask)
    val = torch.where(ymask, val, 0)
    val = torch.where(xmask, val, 0)
    output = val.sum((iy, ix))
    output /= count

    return output.order(n, c, ph, pw)


In [11]:
# featuremap을 얻기 위한 컨볼루션 부분
class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)

        self.avg_pooling = nn.AdaptiveAvgPool2d((28,28))

    def forward(self,x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)

        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool(x)

        x = self.conv3(x)
        x = self.relu(x)
        x = self.pool(x)

        x = self.conv4(x)
        x = self.relu(x)
        x = self.pool(x)

        x = self.avg_pooling(x)

        return x

In [23]:
#test
simple_cnn = CNN()

# coco 이미지 불러오기
image_path = "COCO_train2014_000000000030.jpg"
image = Image.open(image_path)

#텐서 변환
preprocess = transforms.Compose([transforms.ToTensor(),])

# preprocess = transforms.Compose([
#         transforms.Resize((224, 224)),
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
#     ])
input_image = preprocess(image)
input_image = input_image.unsqueeze(0)

output_image = simple_cnn(input_image)


# 임의의 피쳐맵 생성
# featuremap은 cnn을 통과한 이미지의 output을 그대로 가져옴.
# rois는 임의로 설정.
features = output_image
print(features.shape)
rois = torch.tensor([
    [0, 204,31,458,355]
], dtype=torch.float)


#roi를 통과한 output size 설정,
#spatial_scale 은 28,28 사이즈를 7,7로 바꿔야하므로 4.0으로 변경
output_size = (7, 7)
spatial_scale = 1.0 / 4.0

# Call the roi_align function
pooled_features = roi_align(features, rois, spatial_scale, output_size[0], output_size[1], -1, False)

print(pooled_features.sum())

from torchvision.ops import roi_align as roi_align_torchvision

print(roi_align_torchvision(features, rois, output_size, spatial_scale).sum())

torch.Size([1, 512, 28, 28])
tensor(490.0711, grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)


In [13]:
print(pooled_features.shape)

torch.Size([1, 512, 7, 7])


In [67]:
rois = pooled_features[0]
height = pooled_features[0][2]
weight = pooled_features[0][3]
# print(rois)
print(rois[0])
print(height[0])
print(weight[0])

tensor([[0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.]], grad_fn=<SelectBackward0>)
tensor([0.0254, 0.0259, 0.0264, 0.0269, 0.0276, 0.0287, 0.0298],
       grad_fn=<SelectBackward0>)
tensor([6.4516e-05, 6.0186e-05, 5.5856e-05, 5.1526e-05, 4.5897e-05, 3.7237e-05,
        2.8577e-05], grad_fn=<SelectBackward0>)


In [14]:
#마스크부분 & sampadding은 메소드인 padding ='same'과 동일하다.

class SamePad2d(nn.Module):
    """Mimics tensorflow's 'SAME' padding.
    """

    def __init__(self, kernel_size, stride):
        super(SamePad2d, self).__init__()
        self.kernel_size = torch.nn.modules.utils._pair(kernel_size)
        self.stride = torch.nn.modules.utils._pair(stride)

    def forward(self, input):
        in_width = input.size()[2]
        in_height = input.size()[3]
        out_width = math.ceil(float(in_width) / float(self.stride[0]))
        out_height = math.ceil(float(in_height) / float(self.stride[1]))
        pad_along_width = ((out_width - 1) * self.stride[0] +
                           self.kernel_size[0] - in_width)
        pad_along_height = ((out_height - 1) * self.stride[1] +
                            self.kernel_size[1] - in_height)
        pad_left = math.floor(pad_along_width / 2)
        pad_top = math.floor(pad_along_height / 2)
        pad_right = pad_along_width - pad_left
        pad_bottom = pad_along_height - pad_top
        return F.pad(input, (pad_left, pad_right, pad_top, pad_bottom), 'constant', 0)

    def __repr__(self):
        return self.__class__.__name__
  
  
#단일 피쳐맵의 실험을 할때는 batch_size가 필요하지 않아서, 주석처리
# self.batch_size = batch_size 이부분을 제외함.
# 현재 단일 이미지의 경우 [3,512,7,7]의 shape을 가진 tensor로 출력됨.
# num_classes값은 기학습 모델 기준으로 80개의 class 개수를 가지고 있어서
# 80으로 넣어줄 것.
# 현재 기준 roi_align의 output이 [num_rois,in_channels,pool_height,pool_weight]이므로
# 생성자에서 따로 생성할 필요가 없음
# 단, 첫 self.conv1의 입력값은 512로 맞춰줘야함.

class Mask(nn.Module):
  def __init__(self,num_rois,in_channels,pool_height,pool_weight, num_classes):
      super(Mask, self).__init__()
      self.num_rois = num_rois
      self.in_channels = in_channels
      self.pool_height = pool_height
      self.pool_weight = pool_weight
      self.num_classes = num_classes
      self.padding = SamePad2d(kernel_size=3,stride=1)
      self.conv1 = nn.Conv2d(self.in_channels, 256, kernel_size=3, stride=1)
      self.bn1 = nn.BatchNorm2d(256, eps=0.001)
      self.deconv = nn.ConvTranspose2d(256, 80, kernel_size=2, stride=2)
      self.conv2 = nn.Conv2d(80, self.num_classes, kernel_size=3, stride=1)
      self.sigmoid = nn.Sigmoid()
      self.relu = nn.ReLU(inplace=True)

  def forward(self, x):
    x = self.conv1(self.padding(x))
    x = self.bn1(x)
    x = self.relu(x)
    x = self.deconv(x)
    x = self.conv2(self.padding(x))
    x = self.sigmoid(x)
    p_mask = x
    return p_mask


In [15]:
mask = Mask(num_rois=1,in_channels=512,pool_height=7,pool_weight=7,num_classes=80)
mask_out = mask(pooled_features)

In [17]:
mask = Mask(num_rois=1,in_channels=512,pool_height=7,pool_weight=7,num_classes=80)
target = mask(pooled_features)


In [None]:
ck_roi = mask_out[0]
# print(ck_roi)
height = mask_out[0][2]
weight = mask_out[0][3]
print(height)
print(weight)

In [93]:
# 마스크 loss 
mask_prediction = mask_out  # 모델이 예측한 마스크 값
mask_target = torch.rand_like(mask_out, dtype=torch.float)  # 랜덤한 실제 마스크 값, 실제 데이터에 따라 적절한 값을 사용해야 합니다.

# BCELoss를 사용하여 마스크 손실 계산
mask_criterion = nn.BCELoss()
mask_loss = mask_criterion(mask_prediction, mask_target)

# 마스크 손실 출력
print("Mask Loss:", mask_loss.item())

Mask Loss: 0.6933982372283936


In [80]:
fake_target = torch.rand_like(mask_out, dtype=torch.float)
#fake_target 을 0~1로 정규화
fake_target_normal = (fake_target - fake_target.min()) / (fake_target.max() - fake_target.min())

In [103]:
fake_target = torch.rand_like(mask_out, dtype=torch.float)

# Min-Max 스케일링을 통해 [0, 1] 사이로 정규화
fake_tg_normal = F.normalize(fake_target, dim=(2, 3), p=2)
fake_tg_nor1 = fake_tg_normal

In [105]:
print(mask_out.shape)
print(fake_tg_nor1.shape)

torch.Size([3, 80, 14, 14])
torch.Size([3, 80, 14, 14])


In [18]:
import torch.optim as optim

model = Mask(num_rois=1,in_channels=512,pool_height=7,pool_weight=7,num_classes=80)
loss = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr = 0.001)

inputs = pooled_features
mask_target = target


num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    
    output = model(inputs)
    mask_loss = loss(output,mask_target)
    optimizer.zero_grad()
    optimizer.step()
    
    if (epoch+1)%1 ==0:
        print(f"Epoch:{epoch+1},loss:{mask_loss.item():.4f}")



Epoch:1,loss:0.6933
Epoch:2,loss:0.6933
Epoch:3,loss:0.6933
Epoch:4,loss:0.6933
Epoch:5,loss:0.6933


In [7]:
import cv2

image = cv2.imread("COCO_train2014_000000000030.jpg")

#target bbox좌표
#annotation의 bbox 값은 [1,2,3,4] 이렇게 4가지의 값을 가지고있음
#bbox[1,2,3,4]는 [left_top_x,left_top_y,width,height] 값임.
#rectangle을 그리기위해 x1,y1,x2,y2 값이 필요하다면,
#좌표는 annotation 기준으로 [x1,y1,x1+width,y1+height]를 [x1,y1,x2,y2]로
#넣어서 그려주면 target bbox값을 가질 수 있음.

x1,y1,x2,y2 = 204,31,458,355


cv2.rectangle(image,(x1,y1),(x2,y2),(0,255,0),2)
cv2.imshow("image with bbox",image)

cv2.waitKey(0)
cv2.destroyAllWindows()

In [None]:
from torchvision.ops import roi_align

roi = roi_align()