In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.segmentation.deeplabv3 import deeplabv3_resnet50

class SegModel(nn.Module):  # todo: move to models
    def __init__(self) -> None:
        super().__init__()
        self.model = deeplabv3_resnet50(pretrained=False, num_classes=3, pretrained_backbone=True)
        ckpt = 'segmentator.pt'
        ckpt = torch.load(ckpt, map_location='cpu')['state']
        ckpt = {k: v for k, v in ckpt.items() if k != 'loss.weight'}
        self.load_state_dict(ckpt)
        self.eval().requires_grad_(False)

    def forward(self, x):
        x = self.model(x)['out']
        x = F.softmax(x, dim=1)
        
        background = x[:,0].unsqueeze(1)
        body = x[:,1].unsqueeze(1)
        head = x[:,2].unsqueeze(1)
        return background, body, head

In [2]:
model = SegModel()



In [3]:
import cv2
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np


image = cv2.imread('test.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
transform = transforms.ToTensor()
tensor_image = transform(image)

In [4]:
def semantic_loss(Ig,I,model):
    Sbg, Sbody, Shead = model(Ig)
    SbgI, SbodyI, SheadI = model(I)
    M = 1 - SbodyI
    Limg = torch.norm(M*(Ig-I), p=2) ** 2
    Lhead = torch.norm((SheadI - Shead), p=2) ** 2
    return Limg,Lhead


In [5]:
semantic_loss(tensor_image.view(1,3,490,-1),tensor_image.view(1,3,490,-1),model) # this was expected as both are same

(tensor(0.), tensor(0.))