In [None]:
import torch
import torch.nn as nn
from torchvision import models
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from fastai.layers import PixelShuffle_ICNR
%load_ext tensorboard

In [None]:
writer = SummaryWriter()

In [None]:
model = torch.hub.load('zhanghang1989/ResNeSt', 'resnest50', pretrained=True)

In [None]:
model

In [None]:
base_layers = list(model.children())
for i,layer in enumerate(base_layers):
    print(i,'     ',layer)

In [None]:
class DoubleConv(nn.Module):
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.skip = nn.Conv2d(in_channels, out_channels, kernel_size = 1) # Skip connection
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels)
        )
        
    def forward(self, x):
        skip_x = self.skip(x)
        conv_x = self.double_conv(x)
        added_x = skip_x + conv_x  # Element-wise addition of skip connection filters and residual filters
        return F.relu_(added_x) # Inplace functional version of relu

class PsUpsample(nn.Module): # Upsampling using pixel shuffle
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.upsample = PixelShuffle_ICNR(in_channels, in_channels//2, scale=2)
        self.conv = DoubleConv(in_channels, out_channels)
    
    def forward(self, x1, x2):
        x1 = self.upsample(x1)
        x = torch.cat((x2, x1), dim=1)
        x = self.conv(x)
        return x

In [None]:
class ResUNeSt(nn.Module):
    def __init__(self, out_channels):
        super().__init__()

        self.base_model = torch.hub.load('zhanghang1989/ResNeSt', 'resnest50', pretrained=True)
        self.base_layers = list(self.base_model.children())
        
        # Encoder path
        self.in_layer1 = self.base_layers[0]
        self.in_layer2 = nn.Sequential(*self.base_layers[1:4])
        self.layer1 = nn.Sequential(*self.base_layers[4])
        self.layer2 = nn.Sequential(*self.base_layers[5])
        self.layer3 = nn.Sequential(*self.base_layers[6])
        self.layer4 = nn.Sequential(*self.base_layers[7])
        
        # Cross path
        self.down_in1 = nn.Conv2d(64, 128 ,kernel_size=1)
        self.down_up = nn.Conv2d(3, 64, kernel_size=1)
        
        # Decoder path
        self.up1 = PsUpsample(2048, 1024)
        self.up2 = PsUpsample(1024, 512)
        self.up3 = PsUpsample(512, 256)
        self.up4 = PsUpsample(256, 128)
        self.up5 = PsUpsample(128, 64)
        
        self.out_layer = DoubleConv(64,3)
        
        
    def forward(self, x):
        
        #Encoder path
        x_up = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        x_in1 = self.in_layer1(x_up)
        x_in2 = self.in_layer2(x_in1) # This is of same size as x_l1 so not used
        x_l1 = self.layer1(x_in2)
        x_l2 = self.layer2(x_l1)
        x_l3 = self.layer3(x_l2)
        x_l4 = self.layer4(x_l3)
        
        # Decoder path
        x = self.up1(x_l4, x_l3)
        x = self.up2(x, x_l2)
        x = self.up3(x, x_l1)
        x_in1 = self.down_in1(x_in1)
        x = self.up4(x, x_in1)
        x_up = self.down_up(x_up)
        x = self.up5(x, x_up)
        x = self.out_layer(x)
            
        return x

In [None]:
model = ResUNeSt(3)

In [None]:
images = torch.randn(16,3,32,32)

In [None]:
out = model(images)

In [None]:
out.shape

In [None]:
writer.add_graph(model, images)
writer.close()

In [None]:
%tensorboard --logdir=runs