In [2]:
import numpy as np
import torch
import torch.nn as nn

In [None]:
class ResNetBlock(nn.Module):
    def __init__(self, n_features):
        super(ResNetBlock, self).__init__()
        
        self.w1 = nn.Conv2d(in_channels=n_features,out_channels=n_features,kernel_size=3,stride=1,padding=1)
        self.w2 = nn.Conv2d(in_channels=n_features,out_channels=n_features,kernel_size=3,stride=1,padding=1)
        self.activation = nn.ReLU()
    
    def forward(self, x):
        identity = x.clone()
        x = self.w1(x)
        x = self.activation(x)
        x = self.w2(x)
        x = x+identity
        out = self.activation(x)
        return out

In [None]:
class SE_ResNetBlock(nn.Module):
    def __init__(self, n_features,r):
        super(SE_ResNetBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=n_features,out_channels=n_features,kernel_size=3,stride=1,padding=1)
        self.conv2 = nn.Conv2d(in_channels=n_features,out_channels=n_features,kernel_size=3,stride=1,padding=1)
        self.activation = nn.ReLU()
        self.globalpool = nn.AdaptiveAvgPool2d(output_size=(1,1))
        self.fc = nn.Conv2d(in_channels=n_features,out_channels=n_features//r,kernel_size=1,stride=1,padding=0) #nn.Linear(in_features=n_features,out_features=n_features//r) 
        self.fc2 = nn.Conv2d(in_channels=n_features//r,out_channels=n_features,kernel_size=1,stride=1,padding=0) #nn.Linear(in_features=n_features//r,out_features=n_features)
        self.gate = nn.Sigmoid()
    
    def forward(self, x):
        identity = x.clone()
        out = self.conv1(x)
        
        out = self.activation(out)
        out = self.conv2(out)
        
        se = self.globalpool(out) #.unsqueeze(-1).unsqueeze(-1) add if using nn.linear
        se = self.fc(se)
        se = self.activation(se)
        se = self.fc2(se)
        se = self.gate(se)
        
        out = (out*se)+identity
        out = self.activation(out)
        return out

In [None]:
class SE_ResNet(nn.Module):
    def __init__(self, n_in, n_features, num_blocks=3,r=16):
        super(SE_ResNet, self).__init__()
        #First conv layers needs to output the desired number of features.
        conv_layers = [nn.Conv2d(n_in, n_features, kernel_size=3, stride=1, padding=1),
                       nn.ReLU()] #320x100
        for i in range(num_blocks):
            conv_layers.append(Se_ResNetBlock(n_features,r))
            
        conv_layers.append([nn.MaxPool2d(2,2),
                            nn.Conv2d(n_features, 2*n_features, kernel_size=3, stride=1, padding=1),
                            nn.ReLU()]) #160x50
        for i in range(num_blocks):
            conv_layers.append(Se_ResNetBlock(2*n_features,r))
            
        conv_layers.append([nn.MaxPool2d(2,2),
                            nn.Conv2d(2*n_features, 4*n_features, kernel_size=3, stride=1, padding=1),
                            nn.ReLU()]) #80x25
        for i in range(num_blocks):
            conv_layers.append(Se_ResNetBlock(4*n_features,r))
            
        self.blocks = nn.Sequential(*conv_layers)
        
        self.fc = nn.Sequential(nn.Linear(80*25*4*n_features, 2048),
                                nn.ReLU(),
                                nn.Linear(2048, 512),
                                nn.ReLU(),
                                nn.Linear(512,5),
                                nn.Softmax(dim=1))
        
    def forward(self, x):
        x = self.blocks(x)
        #reshape x so it becomes flat, except for the first dimension (which is the minibatch)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return out