In [1]:
%load_ext autoreload
%env CUDA_VISIBLE_DEVICES = 0

env: CUDA_VISIBLE_DEVICES=0


In [4]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from models import DenseNet

In [None]:
device = torch.device('cuda')
device

device(type='cuda')

In [190]:
class ConvBlock(nn.Sequential):
    # preserves densenet convblock order
        def __init__(self, stride=37, kernel_size=4, in_channels=10, out_channels=10, deconv=True):
            layers = []
            layers.append(nn.BatchNorm2d(num_features=in_channels))  
            layers.append(nn.ReLU(inplace=True))
            if deconv:
                layers.append(nn.ConvTranspose2d(stride=stride, kernel_size=kernel_size,
                                                 in_channels=in_channels, out_channels=out_channels, bias=False))
            else:
                layers.append(nn.Conv2d(stride=stride, kernel_size=kernel_size,
                                                 in_channels=in_channels, out_channels=out_channels, bias=False))
                                
            super().__init__(*layers)
            

In [205]:
class FCN(nn.Module):
    """FCN with DenseNet121 backbone"""
    def __init__(self, num_classes=10):
        super().__init__()
        densenet_modules = nn.ModuleList(DenseNet().children())
        self.blockpool1 = nn.Sequential(*densenet_modules[:7])    # transition 1, out (128, 28, 28)
        self.blockpool2 = nn.Sequential(*densenet_modules[7:10])  # transition 2, out (256, 14, 14)
        self.blockpool3 = nn.Sequential(*densenet_modules[10:13]) # transition 3, out (512, 7, 7)
        self.block4 = densenet_modules[13]                        # DenseBlock4(BN-relu-conv), out (1024, 7, 7), avgpool4 не берем, т.к выход 1х1
        
        self.deconv4 = ConvBlock(stride=4, kernel_size=4, in_channels=1024, out_channels=128) # block4 -> blockpool1
        self.deconv3 = ConvBlock(stride=4, kernel_size=4, in_channels=512, out_channels=128)  # blockpool3 -> blockpool1 
        self.deconv2  = ConvBlock(stride=2, kernel_size=2, in_channels=256, out_channels=128) # blockpool2 -> blockpool1 
        self.pointwise = ConvBlock(deconv=False, stride=1, kernel_size=1, in_channels=128 * 4, out_channels=num_classes) # reduce channels
        self.classifier = ConvBlock(kernel_size=8, stride=8, in_channels=num_classes, out_channels=num_classes)

        
    def forward(self, x):
        x = self.blockpool1(x)
        pool1 = x
        x = self.blockpool2(x)
        pool2 = x
        x = self.blockpool3(x)
        pool3 = x
        x = self.block4(x)
    
        x = self.deconv4(x)
        pool3 = self.deconv3(pool3)
        pool2 = self.deconv2(pool2)
        
        x = torch.cat([pool1, pool2, pool3, x], dim=1)
        x = self.pointwise(x) 
        
        x = self.classifier(x)
        return x
    

In [206]:
net = FCN()
inputs = torch.randn(5, 3, 224, 224)
outputs = net(inputs)
outputs.shape

torch.Size([5, 10, 224, 224])