In [None]:
import torch
import torch.nn as nn
import fastai.layers as L
import torch.nn.functional as F
from collections import OrderedDict

torch.backends.cudnn.deterministic = True

def embed_layer(weights_matrix):

  num_embeddings, embedding_dim = weights_matrix.shape
  weights_matrix = torch.from_numpy(weights_matrix)
  
  emb_layer = nn.Embedding(num_embeddings, embedding_dim, padding_idx=1).to('cuda')
  emb_layer.load_state_dict({'weight': weights_matrix})
  emb_layer.weight.require_grad = True

  return emb_layer

class dropsampleMergeLayer(nn.Module):
  def __init__(self, inplanes, dropout, dense):
    super().__init__()
    self.downsample = L.conv1d(inplanes, inplanes*2, stride=2)
    self.dropout = dropout

  def forward(self, x):
    x_drop = F.dropout(x, p=self.dropout)
    return x_drop + self.downsample(x.orig)

def transition(inplanes, dropout, dense = True):

  trans = L.SequentialEx(
      L.conv1d(inplanes, inplanes*2, ks = 3, stride=2, padding = 1),
      dropsampleMergeLayer(inplanes, dropout = dropout, dense= dense))

  return trans

def resLayer(inplanes, leaky, self_attention):
  
  conv_kwargs = {'is_1d' : True, 'self_attention' : self_attention, 'leaky' : leaky}
  return L.res_block(inplanes, **conv_kwargs)

class ResNet(nn.Module):
  def __init__(self, weights_matrix, layers = (3, 3, 6, 6), inplanes = 32, 
               dropout = 0.35, leaky = 0.01, embed_dim = 50, self_attention = True):
    
    super(ResNet, self).__init__()
    self.embedding = embed_layer(weights_matrix)
    self.init_conv = L.conv1d(1, inplanes, ks=(3, embed_dim), stride=1, padding=(1,0), bias=False)
    self.dropout = dropout
    self.features = nn.Sequential(OrderedDict([
                                               ('init_conv', self.init_conv),
                                               ('init_norm', nn.BatchNorm1d(inplanes)),
                                               ('init_relu', nn.LeakyReLU(inplace=True))]))
        
    num_features = inplanes
    
    for i, layer in enumerate(layers):
      self.features.add_module('resblock%d' %(i+1), 
                               self._make_block(num_features, layer, leaky=leaky, self_attention= self_attention))
      if i!= len(layers)-1:
        self.features.add_module('transition%d' %(i+1), 
                                 transition(num_features, dropout = self.dropout))
        num_features = num_features*2

    self.leakyrelu = L.relu(inplace=True, leaky=leaky)
    self.maxpool = nn.AdaptiveMaxPool1d(1)
    self.fc = nn.Linear(num_features, num_features*2)
    self.classifier = nn.Linear(num_features*2, 2)
    
    for m in self.modules():
      if isinstance(m, nn.BatchNorm1d):
        nn.init.constant_(m.weight, 0.5)
        nn.init.constant_(m.bias, 0)
      elif isinstance(m, nn.Linear):
        nn.init.constant_(m.bias, 0)
  
  def _make_block(self, outplanes, layer_count, leaky, self_attention):
    layers = []
    
    for i in range(0, layer_count):
      layers.append(resLayer(outplanes, leaky=leaky, self_attention = self_attention))
      if i < layer_count:
        layers.append(nn.Dropout(p=self.dropout, inplace = True))
    
    return nn.Sequential(*layers)
    
  def forward(self, x):
    x = self.embedding(x).unsqueeze(1)  
    
    for i, layer in enumerate(self.features):
      if i == 1: x = x.squeeze(3)
      x = layer(x)
    x = self.maxpool(x).view(x.size(0), -1)
    x = self.fc(x)
    x = self.leakyrelu(x)
    x = self.classifier(x) 
    
    return x