In [144]:
from model import YOLOv3
import torch
import numpy as np

# yolo model

In [145]:
import torch
import torch.nn as nn

In [146]:
""" 
Information about architecture config:
Tuple is structured by (filters, kernel_size, stride) 
Every conv is a same convolution. 
List is structured by "B" indicating a residual block followed by the number of repeats
"S" is for scale prediction block and computing the yolo loss
"U" is for upsampling the feature map and concatenating with a previous layer
"""
yolo_config = [
    (32, 3, 1),
    (64, 3, 2),
    ["B", 1],
    (128, 3, 2),
    ["B", 2],
    (256, 3, 2),
    ["B", 8],
    (512, 3, 2),
    ["B", 8],
    (1024, 3, 2),
    ["B", 4],  # To this point is Darknet-53
    (512, 1, 1),
    (1024, 3, 1),
    "S",
    (256, 1, 1),
    "U",
    (256, 1, 1),
    (512, 3, 1),
    "S",
    (128, 1, 1),
    "U",
    (128, 1, 1),
    (256, 3, 1),
    "S",
]

# SSD300 CONFIGS for COCO (active learning)
coco300_active = {
    'num_classes': 81,
    'lr_steps': (80000, 100000, 120000),
    'max_iter': 120000,
    'feature_maps': [38, 19, 10, 5, 3, 1],
    'min_dim': 300,
    'steps': [8, 16, 32, 64, 100, 300],
    'min_sizes': [21, 45, 99, 153, 207, 261],
    'max_sizes': [45, 99, 153, 207, 261, 315],
    'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
    'variance': [0.1, 0.2],
    'clip': True,
    'num_initial_labeled_set': 5000,
    'num_total_images': 82081,
    'acquisition_budget': 1000,
    'num_cycles': 3,
    'name': 'COCO',
}

In [147]:
'''class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, bn_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels)
        self.leaky = nn.LeakyReLU(0.1)
        self.use_bn_act = bn_act

    def forward(self, x):
        if self.use_bn_act:
            return self.leaky(self.bn(self.conv(x)))
        else:
            return self.conv(x)


class ResidualBlock(nn.Module):
    def __init__(self, channels, use_residual=True, num_repeats=1):
        super().__init__()
        self.layers = nn.ModuleList()
        for repeat in range(num_repeats):
            self.layers += [
                nn.Sequential(
                    CNNBlock(channels, channels // 2, kernel_size=1),
                    CNNBlock(channels // 2, channels, kernel_size=3, padding=1),
                )
            ]

        self.use_residual = use_residual
        self.num_repeats = num_repeats

    def forward(self, x):
        for layer in self.layers:
            if self.use_residual:
                x = x + layer(x)
            else:
                x = layer(x)

        return x


class ScalePrediction(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.pred = nn.Sequential(
            CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1),
            CNNBlock(
                2 * in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1
            ),
        )
        self.num_classes = num_classes

    def forward(self, x):
        return (
            self.pred(x)
            .reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3])
            .permute(0, 1, 3, 4, 2)
        )


class YOLOv3(nn.Module):
    def __init__(self, in_channels=3, num_classes=80):
        super().__init__()
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.layers = self._create_conv_layers()

    def forward(self, x):
        outputs = []  # for each scale
        route_connections = []
        for layer in self.layers:
            if isinstance(layer, ScalePrediction):
                outputs.append(layer(x))
                continue

            x = layer(x)

            if isinstance(layer, ResidualBlock) and layer.num_repeats == 8:
                route_connections.append(x)

            elif isinstance(layer, nn.Upsample):
                x = torch.cat([x, route_connections[-1]], dim=1)
                route_connections.pop()

        return outputs

    def _create_conv_layers(self):
        layers = nn.ModuleList()
        in_channels = self.in_channels

        for module in config:
            if isinstance(module, tuple):
                out_channels, kernel_size, stride = module
                layers.append(
                    CNNBlock(
                        in_channels,
                        out_channels,
                        kernel_size=kernel_size,
                        stride=stride,
                        padding=1 if kernel_size == 3 else 0,
                    )
                )
                in_channels = out_channels

            elif isinstance(module, list):
                num_repeats = module[1]
                layers.append(ResidualBlock(in_channels, num_repeats=num_repeats,))

            elif isinstance(module, str):
                if module == "S":
                    layers += [
                        ResidualBlock(in_channels, use_residual=False, num_repeats=1),
                        CNNBlock(in_channels, in_channels // 2, kernel_size=1),
                        ScalePrediction(in_channels // 2, num_classes=self.num_classes),
                    ]
                    in_channels = in_channels // 2

                elif module == "U":
                    layers.append(nn.Upsample(scale_factor=2),)
                    in_channels = in_channels * 3

        return layers
'''
print("")




In [148]:
# just for testing
# model = YOLOv3(num_classes=80)
# x = torch.randn((2, 3, 416, 416))
# out = model(x)

In [149]:
# model

# yolo gmm

In [150]:
def multibox(out_channels, num_classes):
    number_anchor = 3
    
    loc_mu_1_layers = []
    loc_var_1_layers = []
    loc_pi_1_layers = []
    loc_mu_2_layers = []
    loc_var_2_layers = []
    loc_pi_2_layers = []
    loc_mu_3_layers = []
    loc_var_3_layers = []
    loc_pi_3_layers = []
    loc_mu_4_layers = []
    loc_var_4_layers = []
    loc_pi_4_layers = []
    conf_mu_1_layers = []
    conf_var_1_layers = []
    conf_pi_1_layers = []
    conf_mu_2_layers = []
    conf_var_2_layers = []
    conf_pi_2_layers = []
    conf_mu_3_layers = []
    conf_var_3_layers = []
    conf_pi_3_layers = []
    conf_mu_4_layers = []
    conf_var_4_layers = []
    conf_pi_4_layers = []

    for c in out_channels:
        # for loc and conf (mu var pi) 
        loc_mu_1_layers += [nn.Conv2d(c, number_anchor * 4, kernel_size=3, padding=1)]
        loc_var_1_layers += [nn.Conv2d(c, number_anchor * 4, kernel_size=3, padding=1)]
        loc_pi_1_layers += [nn.Conv2d(c, number_anchor * 4, kernel_size=3, padding=1)]
        loc_mu_2_layers += [nn.Conv2d(c, number_anchor * 4, kernel_size=3, padding=1)]
        loc_var_2_layers += [nn.Conv2d(c, number_anchor * 4, kernel_size=3, padding=1)]
        loc_pi_2_layers += [nn.Conv2d(c, number_anchor * 4, kernel_size=3, padding=1)]
        loc_mu_3_layers += [nn.Conv2d(c, number_anchor * 4, kernel_size=3, padding=1)]
        loc_var_3_layers += [nn.Conv2d(c, number_anchor * 4, kernel_size=3, padding=1)]
        loc_pi_3_layers += [nn.Conv2d(c, number_anchor * 4, kernel_size=3, padding=1)]
        loc_mu_4_layers += [nn.Conv2d(c, number_anchor * 4, kernel_size=3, padding=1)]
        loc_var_4_layers += [nn.Conv2d(c, number_anchor * 4, kernel_size=3, padding=1)]
        loc_pi_4_layers += [nn.Conv2d(c, number_anchor * 4, kernel_size=3, padding=1)]
        conf_mu_1_layers += [nn.Conv2d(c, number_anchor * num_classes, kernel_size=3, padding=1)]
        conf_var_1_layers += [nn.Conv2d(c, number_anchor * num_classes, kernel_size=3, padding=1)]
        conf_pi_1_layers += [nn.Conv2d(c, number_anchor * 1, kernel_size=3, padding=1)]
        conf_mu_2_layers += [nn.Conv2d(c, number_anchor * num_classes, kernel_size=3, padding=1)]
        conf_var_2_layers += [nn.Conv2d(c, number_anchor * num_classes, kernel_size=3, padding=1)]
        conf_pi_2_layers += [nn.Conv2d(c, number_anchor * 1, kernel_size=3, padding=1)]
        conf_mu_3_layers += [nn.Conv2d(c, number_anchor * num_classes, kernel_size=3, padding=1)]
        conf_var_3_layers += [nn.Conv2d(c, number_anchor * num_classes, kernel_size=3, padding=1)]
        conf_pi_3_layers += [nn.Conv2d(c, number_anchor * 1, kernel_size=3, padding=1)]
        conf_mu_4_layers += [nn.Conv2d(c, number_anchor * num_classes, kernel_size=3, padding=1)]
        conf_var_4_layers += [nn.Conv2d(c, number_anchor * num_classes, kernel_size=3, padding=1)]
        conf_pi_4_layers += [nn.Conv2d(c, number_anchor * 1, kernel_size=3, padding=1)]

    return (
        loc_mu_1_layers, loc_var_1_layers, loc_pi_1_layers, loc_mu_2_layers, loc_var_2_layers, loc_pi_2_layers, \
        loc_mu_3_layers, loc_var_3_layers, loc_pi_3_layers, loc_mu_4_layers, loc_var_4_layers, loc_pi_4_layers, \
        conf_mu_1_layers, conf_var_1_layers, conf_pi_1_layers, conf_mu_2_layers, conf_var_2_layers, conf_pi_2_layers, \
        conf_mu_3_layers, conf_var_3_layers, conf_pi_3_layers, conf_mu_4_layers, conf_var_4_layers, conf_pi_4_layers
    )

#### PriorBox, Detect_GMM, And Test condition is remaining
that part is commented

In [151]:
# Originated from https://github.com/amdegroot/ssd.pytorch

from __future__ import division
from math import sqrt as sqrt
from itertools import product as product
import torch


class PriorBox(object):
    """Compute priorbox coordinates in center-offset form for each source
    feature map.
    """
    def __init__(self, cfg):
        super(PriorBox, self).__init__()
        self.image_size = cfg['min_dim']
        # number of priors for feature map location (either 4 or 6)
        self.num_priors = len(cfg['aspect_ratios'])
        self.variance = cfg['variance'] or [0.1]
        self.feature_maps = cfg['feature_maps']
        self.min_sizes = cfg['min_sizes']
        self.max_sizes = cfg['max_sizes']
        self.steps = cfg['steps']
        self.aspect_ratios = cfg['aspect_ratios']
        self.clip = cfg['clip'] 
        self.version = cfg['name']
        for v in self.variance:
            if v <= 0:
                raise ValueError('Variances must be greater than 0')

    def forward(self):
        mean = []
        for k, f in enumerate(self.feature_maps):
            for i, j in product(range(f), repeat=2):
                f_k = self.image_size / self.steps[k]
                # unit center x,y
                cx = (j + 0.5) / f_k
                cy = (i + 0.5) / f_k

                # aspect_ratio: 1
                # rel size: min_size
                s_k = self.min_sizes[k]/self.image_size
                mean += [cx, cy, s_k, s_k]

                # aspect_ratio: 1
                # rel size: sqrt(s_k * s_(k+1))
                s_k_prime = sqrt(s_k * (self.max_sizes[k]/self.image_size))
                mean += [cx, cy, s_k_prime, s_k_prime]

                # rest of aspect ratios
                for ar in self.aspect_ratios[k]:
                    mean += [cx, cy, s_k*sqrt(ar), s_k/sqrt(ar)]
                    mean += [cx, cy, s_k/sqrt(ar), s_k*sqrt(ar)]
        # back to torch land
        output = torch.Tensor(mean).view(-1, 4)
        if self.clip:
            output.clamp_(max=1, min=0)
        return output


In [152]:

from torch.autograd import Variable

class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, bn_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels)
        self.leaky = nn.LeakyReLU(0.1)
        self.use_bn_act = bn_act

    def forward(self, x):
        if self.use_bn_act:
            return self.leaky(self.bn(self.conv(x)))
        else:
            return self.conv(x)


class ResidualBlock(nn.Module):
    def __init__(self, channels, use_residual=True, num_repeats=1):
        super().__init__()
        self.layers = nn.ModuleList()
        for repeat in range(num_repeats):
            self.layers += [
                nn.Sequential(
                    CNNBlock(channels, channels // 2, kernel_size=1),
                    CNNBlock(channels // 2, channels, kernel_size=3, padding=1),
                )
            ]

        self.use_residual = use_residual
        self.num_repeats = num_repeats

    def forward(self, x):
        for layer in self.layers:
            if self.use_residual:
                x = x + layer(x)
            else:
                x = layer(x)

        return x


class ScalePrediction(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.pred = nn.Sequential(
            CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1),
            CNNBlock(
                2 * in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1
            ),
        )
        self.num_classes = num_classes

    def forward(self, x):
        return (
            self.pred(x)
            .reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3])
            .permute(0, 1, 3, 4, 2)
        )


class YOLO_GMM(nn.Module):
    def __init__(self, gmm_head ,in_channels=3, num_classes=80, phase = "train"):
        super().__init__()
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.phase = phase
        ## ***********************
        # REMAINING
        self.priorbox = PriorBox(coco300_active) # '''???'''
        
        with torch.no_grad():
           self.priors = Variable(self.priorbox.forward())
        
        self.layers = self._create_conv_layers()
        # localization GMM parameters
        self.loc_mu_1 = nn.ModuleList(gmm_head[0])
        self.loc_var_1 = nn.ModuleList(gmm_head[1])
        self.loc_pi_1 = nn.ModuleList(gmm_head[2])
        self.loc_mu_2 = nn.ModuleList(gmm_head[3])
        self.loc_var_2 = nn.ModuleList(gmm_head[4])
        self.loc_pi_2 = nn.ModuleList(gmm_head[5])
        self.loc_mu_3 = nn.ModuleList(gmm_head[6])
        self.loc_var_3 = nn.ModuleList(gmm_head[7])
        self.loc_pi_3 = nn.ModuleList(gmm_head[8])
        self.loc_mu_4 = nn.ModuleList(gmm_head[9])
        self.loc_var_4 = nn.ModuleList(gmm_head[10])
        self.loc_pi_4 = nn.ModuleList(gmm_head[11])

        # Classification GMM parameters
        self.conf_mu_1 = nn.ModuleList(gmm_head[12])
        self.conf_var_1 = nn.ModuleList(gmm_head[13])
        self.conf_pi_1 = nn.ModuleList(gmm_head[14])
        self.conf_mu_2 = nn.ModuleList(gmm_head[15])
        self.conf_var_2 = nn.ModuleList(gmm_head[16])
        self.conf_pi_2 = nn.ModuleList(gmm_head[17])
        self.conf_mu_3 = nn.ModuleList(gmm_head[18])
        self.conf_var_3 = nn.ModuleList(gmm_head[19])
        self.conf_pi_3 = nn.ModuleList(gmm_head[20])
        self.conf_mu_4 = nn.ModuleList(gmm_head[21])
        self.conf_var_4 = nn.ModuleList(gmm_head[22])
        self.conf_pi_4 = nn.ModuleList(gmm_head[23])
        
        

        # **************
        # REMAINING
        #if phase == 'test':
        #    self.softmax = nn.Softmax(dim=-1)
        #    self.detect = Detect_GMM(num_classes, 0, 200, 0.01, 0.45)
        
        
        #self.layers = self._create_conv_layers()

    def forward(self, x):
        outputs = []  # for each scale
        route_connections = []
        #sources = list()
        loc_mu_1 = list()
        loc_var_1 = list()
        loc_pi_1 = list()
        loc_mu_2 = list()
        loc_var_2 = list()
        loc_pi_2 = list()
        loc_mu_3 = list()
        loc_var_3 = list()
        loc_pi_3 = list()
        loc_mu_4 = list()
        loc_var_4 = list()
        loc_pi_4 = list()
        conf_mu_1 = list()
        conf_var_1 = list()
        conf_pi_1 = list()
        conf_mu_2 = list()
        conf_var_2 = list()
        conf_pi_2 = list()
        conf_mu_3 = list()
        conf_var_3 = list()
        conf_pi_3 = list()
        conf_mu_4 = list()
        conf_var_4 = list()
        conf_pi_4 = list()
        
        
        
        # for layer in self.layers: 
        # removing ScalePrediction layer and giving previous layer output to gmm layers
        for i ,layer in enumerate(self.layers):
            #if isinstance(layer, ScalePrediction):
            # list represen
            
            x = layer(x) # moved up to pass previous layer out to next layer
            if i in [14, 20, 27]: 
                outputs.append(x)
                
            elif isinstance(layer, ResidualBlock) and layer.num_repeats == 8:
                route_connections.append(x)

            elif isinstance(layer, nn.Upsample):
                x = torch.cat([x, route_connections[-1]], dim=1)
                route_connections.pop()
               
        # for every feature output all corner mu var pi val
        for (x, l_mu_1, l_var_1, l_pi_1, l_mu_2, l_var_2, l_pi_2, l_mu_3, l_var_3, l_pi_3, l_mu_4, l_var_4, l_pi_4, \
        c_mu_1, c_var_1, c_pi_1, c_mu_2, c_var_2, c_pi_2, c_mu_3, c_var_3, c_pi_3, c_mu_4, c_var_4, c_pi_4) in zip(outputs, \
        self.loc_mu_1, self.loc_var_1, self.loc_pi_1, self.loc_mu_2, self.loc_var_2, self.loc_pi_2, \
        self.loc_mu_3, self.loc_var_3, self.loc_pi_3, self.loc_mu_4, self.loc_var_4, self.loc_pi_4, \
        self.conf_mu_1, self.conf_var_1, self.conf_pi_1, self.conf_mu_2, self.conf_var_2, self.conf_pi_2, \
        self.conf_mu_3, self.conf_var_3, self.conf_pi_3, self.conf_mu_4, self.conf_var_4, self.conf_pi_4):
            
            ## no need to changed shape as taking previous layers of scale prid only needed if scaleprid layers are used
            ## change shape [2, 3, 52, 52, 85] to [2, 52, 52, 255]
            # shape = x.shape
            # x = torch.reshape(x.permute(0, 2, 3, 1,4), (shape[0], shape[2],shape[3], shape[1]* shape[4]))
            # x = x.permute(0, 3, 1, 2)
            
            loc_mu_1.append(l_mu_1(x).permute(0, 2, 3, 1).contiguous())
            loc_var_1.append(l_var_1(x).permute(0, 2, 3, 1).contiguous())
            loc_pi_1.append(l_pi_1(x).permute(0, 2, 3, 1).contiguous())
            loc_mu_2.append(l_mu_2(x).permute(0, 2, 3, 1).contiguous())
            loc_var_2.append(l_var_2(x).permute(0, 2, 3, 1).contiguous())
            loc_pi_2.append(l_pi_2(x).permute(0, 2, 3, 1).contiguous())
            loc_mu_3.append(l_mu_3(x).permute(0, 2, 3, 1).contiguous())
            loc_var_3.append(l_var_3(x).permute(0, 2, 3, 1).contiguous())
            loc_pi_3.append(l_pi_3(x).permute(0, 2, 3, 1).contiguous())
            loc_mu_4.append(l_mu_4(x).permute(0, 2, 3, 1).contiguous())
            loc_var_4.append(l_var_4(x).permute(0, 2, 3, 1).contiguous())
            loc_pi_4.append(l_pi_4(x).permute(0, 2, 3, 1).contiguous())
            conf_mu_1.append(c_mu_1(x).permute(0, 2, 3, 1).contiguous())
            conf_var_1.append(c_var_1(x).permute(0, 2, 3, 1).contiguous())
            conf_pi_1.append(c_pi_1(x).permute(0, 2, 3, 1).contiguous())
            conf_mu_2.append(c_mu_2(x).permute(0, 2, 3, 1).contiguous())
            conf_var_2.append(c_var_2(x).permute(0, 2, 3, 1).contiguous())
            conf_pi_2.append(c_pi_2(x).permute(0, 2, 3, 1).contiguous())
            conf_mu_3.append(c_mu_3(x).permute(0, 2, 3, 1).contiguous())
            conf_var_3.append(c_var_3(x).permute(0, 2, 3, 1).contiguous())
            conf_pi_3.append(c_pi_3(x).permute(0, 2, 3, 1).contiguous())
            conf_mu_4.append(c_mu_4(x).permute(0, 2, 3, 1).contiguous())
            conf_var_4.append(c_var_4(x).permute(0, 2, 3, 1).contiguous())
            conf_pi_4.append(c_pi_4(x).permute(0, 2, 3, 1).contiguous())
            
        # get every think in a row [N, float output of  6 features] 
        loc_mu_1 = torch.cat([o.view(o.size(0), -1) for o in loc_mu_1], 1)
        loc_var_1 = torch.cat([o.view(o.size(0), -1) for o in loc_var_1], 1)
        loc_pi_1 = torch.cat([o.view(o.size(0), -1) for o in loc_pi_1], 1)
        loc_mu_2 = torch.cat([o.view(o.size(0), -1) for o in loc_mu_2], 1)
        loc_var_2 = torch.cat([o.view(o.size(0), -1) for o in loc_var_2], 1)
        loc_pi_2 = torch.cat([o.view(o.size(0), -1) for o in loc_pi_2], 1)
        loc_mu_3 = torch.cat([o.view(o.size(0), -1) for o in loc_mu_3], 1)
        loc_var_3 = torch.cat([o.view(o.size(0), -1) for o in loc_var_3], 1)
        loc_pi_3 = torch.cat([o.view(o.size(0), -1) for o in loc_pi_3], 1)
        loc_mu_4 = torch.cat([o.view(o.size(0), -1) for o in loc_mu_4], 1)
        loc_var_4 = torch.cat([o.view(o.size(0), -1) for o in loc_var_4], 1)
        loc_pi_4 = torch.cat([o.view(o.size(0), -1) for o in loc_pi_4], 1)
        conf_mu_1 = torch.cat([o.view(o.size(0), -1) for o in conf_mu_1], 1)
        conf_var_1 = torch.cat([o.view(o.size(0), -1) for o in conf_var_1], 1)
        conf_pi_1 = torch.cat([o.view(o.size(0), -1) for o in conf_pi_1], 1)
        conf_mu_2 = torch.cat([o.view(o.size(0), -1) for o in conf_mu_2], 1)
        conf_var_2 = torch.cat([o.view(o.size(0), -1) for o in conf_var_2], 1)
        conf_pi_2 = torch.cat([o.view(o.size(0), -1) for o in conf_pi_2], 1)
        conf_mu_3 = torch.cat([o.view(o.size(0), -1) for o in conf_mu_3], 1)
        conf_var_3 = torch.cat([o.view(o.size(0), -1) for o in conf_var_3], 1)
        conf_pi_3 = torch.cat([o.view(o.size(0), -1) for o in conf_pi_3], 1)
        conf_mu_4 = torch.cat([o.view(o.size(0), -1) for o in conf_mu_4], 1)
        conf_var_4 = torch.cat([o.view(o.size(0), -1) for o in conf_var_4], 1)
        conf_pi_4 = torch.cat([o.view(o.size(0), -1) for o in conf_pi_4], 1)
        
        if self.phase == "test":
            # REMAINING 
            loc_var_1 = torch.sigmoid(loc_var_1)
            loc_var_2 = torch.sigmoid(loc_var_2)
            loc_var_3 = torch.sigmoid(loc_var_3)
            loc_var_4 = torch.sigmoid(loc_var_4)

            loc_pi_1 = loc_pi_1.view(-1, 4)
            loc_pi_2 = loc_pi_2.view(-1, 4)
            loc_pi_3 = loc_pi_3.view(-1, 4)
            loc_pi_4 = loc_pi_4.view(-1, 4)

            pi_all = torch.stack(
                [
                    loc_pi_1.reshape(-1),
                    loc_pi_2.reshape(-1),
                    loc_pi_3.reshape(-1),
                    loc_pi_4.reshape(-1)
                ]
            )
            pi_all = pi_all.transpose(0,1)
            pi_all = (torch.softmax(pi_all, dim=1)).transpose(0,1).reshape(-1)
            (
                loc_pi_1,
                loc_pi_2,
                loc_pi_3,
                loc_pi_4
            ) = torch.split(pi_all, loc_pi_1.reshape(-1).size(0), dim=0)
            loc_pi_1 = loc_pi_1.view(-1, 4)
            loc_pi_2 = loc_pi_2.view(-1, 4)
            loc_pi_3 = loc_pi_3.view(-1, 4)
            loc_pi_4 = loc_pi_4.view(-1, 4)

            conf_var_1 = torch.sigmoid(conf_var_1)
            conf_var_2 = torch.sigmoid(conf_var_2)
            conf_var_3 = torch.sigmoid(conf_var_3)
            conf_var_4 = torch.sigmoid(conf_var_4)

            conf_pi_1 = conf_pi_1.view(-1, 1)
            conf_pi_2 = conf_pi_2.view(-1, 1)
            conf_pi_3 = conf_pi_3.view(-1, 1)
            conf_pi_4 = conf_pi_4.view(-1, 1)

            conf_pi_all = torch.stack(
                [
                    conf_pi_1.reshape(-1),
                    conf_pi_2.reshape(-1),
                    conf_pi_3.reshape(-1),
                    conf_pi_4.reshape(-1)
                ]
            )
            conf_pi_all = conf_pi_all.transpose(0,1)
            conf_pi_all = (torch.softmax(conf_pi_all, dim=1)).transpose(0,1).reshape(-1)
            (
                conf_pi_1,
                conf_pi_2,
                conf_pi_3,
                conf_pi_4
            ) = torch.split(conf_pi_all, conf_pi_1.reshape(-1).size(0), dim=0)
            conf_pi_1 = conf_pi_1.view(-1, 1)
            conf_pi_2 = conf_pi_2.view(-1, 1)
            conf_pi_3 = conf_pi_3.view(-1, 1)
            conf_pi_4 = conf_pi_4.view(-1, 1)

            output = self.detect(
                self.priors.type(type(x.data)),
                loc_mu_1.view(loc_mu_1.size(0), -1, 4),
                loc_var_1.view(loc_var_1.size(0), -1, 4),
                loc_pi_1.view(loc_var_1.size(0), -1, 4),
                loc_mu_2.view(loc_mu_2.size(0), -1, 4),
                loc_var_2.view(loc_var_2.size(0), -1, 4),
                loc_pi_2.view(loc_var_2.size(0), -1, 4),
                loc_mu_3.view(loc_mu_3.size(0), -1, 4),
                loc_var_3.view(loc_var_3.size(0), -1, 4),
                loc_pi_3.view(loc_var_3.size(0), -1, 4),
                loc_mu_4.view(loc_mu_4.size(0), -1, 4),
                loc_var_4.view(loc_var_4.size(0), -1, 4),
                loc_pi_4.view(loc_var_4.size(0), -1, 4),
                self.softmax(conf_mu_1.view(conf_mu_1.size(0), -1, self.num_classes)),
                conf_var_1.view(conf_var_1.size(0), -1, self.num_classes),
                conf_pi_1.view(conf_var_1.size(0), -1, 1),
                self.softmax(conf_mu_2.view(conf_mu_2.size(0), -1, self.num_classes)),
                conf_var_2.view(conf_var_2.size(0), -1, self.num_classes),
                conf_pi_2.view(conf_var_2.size(0), -1, 1),
                self.softmax(conf_mu_3.view(conf_mu_3.size(0), -1, self.num_classes)),
                conf_var_3.view(conf_var_3.size(0), -1, self.num_classes),
                conf_pi_3.view(conf_var_3.size(0), -1, 1),
                self.softmax(conf_mu_4.view(conf_mu_4.size(0), -1, self.num_classes)),
                conf_var_4.view(conf_var_4.size(0), -1, self.num_classes),
                conf_pi_4.view(conf_var_4.size(0), -1, 1)
            )
            # pass
        else:
            gmm_output = (
                self.priors,
                loc_mu_1.view(loc_mu_1.size(0), -1, 4),
                loc_var_1.view(loc_var_1.size(0), -1, 4),
                loc_pi_1.view(loc_pi_1.size(0), -1, 4),
                loc_mu_2.view(loc_mu_2.size(0), -1, 4),
                loc_var_2.view(loc_var_2.size(0), -1, 4),
                loc_pi_2.view(loc_pi_2.size(0), -1, 4),
                loc_mu_3.view(loc_mu_3.size(0), -1, 4),
                loc_var_3.view(loc_var_3.size(0), -1, 4),
                loc_pi_3.view(loc_pi_3.size(0), -1, 4),
                loc_mu_4.view(loc_mu_4.size(0), -1, 4),
                loc_var_4.view(loc_var_4.size(0), -1, 4),
                loc_pi_4.view(loc_pi_4.size(0), -1, 4),
                conf_mu_1.view(conf_mu_1.size(0), -1, self.num_classes),
                conf_var_1.view(conf_var_1.size(0), -1, self.num_classes),
                conf_pi_1.view(conf_pi_1.size(0), -1, 1),
                conf_mu_2.view(conf_mu_2.size(0), -1, self.num_classes),
                conf_var_2.view(conf_var_2.size(0), -1, self.num_classes),
                conf_pi_2.view(conf_pi_2.size(0), -1, 1),
                conf_mu_3.view(conf_mu_3.size(0), -1, self.num_classes),
                conf_var_3.view(conf_var_3.size(0), -1, self.num_classes),
                conf_pi_3.view(conf_pi_3.size(0), -1, 1),
                conf_mu_4.view(conf_mu_4.size(0), -1, self.num_classes),
                conf_var_4.view(conf_var_4.size(0), -1, self.num_classes),
                conf_pi_4.view(conf_pi_4.size(0), -1, 1)
            )
            
            
    

        return gmm_output

    def _create_conv_layers(self):
        layers = nn.ModuleList()
        in_channels = self.in_channels

        for module in yolo_config:
            if isinstance(module, tuple):
                out_channels, kernel_size, stride = module
                layers.append(
                    CNNBlock(
                        in_channels,
                        out_channels,
                        kernel_size=kernel_size,
                        stride=stride,
                        padding=1 if kernel_size == 3 else 0,
                    )
                )
                in_channels = out_channels

            elif isinstance(module, list):
                num_repeats = module[1]
                layers.append(ResidualBlock(in_channels, num_repeats=num_repeats,))

            elif isinstance(module, str):
                if module == "S":
                    layers += [
                        ResidualBlock(in_channels, use_residual=False, num_repeats=1),
                        CNNBlock(in_channels, in_channels // 2, kernel_size=1),
                        #ScalePrediction(in_channels // 2, num_classes=self.num_classes),
                    ]
                    in_channels = in_channels // 2

                elif module == "U":
                    layers.append(nn.Upsample(scale_factor=2),)
                    in_channels = in_channels * 3
                    

        return layers


In [153]:
def build_yolo_gmm(phase = "train", size=416, num_classes=80):
    if phase != "test" and phase != "train":
        print("ERROR: Phase: " + phase + " not recognized")
        return
    
    out_channels = [512, 256, 128] # for before scale prid layer 
    # out_channels = [255, 255, 255] # if final yolo out given to gmm

    gmm_head = multibox(out_channels, num_classes)

    return YOLO_GMM(gmm_head , phase = phase)

In [154]:
model = build_yolo_gmm()

In [155]:
x = torch.randn((2, 3, 416, 416))
out = model(x)

In [156]:
print("OUTPUT of gmm")
for o in out:
    print(o.shape)

OUTPUT of gmm
torch.Size([8732, 4])
torch.Size([2, 2535, 4])
torch.Size([2, 2535, 4])
torch.Size([2, 2535, 4])
torch.Size([2, 2535, 4])
torch.Size([2, 2535, 4])
torch.Size([2, 2535, 4])
torch.Size([2, 2535, 4])
torch.Size([2, 2535, 4])
torch.Size([2, 2535, 4])
torch.Size([2, 2535, 4])
torch.Size([2, 2535, 4])
torch.Size([2, 2535, 4])
torch.Size([2, 2535, 80])
torch.Size([2, 2535, 80])
torch.Size([2, 2535, 1])
torch.Size([2, 2535, 80])
torch.Size([2, 2535, 80])
torch.Size([2, 2535, 1])
torch.Size([2, 2535, 80])
torch.Size([2, 2535, 80])
torch.Size([2, 2535, 1])
torch.Size([2, 2535, 80])
torch.Size([2, 2535, 80])
torch.Size([2, 2535, 1])


In [157]:

def point_form(boxes):
    """ Convert prior_boxes to (xmin, ymin, xmax, ymax)
    representation for comparison to point form ground truth data.
    Args:
        boxes: (tensor) center-size default boxes from priorbox layers.
    Return:
        boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
    """
    return torch.cat((boxes[:, :2] - boxes[:, 2:]/2,     # xmin, ymin
                     boxes[:, :2] + boxes[:, 2:]/2), 1)  # xmax, ymax


def center_size(boxes):
    """ Convert prior_boxes to (cx, cy, w, h)
    representation for comparison to center-size form ground truth data.
    Args:
        boxes: (tensor) point_form boxes
    Return:
        boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
    """
    return torch.cat((boxes[:, 2:] + boxes[:, :2])/2,  # cx, cy
                     boxes[:, 2:] - boxes[:, :2], 1)  # w, h


def intersect(box_a, box_b):
    """ We resize both tensors to [A,B,2] without new malloc:
    [A,2] -> [A,1,2] -> [A,B,2]
    [B,2] -> [1,B,2] -> [A,B,2]
    Then we compute the area of intersect between box_a and box_b.
    Args:
      box_a: (tensor) bounding boxes, Shape: [A,4].
      box_b: (tensor) bounding boxes, Shape: [B,4].
    Return:
      (tensor) intersection area, Shape: [A,B].
    """
    A = box_a.size(0)
    B = box_b.size(0)
    max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),
                       box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
    min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2),
                       box_b[:, :2].unsqueeze(0).expand(A, B, 2))
    inter = torch.clamp((max_xy - min_xy), min=0)
    return inter[:, :, 0] * inter[:, :, 1]


def jaccard(box_a, box_b):
    """Compute the jaccard overlap of two sets of boxes.  The jaccard overlap
    is simply the intersection over union of two boxes.  Here we operate on
    ground truth boxes and default boxes.
    E.g.:
        A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
    Args:
        box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
        box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
    Return:
        jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
    """
    inter = intersect(box_a, box_b)
    area_a = ((box_a[:, 2]-box_a[:, 0]) *
              (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter)  # [A,B]
    area_b = ((box_b[:, 2]-box_b[:, 0]) *
              (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter)  # [A,B]
    union = area_a + area_b - inter
    return inter / union  # [A,B]


def match(threshold, truths, priors, variances, labels, loc_t, conf_t, idx):
    """Match each prior box with the ground truth box of the highest jaccard
    overlap, encode the bounding boxes, then return the matched indices
    corresponding to both confidence and location preds.
    Args:
        threshold: (float) The overlap threshold used when mathing boxes.
        truths: (tensor) Ground truth boxes, Shape: [num_obj, num_priors].
        priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
        variances: (tensor) Variances corresponding to each prior coord,
            Shape: [num_priors, 4].
        labels: (tensor) All the class labels for the image, Shape: [num_obj].
        loc_t: (tensor) Tensor to be filled w/ endcoded location targets.
        conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
        idx: (int) current batch index
    Return:
        The matched indices corresponding to 1)location and 2)confidence preds.
    """
    # jaccard index
    overlaps = jaccard(
        truths,
        point_form(priors)
    )
    # (Bipartite Matching)
    # [1,num_objects] best prior for each ground truth
    best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
    # [1,num_priors] best ground truth for each prior
    best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
    best_truth_idx.squeeze_(0)
    best_truth_overlap.squeeze_(0)
    best_prior_idx.squeeze_(1)
    best_prior_overlap.squeeze_(1)
    best_truth_overlap.index_fill_(0, best_prior_idx, 2)  # ensure best prior
    # TODO refactor: index  best_prior_idx with long tensor
    # ensure every gt matches with its prior of max overlap
    for j in range(best_prior_idx.size(0)):
        best_truth_idx[best_prior_idx[j]] = j
    matches = truths[best_truth_idx]          # Shape: [num_priors,4]
    conf = labels[best_truth_idx] + 1         # Shape: [num_priors]
    conf[best_truth_overlap < threshold] = 0  # label as background
    loc = encode(matches, priors, variances)
    loc_t[idx] = loc    # [num_priors,4] encoded offsets to learn
    conf_t[idx] = conf  # [num_priors] top class label for each prior


def encode(matched, priors, variances):
    """Encode the variances from the priorbox layers into the ground truth boxes
    we have matched (based on jaccard overlap) with the prior boxes.
    Args:
        matched: (tensor) Coords of ground truth for each prior in point-form
            Shape: [num_priors, 4].
        priors: (tensor) Prior boxes in center-offset form
            Shape: [num_priors,4].
        variances: (list[float]) Variances of priorboxes
    Return:
        encoded boxes (tensor), Shape: [num_priors, 4]
    """

    # dist b/t match center and prior's center
    g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2]
    # encode variance
    g_cxcy /= (variances[0] * priors[:, 2:])
    # match wh / prior wh
    g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
    g_wh = torch.log(g_wh + 1e-9) / variances[1]
    # return target for smooth_l1_loss
    return torch.cat([g_cxcy, g_wh], 1)  # [num_priors,4


# Adapted from https://github.com/Hakuyume/chainer-ssd
def decode(loc, priors, variances):
    """Decode locations from predictions using priors to undo
    the encoding we did for offset regression at train time.
    Args:
        loc (tensor): location predictions for loc layers,
            Shape: [num_priors,4]
        priors (tensor): Prior boxes in center-offset form.
            Shape: [num_priors,4].
        variances: (list[float]) Variances of priorboxes
    Return:
        decoded bounding box predictions
    """

    boxes = torch.cat((
        priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
        priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
    boxes[:, :2] -= boxes[:, 2:] / 2
    boxes[:, 2:] += boxes[:, :2]
    return boxes


def log_sum_exp(x):
    """Utility function for computing log_sum_exp while determining
    This will be used to determine unaveraged confidence loss across
    all examples in a batch.
    Args:
        x (Variable(tensor)): conf_preds from conf layers
    """
    x_max = x.data.max()
    return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max


# Original author: Francisco Massa:
# https://github.com/fmassa/object-detection.torch
# Ported to PyTorch by Max deGroot (02/01/2017)
def nms(boxes, scores, overlap=0.5, top_k=200):
    """Apply non-maximum suppression at test time to avoid detecting too many
    overlapping bounding boxes for a given object.
    Args:
        boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
        scores: (tensor) The class predscores for the img, Shape:[num_priors].
        overlap: (float) The overlap thresh for suppressing unnecessary boxes.
        top_k: (int) The Maximum number of box preds to consider.
    Return:
        The indices of the kept boxes with respect to num_priors.
    """

    keep = scores.new(scores.size(0)).zero_().long()
    if boxes.numel() == 0:
        return keep
    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 2]
    y2 = boxes[:, 3]
    area = torch.mul(x2 - x1, y2 - y1)
    v, idx = scores.sort(0)  # sort in ascending order
    # I = I[v >= 0.01]
    idx = idx[-top_k:]  # indices of the top-k largest vals
    xx1 = boxes.new()
    yy1 = boxes.new()
    xx2 = boxes.new()
    yy2 = boxes.new()
    w = boxes.new()
    h = boxes.new()

    # keep = torch.Tensor()
    count = 0
    while idx.numel() > 0:
        i = idx[-1]  # index of current largest val
        # keep.append(i)
        keep[count] = i
        count += 1
        if idx.size(0) == 1:
            break
        idx = idx[:-1]  # remove kept element from view
        # load bboxes of next highest vals
        torch.index_select(x1, 0, idx, out=xx1)
        torch.index_select(y1, 0, idx, out=yy1)
        torch.index_select(x2, 0, idx, out=xx2)
        torch.index_select(y2, 0, idx, out=yy2)
        # store element-wise max with next highest score
        xx1 = torch.clamp(xx1, min=x1[i])
        yy1 = torch.clamp(yy1, min=y1[i])
        xx2 = torch.clamp(xx2, max=x2[i])
        yy2 = torch.clamp(yy2, max=y2[i])
        w.resize_as_(xx2)
        h.resize_as_(yy2)
        w = xx2 - xx1
        h = yy2 - yy1
        # check sizes of xx1 and xx2.. after each iteration
        w = torch.clamp(w, min=0.0)
        h = torch.clamp(h, min=0.0)
        inter = w*h
        # IoU = i / (area(a) + area(b) - i)
        rem_areas = torch.index_select(area, 0, idx)  # load remaining areas)
        union = (rem_areas - inter) + area[i]
        IoU = inter/union  # store result in iou
        # keep only elements with an IoU <= overlap
        idx = idx[IoU.le(overlap)]
    return keep, count
 

In [158]:

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from data import coco as cfg
# from ..box_utils import match, log_sum_exp
import math

def Gaussian(y, mu, var):
    eps = 0.3
    result = (y-mu)/var
    result = (result**2)/2*(-1)
    exp = torch.exp(result)
    result = exp/(math.sqrt(2*math.pi))/(var + eps)

    return result

def NLL_loss(bbox_gt, bbox_pred, bbox_var):
        bbox_var = torch.sigmoid(bbox_var)
        prob = Gaussian(bbox_gt, bbox_pred, bbox_var)

        return prob

class MultiBoxLoss_GMM(nn.Module):
    """SSD Weighted Loss Function
    Compute Targets:
        1) Produce Confidence Target Indices by matching  ground truth boxes
           with (default) 'priorboxes' that have jaccard index > threshold parameter
           (default threshold: 0.5).
        2) Produce localization target by 'encoding' variance into offsets of ground
           truth boxes and their matched  'priorboxes'.
        3) Hard negative mining to filter the excessive number of negative examples
           that comes with using a large number of default bounding boxes.
           (default negative:positive ratio 3:1)
    """

    def __init__(self, num_classes, overlap_thresh, prior_for_matching,
                 bkg_label, neg_mining, neg_pos, neg_overlap, encode_target,
                 use_gpu=True, cls_type='Type-1'):
        super(MultiBoxLoss_GMM, self).__init__()
        self.use_gpu = use_gpu
        self.num_classes = num_classes
        self.threshold = overlap_thresh
        self.background_label = bkg_label
        self.encode_target = encode_target
        self.use_prior_for_matching = prior_for_matching
        self.do_neg_mining = neg_mining
        self.negpos_ratio = neg_pos
        self.neg_overlap = neg_overlap
        # self.variance = cfg['variance']
        self.variance = 0.1
        self.cls_type = cls_type

    def forward(self, predictions, targets):
        priors, loc_mu_1, loc_var_1, loc_pi_1, loc_mu_2, loc_var_2, loc_pi_2, \
        loc_mu_3, loc_var_3, loc_pi_3, loc_mu_4, loc_var_4, loc_pi_4, \
        conf_mu_1, conf_var_1, conf_pi_1, conf_mu_2, conf_var_2, conf_pi_2, \
        conf_mu_3, conf_var_3, conf_pi_3, conf_mu_4, conf_var_4, conf_pi_4 = predictions

        num = loc_mu_1.size(0)
        priors = priors[:loc_mu_1.size(1), :]
        num_priors = (priors.size(0))
        num_classes = self.num_classes

        # match priors (default boxes) and ground truth boxes
        loc_t = torch.Tensor(num, num_priors, 4)
        conf_t = torch.LongTensor(num, num_priors)
        for idx in range(num):
            truths = targets[idx][:, :-1].data
            labels = targets[idx][:, -1].data
            defaults = priors.data
            match(self.threshold,
                  truths,
                  defaults,
                  self.variance,
                  labels,
                  loc_t,
                  conf_t,
                  idx)
        if self.use_gpu:
            loc_t = loc_t.cuda()
            conf_t = conf_t.cuda()
        # wrap targets
        loc_t = Variable(loc_t, requires_grad=False)
        conf_t = Variable(conf_t, requires_grad=False)

        pos = conf_t > 0
        num_pos = pos.sum(dim=1, keepdim=True)

        pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_mu_1)
        loc_mu_1_ = loc_mu_1[pos_idx].view(-1, 4)
        loc_mu_2_ = loc_mu_2[pos_idx].view(-1, 4)
        loc_mu_3_ = loc_mu_3[pos_idx].view(-1, 4)
        loc_mu_4_ = loc_mu_4[pos_idx].view(-1, 4)

        loc_t = loc_t[pos_idx].view(-1, 4)

        # localization loss
        loss_l_1 = NLL_loss(loc_t, loc_mu_1_, loc_var_1[pos_idx].view(-1, 4))
        loss_l_2 = NLL_loss(loc_t, loc_mu_2_, loc_var_2[pos_idx].view(-1, 4))
        loss_l_3 = NLL_loss(loc_t, loc_mu_3_, loc_var_3[pos_idx].view(-1, 4))
        loss_l_4 = NLL_loss(loc_t, loc_mu_4_, loc_var_4[pos_idx].view(-1, 4))

        loc_pi_1_ = loc_pi_1[pos_idx].view(-1, 4)
        loc_pi_2_ = loc_pi_2[pos_idx].view(-1, 4)
        loc_pi_3_ = loc_pi_3[pos_idx].view(-1, 4)
        loc_pi_4_ = loc_pi_4[pos_idx].view(-1, 4)

        pi_all = torch.stack([
                    loc_pi_1_.reshape(-1),
                    loc_pi_2_.reshape(-1),
                    loc_pi_3_.reshape(-1),
                    loc_pi_4_.reshape(-1)
                    ])
        pi_all = pi_all.transpose(0,1)
        pi_all = (torch.softmax(pi_all, dim=1)).transpose(0,1).reshape(-1)
        (
            loc_pi_1_,
            loc_pi_2_,
            loc_pi_3_,
            loc_pi_4_
        ) = torch.split(pi_all, loc_pi_1_.reshape(-1).size(0), dim=0)
        loc_pi_1_ = loc_pi_1_.view(-1, 4)
        loc_pi_2_ = loc_pi_2_.view(-1, 4)
        loc_pi_3_ = loc_pi_3_.view(-1, 4)
        loc_pi_4_ = loc_pi_4_.view(-1, 4)

        _loss_l = (
            loc_pi_1_*loss_l_1 +
            loc_pi_2_*loss_l_2 +
            loc_pi_3_*loss_l_3 +
            loc_pi_4_*loss_l_4
        )

        epsi = 10**-9
        # balance parameter
        balance = 2.0
        loss_l = -torch.log(_loss_l + epsi)/balance
        loss_l = loss_l.sum()

        if self.cls_type == 'Type-1':
            # Classification loss (Type-1)
            conf_pi_1_ = conf_pi_1.view(-1, 1)
            conf_pi_2_ = conf_pi_2.view(-1, 1)
            conf_pi_3_ = conf_pi_3.view(-1, 1)
            conf_pi_4_ = conf_pi_4.view(-1, 1)

            conf_pi_all = torch.stack([
                            conf_pi_1_.reshape(-1),
                            conf_pi_2_.reshape(-1),
                            conf_pi_3_.reshape(-1),
                            conf_pi_4_.reshape(-1)
                            ])
            conf_pi_all = conf_pi_all.transpose(0,1)
            conf_pi_all = (torch.softmax(conf_pi_all, dim=1)).transpose(0,1).reshape(-1)
            (
                conf_pi_1_,
                conf_pi_2_,
                conf_pi_3_,
                conf_pi_4_
            ) = torch.split(conf_pi_all, conf_pi_1_.reshape(-1).size(0), dim=0)
            conf_pi_1_ = conf_pi_1_.view(conf_pi_1.size(0), -1)
            conf_pi_2_ = conf_pi_2_.view(conf_pi_2.size(0), -1)
            conf_pi_3_ = conf_pi_3_.view(conf_pi_3.size(0), -1)
            conf_pi_4_ = conf_pi_4_.view(conf_pi_4.size(0), -1)

            conf_var_1 = torch.sigmoid(conf_var_1)
            conf_var_2 = torch.sigmoid(conf_var_2)
            conf_var_3 = torch.sigmoid(conf_var_3)
            conf_var_4 = torch.sigmoid(conf_var_4)

            rand_val_1 = torch.randn(conf_var_1.size(0), conf_var_1.size(1), conf_var_1.size(2))
            rand_val_2 = torch.randn(conf_var_2.size(0), conf_var_2.size(1), conf_var_2.size(2))
            rand_val_3 = torch.randn(conf_var_3.size(0), conf_var_3.size(1), conf_var_3.size(2))
            rand_val_4 = torch.randn(conf_var_4.size(0), conf_var_4.size(1), conf_var_4.size(2))

            batch_conf_1 = (conf_mu_1+torch.sqrt(conf_var_1)*rand_val_1).view(-1, self.num_classes)
            batch_conf_2 = (conf_mu_2+torch.sqrt(conf_var_2)*rand_val_2).view(-1, self.num_classes)
            batch_conf_3 = (conf_mu_3+torch.sqrt(conf_var_3)*rand_val_3).view(-1, self.num_classes)
            batch_conf_4 = (conf_mu_4+torch.sqrt(conf_var_4)*rand_val_4).view(-1, self.num_classes)

            loss_c_1 = log_sum_exp(batch_conf_1) - batch_conf_1.gather(1, conf_t.view(-1, 1))
            loss_c_2 = log_sum_exp(batch_conf_2) - batch_conf_2.gather(1, conf_t.view(-1, 1))
            loss_c_3 = log_sum_exp(batch_conf_3) - batch_conf_3.gather(1, conf_t.view(-1, 1))
            loss_c_4 = log_sum_exp(batch_conf_4) - batch_conf_4.gather(1, conf_t.view(-1, 1))

            loss_c = (
                loss_c_1 * conf_pi_1_.view(-1, 1) +
                loss_c_2 * conf_pi_2_.view(-1, 1) +
                loss_c_3 * conf_pi_3_.view(-1, 1) +
                loss_c_4 * conf_pi_4_.view(-1, 1)
            )
            loss_c = loss_c.view(pos.size()[0], pos.size()[1])
            loss_c[pos] = 0  # filter out pos boxes for now  : true -> zero
            loss_c = loss_c.view(num, -1)

            _, loss_idx = loss_c.sort(1, descending=True)
            _, idx_rank = loss_idx.sort(1)
            num_pos = pos.long().sum(1, keepdim=True)
            num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
            neg = idx_rank < num_neg.expand_as(idx_rank)

            # Confidence Loss Including Positive and Negative Examples
            pos_idx = pos.unsqueeze(2).expand_as(conf_mu_1)
            neg_idx = neg.unsqueeze(2).expand_as(conf_mu_1)

            batch_conf_1_ = conf_mu_1+torch.sqrt(conf_var_1)*rand_val_1
            batch_conf_2_ = conf_mu_2+torch.sqrt(conf_var_2)*rand_val_2
            batch_conf_3_ = conf_mu_3+torch.sqrt(conf_var_3)*rand_val_3
            batch_conf_4_ = conf_mu_4+torch.sqrt(conf_var_4)*rand_val_4

            conf_pred_1 = batch_conf_1_[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
            conf_pred_2 = batch_conf_2_[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
            conf_pred_3 = batch_conf_3_[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
            conf_pred_4 = batch_conf_4_[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)

            targets_weighted = conf_t[(pos+neg).gt(0)]

            loss_c_1 = log_sum_exp(conf_pred_1) - conf_pred_1.gather(1, targets_weighted.view(-1, 1))
            loss_c_2 = log_sum_exp(conf_pred_2) - conf_pred_2.gather(1, targets_weighted.view(-1, 1))
            loss_c_3 = log_sum_exp(conf_pred_3) - conf_pred_3.gather(1, targets_weighted.view(-1, 1))
            loss_c_4 = log_sum_exp(conf_pred_4) - conf_pred_4.gather(1, targets_weighted.view(-1, 1))

            _conf_pi_1 = conf_pi_1_[(pos+neg).gt(0)]
            _conf_pi_2 = conf_pi_2_[(pos+neg).gt(0)]
            _conf_pi_3 = conf_pi_3_[(pos+neg).gt(0)]
            _conf_pi_4 = conf_pi_4_[(pos+neg).gt(0)]

            loss_c = (
                loss_c_1 * _conf_pi_1.view(-1, 1) +
                loss_c_2 * _conf_pi_2.view(-1, 1) +
                loss_c_3 * _conf_pi_3.view(-1, 1) +
                loss_c_4 * _conf_pi_4.view(-1, 1)
            )
            loss_c = loss_c.sum()

        else:
            # Classification loss (Type-2)
            # more details are in our supplementary material
            conf_pi_1_ = conf_pi_1.view(-1, 1)
            conf_pi_2_ = conf_pi_2.view(-1, 1)
            conf_pi_3_ = conf_pi_3.view(-1, 1)
            conf_pi_4_ = conf_pi_4.view(-1, 1)

            conf_pi_all = torch.stack([
                            conf_pi_1_.reshape(-1),
                            conf_pi_2_.reshape(-1),
                            conf_pi_3_.reshape(-1),
                            conf_pi_4_.reshape(-1)
                            ])
            conf_pi_all = conf_pi_all.transpose(0,1)
            conf_pi_all = (torch.softmax(conf_pi_all, dim=1)).transpose(0,1).reshape(-1)
            (
                conf_pi_1_,
                conf_pi_2_,
                conf_pi_3_,
                conf_pi_4_
            ) = torch.split(conf_pi_all, conf_pi_1_.reshape(-1).size(0), dim=0)
            conf_pi_1_ = conf_pi_1_.view(conf_pi_1.size(0), -1)
            conf_pi_2_ = conf_pi_2_.view(conf_pi_2.size(0), -1)
            conf_pi_3_ = conf_pi_3_.view(conf_pi_3.size(0), -1)
            conf_pi_4_ = conf_pi_4_.view(conf_pi_4.size(0), -1)

            conf_var_1 = torch.sigmoid(conf_var_1)
            conf_var_2 = torch.sigmoid(conf_var_2)
            conf_var_3 = torch.sigmoid(conf_var_3)
            conf_var_4 = torch.sigmoid(conf_var_4)

            rand_val_1 = torch.randn(conf_var_1.size(0), conf_var_1.size(1), conf_var_1.size(2))
            rand_val_2 = torch.randn(conf_var_2.size(0), conf_var_2.size(1), conf_var_2.size(2))
            rand_val_3 = torch.randn(conf_var_3.size(0), conf_var_3.size(1), conf_var_3.size(2))
            rand_val_4 = torch.randn(conf_var_4.size(0), conf_var_4.size(1), conf_var_4.size(2))

            batch_conf_1 = (conf_mu_1+torch.sqrt(conf_var_1)*rand_val_1).view(-1, self.num_classes)
            batch_conf_2 = (conf_mu_2+torch.sqrt(conf_var_2)*rand_val_2).view(-1, self.num_classes)
            batch_conf_3 = (conf_mu_3+torch.sqrt(conf_var_3)*rand_val_3).view(-1, self.num_classes)
            batch_conf_4 = (conf_mu_4+torch.sqrt(conf_var_4)*rand_val_4).view(-1, self.num_classes)

            soft_max = nn.Softmax(dim=1)

            epsi = 10**-9
            weighted_softmax_out = (
                        soft_max(batch_conf_1)*conf_pi_1_.view(-1, 1) +
                        soft_max(batch_conf_2)*conf_pi_2_.view(-1, 1) +
                        soft_max(batch_conf_3)*conf_pi_3_.view(-1, 1) +
                        soft_max(batch_conf_4)*conf_pi_4_.view(-1, 1)
            )
            softmax_out_log = -torch.log(weighted_softmax_out+epsi)
            loss_c = softmax_out_log.gather(1, conf_t.view(-1,1))

            loss_c = loss_c.view(pos.size()[0], pos.size()[1])
            loss_c[pos] = 0  # filter out pos boxes for now  : true -> zero
            loss_c = loss_c.view(num, -1)

            _, loss_idx = loss_c.sort(1, descending=True)
            _, idx_rank = loss_idx.sort(1)
            num_pos = pos.long().sum(1, keepdim=True)
            num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
            neg = idx_rank < num_neg.expand_as(idx_rank)

            # Confidence Loss Including Positive and Negative Examples
            pos_idx = pos.unsqueeze(2).expand_as(conf_mu_1)
            neg_idx = neg.unsqueeze(2).expand_as(conf_mu_1)

            batch_conf_1_ = conf_mu_1+torch.sqrt(conf_var_1)*rand_val_1
            batch_conf_2_ = conf_mu_2+torch.sqrt(conf_var_2)*rand_val_2
            batch_conf_3_ = conf_mu_3+torch.sqrt(conf_var_3)*rand_val_3
            batch_conf_4_ = conf_mu_4+torch.sqrt(conf_var_4)*rand_val_4

            conf_pred_1 = batch_conf_1_[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
            conf_pred_2 = batch_conf_2_[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
            conf_pred_3 = batch_conf_3_[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
            conf_pred_4 = batch_conf_4_[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)

            targets_weighted = conf_t[(pos+neg).gt(0)]

            _conf_pi_1 = conf_pi_1_[(pos+neg).gt(0)]
            _conf_pi_2 = conf_pi_2_[(pos+neg).gt(0)]
            _conf_pi_3 = conf_pi_3_[(pos+neg).gt(0)]
            _conf_pi_4 = conf_pi_4_[(pos+neg).gt(0)]

            weighted_softmax_out = (
                        soft_max(conf_pred_1)*_conf_pi_1.view(-1, 1) +
                        soft_max(conf_pred_2)*_conf_pi_2.view(-1, 1) +
                        soft_max(conf_pred_3)*_conf_pi_3.view(-1, 1) +
                        soft_max(conf_pred_4)*_conf_pi_4.view(-1, 1)
            )
            softmax_out_log = -torch.log(weighted_softmax_out+epsi)
            loss_c = softmax_out_log.gather(1, targets_weighted.view(-1,1))
            loss_c = loss_c.sum()

        N = num_pos.data.sum()
        loss_l /= N
        loss_c /= N
        return loss_l, loss_c


ModuleNotFoundError: No module named 'data'

In [None]:
"""
Main file for training Yolo model on Pascal VOC and COCO dataset
"""

import config
import torch
import torch.optim as optim

from model import YOLOv3
from tqdm import tqdm
from utils import (
    mean_average_precision,
    cells_to_bboxes,
    get_evaluation_bboxes,
    save_checkpoint,
    load_checkpoint,
    check_class_accuracy,
    get_loaders,
    plot_couple_examples
)
from loss import YoloLoss
import warnings
warnings.filterwarnings("ignore")

torch.backends.cudnn.benchmark = True


def train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors):
    loop = tqdm(train_loader, leave=True)
    losses = []
    for batch_idx, (x, y) in enumerate(loop):
        x = x.to(config.DEVICE)
        y0, y1, y2 = (
            y[0].to(config.DEVICE),
            y[1].to(config.DEVICE),
            y[2].to(config.DEVICE),
        )

        with torch.cuda.amp.autocast():
            out = model(x)
            loss = (
                loss_fn(out[0], y0, scaled_anchors[0])
                + loss_fn(out[1], y1, scaled_anchors[1])
                + loss_fn(out[2], y2, scaled_anchors[2])
            )

        losses.append(loss.item())
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update progress bar
        mean_loss = sum(losses) / len(losses)
        loop.set_postfix(loss=mean_loss)



def main():
    # model = YOLO_GMM(num_classes=config.NUM_CLASSES).to(config.DEVICE)
    model = build_yolo_gmm().to(config.DEVICE)
    optimizer = optim.Adam(
        model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY
    )
    loss_fn = YoloLoss()
    scaler = torch.cuda.amp.GradScaler()

    train_loader, test_loader, train_eval_loader = get_loaders(
        train_csv_path=config.DATASET + "/train.csv", test_csv_path=config.DATASET + "/test.csv"
    )

    # if config.LOAD_MODEL:
    #     load_checkpoint(
    #         config.CHECKPOINT_FILE, model, optimizer, config.LEARNING_RATE
    #     )

    scaled_anchors = (
        torch.tensor(config.ANCHORS)
        * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
    ).to(config.DEVICE)

    for epoch in range(config.NUM_EPOCHS):
        #plot_couple_examples(model, test_loader, 0.6, 0.5, scaled_anchors)
        train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors)

        #if config.SAVE_MODEL:
        #    save_checkpoint(model, optimizer, filename=f"checkpoint.pth.tar")

        #print(f"Currently epoch {epoch}")
        #print("On Train Eval loader:")
        #print("On Train loader:")
        #check_class_accuracy(model, train_loader, threshold=config.CONF_THRESHOLD)

        if epoch > 0 and epoch % 3 == 0:
            check_class_accuracy(model, test_loader, threshold=config.CONF_THRESHOLD)
            pred_boxes, true_boxes = get_evaluation_bboxes(
                test_loader,
                model,
                iou_threshold=config.NMS_IOU_THRESH,
                anchors=config.ANCHORS,
                threshold=config.CONF_THRESHOLD,
            )
            mapval = mean_average_precision(
                pred_boxes,
                true_boxes,
                iou_threshold=config.MAP_IOU_THRESH,
                box_format="midpoint",
                num_classes=config.NUM_CLASSES,
            )
            print(f"MAP: {mapval.item()}")
            model.train()


if __name__ == "__main__":
    main()


  0%|          | 0/2069 [00:17<?, ?it/s]


IndexError: too many indices for tensor of dimension 2