### TCN Implementation

In [10]:
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm


class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()


class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                 self.conv2, self.chomp2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)


class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            print(f"TB[{i}] -> in_channels: {in_channels} out_channels: {out_channels}")
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                                     padding=(kernel_size-1) * dilation_size, dropout=dropout)]

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)



### TCN used for word-level PennTreebank

In [11]:
class TCN(nn.Module):

    def __init__(self, input_size, output_size, num_channels,
                 kernel_size=2, dropout=0.3, emb_dropout=0.1, tied_weights=False):
        super(TCN, self).__init__()
        self.encoder = nn.Embedding(output_size, input_size)
        self.tcn = TemporalConvNet(input_size, num_channels, kernel_size, dropout=dropout)

        self.decoder = nn.Linear(num_channels[-1], output_size)
        if tied_weights:
            if num_channels[-1] != input_size:
                raise ValueError('When using the tied flag, nhid must be equal to emsize')
            self.decoder.weight = self.encoder.weight
            print("Weight tied")
        self.drop = nn.Dropout(emb_dropout)
        self.emb_dropout = emb_dropout
        self.init_weights()

    def init_weights(self):
        self.encoder.weight.data.normal_(0, 0.01)
        self.decoder.bias.data.fill_(0)
        self.decoder.weight.data.normal_(0, 0.01)

    def forward(self, input):
        """Input ought to have dimension (N, C_in, L_in), where L_in is the seq_len; here the input is (N, L, C)"""
        emb = self.drop(self.encoder(input))
        y = self.tcn(emb.transpose(1, 2)).transpose(1, 2)
        y = self.decoder(y)
        return y.contiguous()

### Default settings for TCN on PennTreeBank dataset

In [12]:
# the first two configs are for the embedding layer. 
# We don't expect to dynamically change embedding layer
# during weight-shared training.
emsize = 600
n_words = 10000
### num_channels can be elastic
nhid = 600
levels = 4
num_chans = [nhid]*(levels-1) + [emsize]
dropout = 0.45
emb_dropout = 0.25
k_size = 3
tied = True
model = TCN(emsize, n_words, num_chans, dropout=dropout, emb_dropout=emb_dropout, kernel_size=k_size, tied_weights=tied)

TB[0] -> in_channels: 600 out_channels: 600
TB[1] -> in_channels: 600 out_channels: 600
TB[2] -> in_channels: 600 out_channels: 600
TB[3] -> in_channels: 600 out_channels: 600
Weight tied


In [13]:
model

TCN(
  (encoder): Embedding(10000, 600)
  (tcn): TemporalConvNet(
    (network): Sequential(
      (0): TemporalBlock(
        (conv1): Conv1d(600, 600, kernel_size=(3,), stride=(1,), padding=(2,))
        (chomp1): Chomp1d()
        (relu1): ReLU()
        (dropout1): Dropout(p=0.45, inplace=False)
        (conv2): Conv1d(600, 600, kernel_size=(3,), stride=(1,), padding=(2,))
        (chomp2): Chomp1d()
        (relu2): ReLU()
        (dropout2): Dropout(p=0.45, inplace=False)
        (net): Sequential(
          (0): Conv1d(600, 600, kernel_size=(3,), stride=(1,), padding=(2,))
          (1): Chomp1d()
          (2): ReLU()
          (3): Dropout(p=0.45, inplace=False)
          (4): Conv1d(600, 600, kernel_size=(3,), stride=(1,), padding=(2,))
          (5): Chomp1d()
          (6): ReLU()
          (7): Dropout(p=0.45, inplace=False)
        )
        (relu): ReLU()
      )
      (1): TemporalBlock(
        (conv1): Conv1d(600, 600, kernel_size=(3,), stride=(1,), padding=(4,), dila

In [31]:
# model.tcn.network[0].net[0](torch.ones((1,600,32)))
model.tcn.network[1].net[0]

Conv1d(600, 600, kernel_size=(3,), stride=(1,), padding=(4,), dilation=(2,))

In [14]:
in_channel = 20
out_channel = 40
kernel_size = 3
stride=1
dilation = 2 
padding_val = (kernel_size-1) * dilation
# temp =  TemporalBlock(in_channel, out_channel, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding_val, dropout=0.45)
tcn_conv = weight_norm(nn.Conv1d(in_channel, out_channel, kernel_size,
                                           stride=stride, padding=padding_val, dilation=dilation))


In [None]:
tcn_conv.weight_v

In [15]:
# for idx, m in enumerate(temp.net.modules()):
#     print(f"ID: {idx} Module: {m}")
inp = torch.ones((1,20,32))
out = tcn_conv(inp)
out

tensor([[[ 0.6395,  0.6395,  0.5636,  ...,  0.8489,  0.9248,  0.9248],
         [ 0.0329,  0.0329, -0.3584,  ..., -0.4991, -0.1079, -0.1079],
         [ 0.2309,  0.2309,  0.4339,  ...,  0.2844,  0.0814,  0.0814],
         ...,
         [ 0.2028,  0.2028, -0.1882,  ..., -0.6475, -0.2564, -0.2564],
         [-0.3380, -0.3380, -0.3536,  ..., -0.1350, -0.1194, -0.1194],
         [ 0.2696,  0.2696,  0.6128,  ..., -0.3301, -0.6732, -0.6732]]],
       grad_fn=<SqueezeBackward1>)

In [164]:
from torch.nn.parameter import Parameter, UninitializedParameter
from torch import _weight_norm, norm_except_dim
import torch.nn.functional as F
from ofa.utils import get_same_padding

class DynamicConv1dWtNorm(nn.Module):
	def __init__(self, max_in_channels, max_out_channels, kernel_size=1, stride=1, dilation=1, padding=0, dim=0, weight_norm=True):
		super(DynamicConv1dWtNorm, self).__init__()
		self.max_in_channels = max_in_channels
		self.max_out_channels = max_out_channels
		self.kernel_size = kernel_size
		self.stride = stride
		self.dilation = dilation
		self.padding = padding
		self.conv = nn.Conv1d(self.max_in_channels, self.max_out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation)
		weight = self.conv.weight
		del self.conv._parameters["weight"]
		self.dim = dim
		if weight_norm:
            # self.conv_g = Parameter(
            #     norm_except_dim(weight, 2, dim=self.dim).data
            # )
            # self.conv_v = Parameter(Parameter(weight.data))
           
			self.register_parameter("conv_g", Parameter(
				norm_except_dim(weight, 2, dim=self.dim).data
			))
			self.register_parameter("conv_v",Parameter(weight.data) )
		self.active_out_channel = self.max_out_channels
	
	def get_active_filter(self, out_channel, in_channel):
		if weight_norm:
			return _weight_norm(self.conv_v[:out_channel, :in_channel, :], self.conv_g[:out_channel, :, :], self.dim)
		return self.conv.weight[:out_channel, :in_channel, :]

	def forward(self, x, out_channel=None):
		if out_channel is None:
			out_channel = self.active_out_channel
		in_channel = x.size(1)
		filters = self.get_active_filter(out_channel, in_channel).contiguous()
  
		y = F.conv1d(x, filters, self.conv.bias[:out_channel], self.stride, self.padding , self.dilation, 1)
		return y

In [17]:
in_channel = 20
out_channel = 40
kernel_size = 3
stride=1
dilation = 2 
padding_val = (kernel_size-1) * dilation
dynamic_layer = DynamicConv1dWtNorm(in_channel, out_channel, kernel_size,
                                           stride=stride, padding=padding_val, dilation=dilation)

In [18]:
inp = torch.ones((1,12,32))
dynamic_layer.active_out_channel = 10
out = dynamic_layer(inp)

out.size()

torch.Size([1, 10, 36])

In [165]:
from ofa.utils import make_divisible
class DynamicTemporalBlock(nn.Module):
    def __init__(self, maxin_channel, maxout_channel, kernel_size, stride, dilation, padding, dropout=0.2, expand_ratio_list=[1.0]):
        super(DynamicTemporalBlock, self).__init__()
        self.active_expand_ratio = max(expand_ratio_list)
        self.expand_ratio_list = expand_ratio_list
        self.active_out_channel = maxout_channel
        self.kernel_size = kernel_size
        self.stride = stride
        self.dilation = dilation
        self.padding = padding
        self.dropout = dropout
        max_middle_channel = self.active_middle_channels
        self.conv1 = DynamicConv1dWtNorm(maxin_channel, max_middle_channel, kernel_size,stride=stride, padding=padding, dilation=dilation)
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)
        self.conv2 = DynamicConv1dWtNorm(max_middle_channel, maxout_channel, kernel_size,stride=stride, padding=padding, dilation=dilation)
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)
        self.downsample = DynamicConv1dWtNorm(maxin_channel, maxout_channel, 1, weight_norm=False) if maxin_channel != maxout_channel else None
        self.relu = nn.ReLU()
        self.init_weights()
    
    def init_weights(self):
        # self.conv1.conv.weight.data.normal_(0, 0.01)
        # self.conv2.conv.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.conv.weight.data.normal_(0, 0.01)
    
    @property
    def active_middle_channels(self):
        return round(self.active_out_channel * self.active_expand_ratio)
    
    def get_active_subnet(self, in_channel, preserve_weight=True):
        sub_layer = MyTemporalBlock(in_channel, self.active_out_channel, self.active_middle_channels, self.kernel_size, self.stride, self.dilation, self.padding, self.dropout)
        middle_channel = self.active_middle_channels
        out_channel = self.active_out_channel

        sub_layer.conv1.conv_g.data.copy_(self.conv1.conv_g.data[:middle_channel, :, :])
        sub_layer.conv1.conv_v.data.copy_(self.conv1.conv_v.data[:middle_channel, :in_channel, :])
        sub_layer.conv1.conv.bias.data.copy_(self.conv1.conv.bias.data[:middle_channel])
        
        sub_layer.conv2.conv_g.data.copy_(self.conv2.conv_g.data[:out_channel, :, :])
        sub_layer.conv2.conv_v.data.copy_(self.conv2.conv_v.data[:out_channel, :middle_channel, :])
        sub_layer.conv2.conv.bias.data.copy_(self.conv2.conv.bias.data[:out_channel])
        return sub_layer
        
    
    def forward(self, x):
        feature_dim = self.active_middle_channels
        self.conv1.active_out_channel = feature_dim
        self.conv2.active_out_channel = self.active_out_channel
        if self.downsample is not None:
            self.downsample.active_out_channel = self.active_out_channel
            
        out = self.conv1(x)
        out = self.chomp1(out)
        out = self.relu1(out)
        out = self.dropout1(out)
        
        out = self.conv2(out)
        out = self.chomp2(out)
        out = self.relu2(out)
        out = self.dropout2(out)
        
        res = x if self.downsample is None else self.downsample(x)
        
        return self.relu(out + res)
        
          

class Conv1dWtNorm(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, dilation=1, padding=0, dim=0, weight_norm=True):
        super(Conv1dWtNorm, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.dilation = dilation
        self.padding = padding
        self.conv = nn.Conv1d(self.in_channels, self.out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation)
        weight = self.conv.weight
        del self.conv._parameters["weight"]
        self.dim = dim
        if weight_norm:
            # self.conv_g = Parameter(
            #     norm_except_dim(weight, 2, dim=self.dim).data
            # )
            # self.conv_v = Parameter(Parameter(weight.data))
           
            self.register_parameter("conv_g", Parameter(
                norm_except_dim(weight, 2, dim=self.dim).data
            ))
            self.register_parameter("conv_v",Parameter(weight.data) )
        # self.active_out_channel = self.max_out_channels

    def get_active_filter(self):
        if weight_norm:
            return _weight_norm(self.conv_v, self.conv_g, self.dim)
        return self.conv.weight
    
    def init_weights(self):
        pass
        # self.conv.weight.data.normal_(0, 0.01)

    def forward(self, x, out_channel=None):
        # if out_channel is None:
        #     out_channel = self.active_out_channel
        # in_channel = x.size(1)
        filters = self.get_active_filter().contiguous()

        y = F.conv1d(x, filters, self.conv.bias, self.stride, self.padding , self.dilation, 1)
        return y
    
class MyTemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, middle_channel, kernel_size, stride, dilation, padding, dropout=0.2):
        super(MyTemporalBlock, self).__init__()
  
        self.conv1 = Conv1dWtNorm(n_inputs, middle_channel, kernel_size, stride=stride, padding=padding, dilation=dilation)
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        # copy three important things conv_g, conv_v
        self.conv2 = Conv1dWtNorm(middle_channel, n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation)
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()
        

    def init_weights(self):
        self.conv1.init_weights()
        self.conv2.init_weights()
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.conv1(x)
        out = self.chomp1(out)
        out = self.relu1(out)
        out = self.dropout1(out)
        
        out = self.conv2(out)
        out = self.chomp2(out)
        out = self.relu2(out)
        out = self.dropout2(out)
        
        res = x if self.downsample is None else self.downsample(x)
        
        return self.relu(out + res)
    
   

In [166]:
from turtle import forward
from ofa.utils import val2list
import random

class MyTemporalConvNet(nn.Module):
    def __init__(self, blocks):
        super(MyTemporalConvNet, self).__init__()
        self.blocks = blocks
    
    def eval(self):
        for block in self.blocks:
            block.eval()
    
    def train(self):
        for block in self.blocks:
            block.train()
            
    def forward(self, x):
        out = x
        for block in self.blocks:
            out = block(out)
        return out
    
class DynamicTemporalConvNet(nn.Module):
    def __init__(self, input_channel, num_channels, kernel_size=2, dropout=0.2, depth_list = [2], expand_ratio_list=[0.25]):
        super(DynamicTemporalConvNet, self).__init__()
        self.runtime_depth = 0
        self.max_depth = max(depth_list)
        self.depth_list = depth_list
        self.expand_ratio_list = expand_ratio_list
        self.blocks = []
        self.input_channel = input_channel
        for i in range(len(num_channels)):
            dilation_size = 2 ** i
            in_channels = input_channel if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            self.blocks += [DynamicTemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                                     padding=(kernel_size-1) * dilation_size, dropout=dropout, expand_ratio_list=self.expand_ratio_list)]
    
    def set_active_subnet(self, d=None, e=None, w=None):
        # current elastic elasticTCN doesn't support width multipliers
        if isinstance(d, list):
            d = d[0]
        expand_ratios = val2list(e, len(self.blocks))
        for block, expand_ratio in zip(self.blocks, expand_ratios):
            block.active_expand_ratio = expand_ratio
        if d is not None:
            self.runtime_depth = self.max_depth - d
    
    def eval(self):
        for block in self.blocks:
            block.eval()
    
    def train(self):
        for block in self.blocks:
            block.train()
    
    def set_max_net(self):
        self.set_active_subnet(d=max(self.depth_list), e=max(self.expand_ratio_list), w=None)
        
    def sample_active_subnet(self):
        # current elastic elasticTCN doesn't support width multipliers
        expand_setting = []
        for block in self.blocks:
            expand_setting.append(random.choice(block.expand_ratio_list))
        depth = random.choice(self.depth_list)
        
        arch_config = {
            "d" : depth,
            "e" : expand_setting,
            "w" : None
        }
        self.set_active_subnet(**arch_config)
        return arch_config
    
    def get_active_subnet(self):
        blocks = []
        input_channel = self.input_channel
        for block in self.blocks[:len(self.blocks)-self.runtime_depth]:
            blocks.append(block.get_active_subnet(in_channel=input_channel))
            input_channel = block.active_out_channel
        
        return MyTemporalConvNet(blocks=blocks) 
        
    def forward(self, x):
        out = x
        for block in self.blocks[:len(self.blocks)-self.runtime_depth]:
            out = block(out)
        return out

class MyTCN(nn.Module):
    def __init__(self, encoder, tcn, decoder, drop):
        super(MyTCN, self).__init__()
        self.encoder = encoder
        self.tcn = tcn
        self.decoder = decoder
        self.drop = drop
    
    def eval(self):
        self.encoder.eval()
        self.tcn.eval()
        self.decoder.eval()
        self.drop.eval()
    
    def train(self):
        self.encoder.train()
        self.tcn.train()
        self.decoder.train()
        self.drop.train()
        
        
    def forward(self, input):
        emb = self.drop(self.encoder(input))
        y = self.tcn(emb.transpose(1, 2)).transpose(1, 2)
        y = self.decoder(y)
        return y.contiguous()
        
class ElasticTCN(nn.Module):

    def __init__(self, input_size, output_size, num_channels,
                 kernel_size=2, dropout=0.3, emb_dropout=0.1, tied_weights=False, depth_list = [2], expand_ratio_list=[0.25]):
        super(ElasticTCN, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.encoder = nn.Embedding(output_size, input_size)
        self.tcn = DynamicTemporalConvNet(input_size, num_channels, kernel_size, dropout=dropout, depth_list=depth_list, expand_ratio_list=expand_ratio_list)
        self.last_channel = num_channels[-1]

        self.decoder = nn.Linear(num_channels[-1], output_size)
        if tied_weights:
            if num_channels[-1] != input_size:
                raise ValueError('When using the tied flag, nhid must be equal to emsize')
            self.decoder.weight = self.encoder.weight
            print("Weight tied")
        self.drop = nn.Dropout(emb_dropout)
        self.emb_dropout = emb_dropout
        self.init_weights()

    def eval(self):
        self.encoder.eval()
        self.tcn.eval()
        self.decoder.eval()
        self.drop.eval()
    
    def train(self):
        self.encoder.train()
        self.tcn.train()
        self.decoder.train()
        self.drop.train()
        
    def set_max_net(self):
        self.tcn.set_max_net()
    
    def sample_active_subnet(self):
        return self.tcn.sample_active_subnet()
    
    def set_active_subnet(self, d=None, e=None, w=None):
        self.tcn.set_active_subnet(d=d, e=e, w=w)
         
    def get_active_subnet(self):
        active_tcn = self.tcn.get_active_subnet()
        encoder = nn.Embedding(self.output_size, self.input_size)
        encoder.weight.data.copy_(self.encoder.weight.data)
        decoder = nn.Linear(self.last_channel, self.output_size)
        decoder.weight.data.copy_(self.decoder.weight.data)
        decoder.bias.data.copy_(self.decoder.bias.data)

        return MyTCN(
            encoder, active_tcn, decoder, nn.Dropout(self.emb_dropout)
        )
    
    
     
    def init_weights(self):
        self.encoder.weight.data.normal_(0, 0.01)
        self.decoder.bias.data.fill_(0)
        self.decoder.weight.data.normal_(0, 0.01)

    def forward(self, input):
        """Input ought to have dimension (N, C_in, L_in), where L_in is the seq_len; here the input is (N, L, C)"""
        emb = self.drop(self.encoder(input))
        y = self.tcn(emb.transpose(1, 2)).transpose(1, 2)
        y = self.decoder(y)
        return y.contiguous()
        

In [167]:
emsize = 600
n_words = 10000
### num_channels can be elastic
nhid = 600
levels = 4
num_chans = [nhid]*(levels-1) + [emsize]
dropout = 0.45
emb_dropout = 0.25
k_size = 3
tied = True
depth_list = [0,1,2]
expand_ratio_list = [0.1, 0.2, 0.25, 0.5, 1]
model = ElasticTCN(emsize, n_words, num_chans, dropout=dropout, emb_dropout=emb_dropout, kernel_size=k_size, tied_weights=tied, depth_list=depth_list, expand_ratio_list=expand_ratio_list)

Weight tied


In [168]:
subnet_config = model.sample_active_subnet()
print(subnet_config)
subnet = model.get_active_subnet()

{'d': 1, 'e': [0.2, 0.5, 0.25, 0.2], 'w': None}


In [169]:
model.set_max_net()
max_subnet = model.get_active_subnet()

In [170]:
input_sentence = torch.ones((1,600), dtype=torch.long)
model.set_max_net()
max_subnet.eval()
model.eval()
torch.equal(max_subnet(input_sentence), model(input_sentence))

True

In [171]:
subnet.eval()
assert not torch.equal(max_subnet(input_sentence), subnet(input_sentence))
assert not torch.equal(model(input_sentence), subnet(input_sentence))


In [172]:
model.set_active_subnet(**subnet_config)
model.eval()
torch.equal(model(input_sentence), subnet(input_sentence))

True