### TCN Implementation

In [11]:
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm


class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()


class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                 self.conv2, self.chomp2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)


class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            print(f"TB[{i}] -> in_channels: {in_channels} out_channels: {out_channels}")
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                                     padding=(kernel_size-1) * dilation_size, dropout=dropout)]

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)



### TCN used for word-level PennTreebank

In [12]:
class TCN(nn.Module):

    def __init__(self, input_size, output_size, num_channels,
                 kernel_size=2, dropout=0.3, emb_dropout=0.1, tied_weights=False):
        super(TCN, self).__init__()
        self.encoder = nn.Embedding(output_size, input_size)
        self.tcn = TemporalConvNet(input_size, num_channels, kernel_size, dropout=dropout)

        self.decoder = nn.Linear(num_channels[-1], output_size)
        if tied_weights:
            if num_channels[-1] != input_size:
                raise ValueError('When using the tied flag, nhid must be equal to emsize')
            self.decoder.weight = self.encoder.weight
            print("Weight tied")
        self.drop = nn.Dropout(emb_dropout)
        self.emb_dropout = emb_dropout
        self.init_weights()

    def init_weights(self):
        self.encoder.weight.data.normal_(0, 0.01)
        self.decoder.bias.data.fill_(0)
        self.decoder.weight.data.normal_(0, 0.01)

    def forward(self, input):
        """Input ought to have dimension (N, C_in, L_in), where L_in is the seq_len; here the input is (N, L, C)"""
        emb = self.drop(self.encoder(input))
        y = self.tcn(emb.transpose(1, 2)).transpose(1, 2)
        y = self.decoder(y)
        return y.contiguous()

### Default settings for TCN on PennTreeBank dataset

In [13]:
# the first two configs are for the embedding layer. 
# We don't expect to dynamically change embedding layer
# during weight-shared training.
emsize = 600
n_words = 10000
### num_channels can be elastic
nhid = 600
levels = 4
num_chans = [nhid]*(levels-1) + [emsize]
dropout = 0.45
emb_dropout = 0.25
k_size = 3
tied = True
model = TCN(emsize, n_words, num_chans, dropout=dropout, emb_dropout=emb_dropout, kernel_size=k_size, tied_weights=tied)

TB[0] -> in_channels: 600 out_channels: 600
TB[1] -> in_channels: 600 out_channels: 600
TB[2] -> in_channels: 600 out_channels: 600
TB[3] -> in_channels: 600 out_channels: 600
Weight tied


In [14]:
model

TCN(
  (encoder): Embedding(10000, 600)
  (tcn): TemporalConvNet(
    (network): Sequential(
      (0): TemporalBlock(
        (conv1): Conv1d(600, 600, kernel_size=(3,), stride=(1,), padding=(2,))
        (chomp1): Chomp1d()
        (relu1): ReLU()
        (dropout1): Dropout(p=0.45, inplace=False)
        (conv2): Conv1d(600, 600, kernel_size=(3,), stride=(1,), padding=(2,))
        (chomp2): Chomp1d()
        (relu2): ReLU()
        (dropout2): Dropout(p=0.45, inplace=False)
        (net): Sequential(
          (0): Conv1d(600, 600, kernel_size=(3,), stride=(1,), padding=(2,))
          (1): Chomp1d()
          (2): ReLU()
          (3): Dropout(p=0.45, inplace=False)
          (4): Conv1d(600, 600, kernel_size=(3,), stride=(1,), padding=(2,))
          (5): Chomp1d()
          (6): ReLU()
          (7): Dropout(p=0.45, inplace=False)
        )
        (relu): ReLU()
      )
      (1): TemporalBlock(
        (conv1): Conv1d(600, 600, kernel_size=(3,), stride=(1,), padding=(4,), dila

In [31]:
# model.tcn.network[0].net[0](torch.ones((1,600,32)))
model.tcn.network[1].net[0]

Conv1d(600, 600, kernel_size=(3,), stride=(1,), padding=(4,), dilation=(2,))

In [43]:
in_channel = 20
out_channel = 40
kernel_size = 3
stride=1
dilation = 2 
padding_val = (kernel_size-1) * dilation
# temp =  TemporalBlock(in_channel, out_channel, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding_val, dropout=0.45)
tcn_conv = weight_norm(nn.Conv1d(in_channel, out_channel, kernel_size,
                                           stride=stride, padding=padding_val, dilation=dilation))


In [None]:
tcn_conv.weight_v

In [None]:
# for idx, m in enumerate(temp.net.modules()):
#     print(f"ID: {idx} Module: {m}")
inp = torch.ones((1,20,32))
out = tcn_conv(inp)
out

In [15]:
from torch.nn.parameter import Parameter, UninitializedParameter
from torch import _weight_norm, norm_except_dim
import torch.nn.functional as F
from ofa.utils import get_same_padding

class DynamicConv1dWtNorm(nn.Module):
	def __init__(self, max_in_channels, max_out_channels, kernel_size=1, stride=1, dilation=1, padding=0, dim=0, weight_norm=True):
		super(DynamicConv1dWtNorm, self).__init__()
		self.max_in_channels = max_in_channels
		self.max_out_channels = max_out_channels
		self.kernel_size = kernel_size
		self.stride = stride
		self.dilation = dilation
		self.padding = padding
		self.conv = nn.Conv1d(self.max_in_channels, self.max_out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation)
		weight = self.conv.weight
		self.dim = dim
		if weight_norm:
			self.conv_g =  Parameter(norm_except_dim(weight, 2, dim=self.dim).data)
			self.conv_v = Parameter(Parameter(weight.data))
		self.active_out_channel = self.max_out_channels
	
	def get_active_filter(self, out_channel, in_channel):
		if weight_norm:
			return _weight_norm(self.conv_v[:out_channel, :in_channel, :], self.conv_g[:out_channel, :, :], self.dim)
		return self.conv.weight[:out_channel, :in_channel, :]

	def forward(self, x, out_channel=None):
		if out_channel is None:
			out_channel = self.active_out_channel
		in_channel = x.size(1)
		filters = self.get_active_filter(out_channel, in_channel).contiguous()
  
		y = F.conv1d(x, filters, self.conv.bias[:out_channel], self.stride, self.padding , self.dilation, 1)
		return y

In [50]:
in_channel = 20
out_channel = 40
kernel_size = 3
stride=1
dilation = 2 
padding_val = (kernel_size-1) * dilation
dynamic_layer = DynamicConv1dWtNorm(in_channel, out_channel, kernel_size,
                                           stride=stride, padding=padding_val, dilation=dilation)

In [62]:
inp = torch.ones((1,12,32))
dynamic_layer.active_out_channel = 10
out = dynamic_layer(inp)

out.size()

torch.Size([1, 10, 36])

In [19]:
from ofa.utils import make_divisible
class DynamicTemporalBlock(nn.Module):
    def __init__(self, maxin_channel, maxout_channel, kernel_size, stride, dilation, padding, dropout=0.2, expand_ratio_list=[1.0]):
        super(DynamicTemporalBlock, self).__init__()
        self.active_expand_ratio = max(expand_ratio_list)
        self.active_out_channel = maxout_channel
        self.kernel_size = kernel_size
        self.stride = stride
        self.dilation = dilation
        self.padding = padding
        self.dropout = dropout
        max_middle_channel = self.active_middle_channels
        self.conv1 = DynamicConv1dWtNorm(maxin_channel, max_middle_channel, kernel_size,stride=stride, padding=padding, dilation=dilation)
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)
        self.conv2 = DynamicConv1dWtNorm(max_middle_channel, maxout_channel, kernel_size,stride=stride, padding=padding, dilation=dilation)
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)
        self.downsample = DynamicConv1dWtNorm(maxin_channel, maxout_channel, 1, weight_norm=False) if maxin_channel != maxout_channel else None
        self.relu = nn.ReLU()
        self.init_weights()
    
    def init_weights(self):
        self.conv1.conv.weight.data.normal_(0, 0.01)
        self.conv2.conv.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.conv.weight.data.normal_(0, 0.01)
    
    @property
    def active_middle_channels(self):
        return round(self.active_out_channel * self.active_expand_ratio)
    
    def get_active_subnet(self, in_channel, preserve_weight=True):
        sub_layer = MyTemporalBlock(in_channel, self.active_out_channel, self.active_middle_channels, self.kernel_size, self.stride, self.dilation, self.padding, self.dropout, self.dim)
        middle_channel = self.active_middle_channels
        out_channel = self.active_out_channel
        # TODO (alind): add proper selection mechanisms
        sub_layer.conv1.conv_g.data.copy_(self.conv1.conv_g.data[:middle_channel, :, :])
        sub_layer.conv1.conv_v.data.copy_(self.conv1.conv_v.data[:middle_channel, :in_channel, :])
        sub_layer.conv1.conv.bias.data.copy_(self.conv1.conv.bias.data[:middle_channel])
        
        sub_layer.conv2.conv_g.data.copy_(self.conv2.conv_g.data[:out_channel, :, :])
        sub_layer.conv2.conv_v.data.copy_(self.conv2.conv_v.data[:out_channel, :middle_channel, :])
        sub_layer.conv2.conv.bias.data.copy_(self.conv2.conv.bias.data[:out_channel])
        return sub_layer
        
    
    def forward(self, x):
        feature_dim = self.active_middle_channels
        self.conv1.active_out_channel = feature_dim
        self.conv2.active_out_channel = self.active_out_channel
        if self.downsample is not None:
            self.downsample.active_out_channel = self.active_out_channel
            
        out = self.conv1(x)
        out = self.chomp1(out)
        out = self.relu1(out)
        out = self.dropout1(out)
        
        out = self.conv2(out)
        out = self.chomp2(out)
        out = self.relu2(out)
        out = self.dropout2(out)
        
        res = x if self.downsample is None else self.downsample(x)
        
        return self.relu(out + res)
        
          

class Conv1dWtNorm(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, dilation=1, padding=0, dim=0, weight_norm=True):
        super(DynamicConv1dWtNorm, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.dilation = dilation
        self.padding = padding
        self.conv = nn.Conv1d(self.in_channels, self.out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation)
        weight = self.conv.weight
        self.dim = dim
        if weight_norm:
            self.conv_g =  Parameter(norm_except_dim(weight, 2, dim=self.dim).data)
            self.conv_v = Parameter(Parameter(weight.data))
        self.active_out_channel = self.max_out_channels

    def get_active_filter(self):
        if weight_norm:
            return _weight_norm(self.conv_v, self.conv_g, self.dim)
        return self.conv.weight

    def forward(self, x, out_channel=None):
        if out_channel is None:
            out_channel = self.active_out_channel
        in_channel = x.size(1)
        filters = self.get_active_filter(out_channel, in_channel).contiguous()

        y = F.conv1d(x, filters, self.conv.bias, self.stride, self.padding , self.dilation, 1)
        return y
    
class MyTemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, middle_channel, kernel_size, stride, dilation, padding, dropout=0.2, dim=0):
        super(MyTemporalBlock, self).__init__()
  
        self.dim = dim
        self.conv1 = Conv1dWtNorm(n_inputs, middle_channel, kernel_size, stride=stride, padding=padding, dilation=dilation)
#         nn.Conv1d(n_inputs, middle_channel, kernel_size, stride=stride, padding=padding, dilation=dilation)
#         self.conv1_g =  Parameter(norm_except_dim(self.conv1.weight, 2, dim=self.dim).data)
#         self.conv1_v = Parameter(Parameter(self.conv1.weight.data))
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        # copy three important things conv_g, conv_v
        self.conv2 = Conv1dWtNorm(middle_channel, n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation)
#         self.conv2_g =  Parameter(norm_except_dim(self.conv2.weight, 2, dim=self.dim).data)
#         self.conv2_v = Parameter(Parameter(self.conv2.weight.data))
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()
        

    def init_weights(self):
        self.conv1.init_weights()
        self.conv2.init_weights()
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.conv1(x)
        out = self.chomp1(out)
        out = self.relu1(out)
        out = self.dropout1(out)
        
        out = self.conv2(out)
        out = self.chomp2(out)
        out = self.relu2(out)
        out = self.dropout2(out)
        
        res = x if self.downsample is None else self.downsample(x)
        
        return self.relu(out + res)
    
   

In [None]:
class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            print(f"TB[{i}] -> in_channels: {in_channels} out_channels: {out_channels}")
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                                     padding=(kernel_size-1) * dilation_size, dropout=dropout)]

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)