In [None]:
'''FPN in PyTorch.
See the paper "Feature Pyramid Networks for Object Detection" for more details.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.autograd import Variable


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Res_block(nn.Module):
    def __init__(self, in_channels, out_channels,sample):
        super(Res_block, self).__init__()
        self.in_planes = 64

        if sample:
          self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
          self.parallel = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2),
                nn.BatchNorm2d(out_channels)

        else:
          self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
          self.parallel = nn.Sequential()

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)


        def forward(self, input):
          parallel = self.parallel(input)
          input = nn.ReLU(
                      (self.bn2(
                            self.conv2((
                              nn.ReLU(
                                  (self.bn1(self.conv1(input)))
                                     ))))))
          input = input + parallel
          return nn.ReLU(input))

class FPN(nn.Module):
  def __init__(self, in_channels, resblock, outputs=256):
    super().__init__()
        self.layer0 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )

        self.layer1 = nn.Sequential(
            resblock(64, 64, sample=False),
            resblock(64, 64, sample=False)
        )

        self.layer2 = nn.Sequential(
            resblock(64, 128, sample=True),
            resblock(128, 128, sample=False)
        )

        self.layer3 = nn.Sequential(
            resblock(128, 256, sample=True),
            resblock(256, 256, sample=False)
        )


        self.layer4 = nn.Sequential(
            resblock(256, 512, sample=True),
            resblock(512, 512, sample=False)
        )

        self.gap = torch.nn.AdaptiveAvgPool2d(1)
        self.fc = torch.nn.Linear(512, outputs)

        # # Bottom-up layers
        # self.layer1 = self._make_layer(block,  64, num_blocks[0], stride=1)
        # self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        # self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        # self.layer4 = self._make_layer(block, 256, num_blocks[3], stride=2)
        # self.layer5 = self._make_layer(block, 256, num_blocks[4], stride=2)
        # self.layer6 = self._make_layer(block, 256, num_blocks[5], stride=2)
        # self.layer7 = self._make_layer(block, 512, num_blocks[6], stride=2)
        # self.layer8 = self._make_layer(block, 512, num_blocks[7], stride=2)
        # self.layer9 = self._make_layer(block, 512, num_blocks[8]], stride=2)
        # self.layer10 = self._make_layer(block, 512, num_blocks[9], stride=2)

        # # Top layer
        # self.toplayer = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0)  # Reduce channels

        # # Smooth layers
        # self.smooth1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        # self.smooth2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        # self.smooth3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)

        # # Lateral layers
        # self.latlayer1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
        # self.latlayer2 = nn.Conv2d( 512, 256, kernel_size=1, stride=1, padding=0)
        # self.latlayer3 = nn.Conv2d( 256, 256, kernel_size=1, stride=1, padding=0)

    # def _make_layer(self, block, planes, num_blocks, stride):
    #     strides = [stride] + [1]*(num_blocks-1)
    #     layers = []
    #     for stride in strides:
    #         layers.append(block(self.in_planes, planes, stride))
    #         self.in_planes = planes * block.expansion
    #     return nn.Sequential(*layers)

    # def forward(self, x):
    #     # Bottom-up
    #     c1 = F.relu(self.bn1(self.conv1(x)))
    #     c1 = F.max_pool2d(c1, kernel_size=3, stride=2, padding=1)
    #     c5 = self.layer4(self.layer3(self.layer2(self.layer1(c1))))
       
    #     # Top-down
    #     p5 = self.toplayer(c5)
    #     p4 = self._upsample_add(p5, self.latlayer1(c4))
    #     p3 = self._upsample_add(p4, self.latlayer2(c3))
    #     p2 = self._upsample_add(p3, self.latlayer3(c2))
    #     # Smooth
    #     p4 = self.smooth1(p4)
    #     p3 = self.smooth2(p3)
    #     p2 = self.smooth3(p2)
    #     return p2

    def forward(self, input):
        input = self.layer0(input)
        input = self.layer1(input)
        input = self.layer2(input)
        input = self.layer3(input)
        input = self.layer4(input)
        input = self.gap(input)
        input = torch.flatten(input)
        input = self.fc(input)

        return input

    def _upsample_add(self, x, y):
        '''Upsample and add two feature maps.
        Args:
          x: (Variable) top feature map to be upsampled.
          y: (Variable) lateral feature map.
        Returns:
          (Variable) added feature map.
        Note in PyTorch, when input size is odd, the upsampled feature map
        with `F.upsample(..., scale_factor=2, mode='nearest')`
        maybe not equal to the lateral feature map size.
        e.g.
        original input size: [N,_,15,15] ->
        conv2d feature map size: [N,_,8,8] ->
        upsampled feature map size: [N,_,16,16]
        So we choose bilinear upsample which supports arbitrary output sizes.
        '''
        _,_,H,W = y.size()
        return F.upsample(x, size=(H,W), mode='bilinear') + y

    


def FPN101():
    # return FPN(Bottleneck, [2,4,23,3])
    return FPN(Bottleneck, [2,2,2,2])


def test():
    net = FPN101()
    fm = net(torch.randn(1,3,128,128))

test()

In [None]:
class FRGB(nn.Module):
    def __init__(self):
        super(FRGB, self).__init__()
        self.conv1 = nn.Conv2d(256,128,3,padding=1)
        self.gn1 = nn.GroupNorm(32,128)
        self.ReLU1 = nn.ReLU()
        self.conv2 = nn.Conv2d(128,128,3,padding=1)
        self.gn2 = nn.GroupNorm(32,128)
        self.ReLU2 = nn.ReLU()
        self.conv3 = nn.Conv2d(128,128,1)
        self.gn3 = nn.GroupNorm(32,128)
    
    def forward(self,x):
        x = self.ReLU1(self.gn1(self.conv1(x)))
        x = self.ReLU2(self.gn2(self.conv2(x)))
        x = self.gn3(self.conv3(x))
        return x

class Fspatial(nn.Module):
    def __init__(self):
        super(Fspatial, self).__init__()
        self.conv1 = nn.Conv2d(256,128,3,padding=1)
        self.gn1 = nn.GroupNorm(32,128)
        self.ReLU1 = nn.ReLU()
        self.conv2 = nn.Conv2d(128,128,3,padding=1)
        self.gn2 = nn.GroupNorm(32,128)
        self.ReLU2 = nn.ReLU()
        self.conv3 = nn.Conv2d(128,1,1)
    
    def forward(self,x):
        x = self.ReLU1(self.gn1(self.conv1(x)))
        x = self.ReLU2(self.gn2(self.conv2(x)))
        x = self.conv3(x)
        return x
    
class TwoLayerCNN(nn.Module):
    def __init__(self):
        super(TwoLayerCNN, self).__init__()
        self.conv1 = nn.Conv2d(64,128,1)
        self.gn1 = nn.GroupNorm(32,128)
        self.ReLU1 = nn.ReLU()
        self.conv2 = nn.Conv2d(128,128,1)
        self.gn2 = nn.GroupNorm(32,128)
    
    
    def forward(self,x):
        x = self.ReLU1(self.gn1(self.conv1(x)))
        x = self.gn2(self.conv2(x))
        return x


In [None]:
import time

#PyTorch
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice
    

class ComboNet(nn.Module):
    def __init__(self, batchSize, numPeChannels, learningRate):
        super(ComboNet, self).__init__()
        self.f_FPN = FPN(Bottleneck, [2,2,2,2])
        self.f_RGB = FRGB()
        self.f_depth = TwoLayerCNN()
        self.f_spatial = Fspatial()
        
        self.optimizer = torch.optim.Adam(self.parameters(), lr=learningRate)
        self.lossBCE = nn.BCELoss()
        self.lossDICE = DiceLoss()
        self.lossCE = nn.CrossEntropyLoss() #THIS IS JUST HERE AS A DUMMY FOR NOW
        
        #Make denominator tensor for positional encoding (don't want to run duplicate work)
        denom = torch.zeros(batchSize,numPeChannels,256,256)
        idx = torch.ones(256,256)
        for i in range(numPeChannels):
            denom[:,i,:,:] = 200**(2*i*idx/numPeChannels)
        self.denom = denom
        self.numPeChannels = numPeChannels
        
        
    def forward(self,x1,x2,z):
        #RGB feature processing
        x1 = self.f_FPN(x1)
        x1 = self.f_RGB(x1)
        x1 = torch.nn.functional.interpolate(x1,scale_factor=2, mode='bilinear') #Convert 128x128 to 256x256 img
        
        #Depth feature processing
        x2 = self.positionEncoding(x2,z)
        x2 = self.f_depth(x2)
        
        #Combine features and pass them through final CNN
        x = torch.cat((x1,x2), 1) #second arg specifies which dimension to concatenate on, we want channel dimension which is 1
        x = self.f_spatial(x)
        
        return x
    
    def positionEncoding(self, depth, z):
        """
        Computes the positional encoding (as defined by the paper) for a depth
        - depth: the input depth image
        - z: the distance we wish to evaluate
        """
        
        depth = torch.nn.functional.interpolate(depth,scale_factor=0.5, mode='bilinear')
        s = depth.size()
        pe = torch.zeros(s[0],self.numPeChannels,s[2],s[3])
        pe[:,0::2,:,:] = torch.sin((50*depth)/self.denom[:,0::2,:])
        pe[:,1::2,:,:] = torch.cos((50*depth)/self.denom[:,1::2,:])
        
        return pe
    
    def step(self,x1,x2,z,y):
        """
        Iterates over a single training step
        - x: input batch
        - y: expected labels for batch
        """
        self.optimizer.zero_grad() #Reset parameter gradients to 0

        outputs = self.forward(x1,x2,z)
#         loss = self.lossBCE(outputs,y) + self.lossDICE(outputs,y)
        loss = self.lossCE(outputs,y) + self.lossCE(outputs,y) #dummy example showing we can easily sum losses
        loss.backward()
        self.optimizer.step()
        
        return loss.detach().cpu().numpy()
        
        
        
        

    
def trainModel():
    
    #Instantiate model
    net = ComboNet(1,64,0.001)
    
    #Get data 
    epochs = 10
    data = torch.randn(1,3,512,512)
    
    #Training loop
    st = time.time()
    for i in range(epochs):
        print("Processing epoch ",i)
        
        for d in data:
            ipt1 = torch.randn(1,3,512,512)
            ipt2 = torch.ones(1,1,512,512)
            y = torch.randn(1,1,256,256)
            net.step(ipt1,ipt2,1,y)
        
    print("Total training time: ",round(time.time()-st,2))
    
    

trainModel()
        

Processing epoch  0
Processing epoch  1
Processing epoch  2
Processing epoch  3
Processing epoch  4
Processing epoch  5
Processing epoch  6
Processing epoch  7
Processing epoch  8
Processing epoch  9
Total training time:  86.2
