In [33]:
#pip install torchstat

from torchstat import stat
import torchvision.models as models


import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(56180, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        print(x.shape)
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 56180)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

    
class ThreeLayer(nn.Module):

    def __init__(self, input_dim, output_dim):
        super(ThreeLayer, self).__init__()
        self.fc1 = nn.Linear(input_dim, 100)
        self.fc2 = nn.Linear(100, 50)
        self.fc3 = nn.Linear(50, output_dim)
        nn.init.kaiming_normal_(self.fc1.weight)
        nn.init.kaiming_normal_(self.fc2.weight)
        nn.init.kaiming_normal_(self.fc3.weight)

    def forward(self, x):
        print('three:', x.shape)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return x
    
    
class aLayer(nn.Module):

    def __init__(self, input_dim, output_dim):
        super(aLayer, self).__init__()
        self.fc1 = nn.Linear(input_dim, 100)
        self.exit_1 = nn.Linear(100, output_dim)
        self.fc2 = nn.Linear(100, 50)
        self.exit_2 = nn.Linear(50, output_dim)
        self.fc3 = nn.Linear(50, output_dim)
        self.exit_3 = nn.Linear(output_dim, output_dim)
        

    def forward(self, x):
        
        x = F.relu(self.fc1(x))
        exit_1 = self.exit_1(x)
        sm_1 = F.softmax(exit_1, dim=0)
        neg_entropy_1 = torch.sum(sm_1 * torch.log(sm_1))
        return exit_1
        
        
    
class abLayer(nn.Module):

    def __init__(self, input_dim, output_dim):
        super(abLayer, self).__init__()
        self.fc1 = nn.Linear(input_dim, 100)
        self.exit_1 = nn.Linear(100, output_dim)
        self.fc2 = nn.Linear(100, 50)
        self.exit_2 = nn.Linear(50, output_dim)
        self.fc3 = nn.Linear(50, output_dim)
        self.exit_3 = nn.Linear(output_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        exit_1 = self.exit_1(x)
        sm_1 = F.softmax(exit_1, dim=0)
        neg_entropy_1 = torch.sum(sm_1 * torch.log(sm_1))
        x = F.relu(self.fc2(x))
        exit_2 = self.exit_2(x)
        sm_2 = F.softmax(exit_2, dim=0)
        neg_entropy_2 = torch.sum(sm_2 * torch.log(sm_2))
        return 2, exit_2
        

    
    
class RNNLM(nn.Module):
    def __init__(self, params):
        super(RNNLM, self).__init__()
        
        #self.batch_first = True
        self.vocab_size = params['vocab_size']
        self.d_emb = params['d_emb']
        self.learning_rate = params['learning_rate']
        self.d_hid = params['d_hid']
        self.batch_size = params['batch_size']
        
        self.num_layers = 2
        
        self.lstm = nn.LSTM(self.d_emb, self.d_hid, self.num_layers, batch_first=True, dropout = 0.02)
        self.embeddings = nn.Embedding(self.vocab_size, self.d_emb)
        self.softmax = nn.LogSoftmax(dim=1)
        
        self.W = nn.Linear(self.d_hid, self.vocab_size) 
        
        
        
    def forward(self, batch):
      
        # each example in a batch is of the form <BOS> w1 w2 ... wn <EOS>
        # we want to predict everything except the <BOS> tokens
        bsz, seq_len = batch.size()
        embs = self.embeddings(batch)
        
        out, _ = self.lstm(embs)
        out = self.W(out)
        
        return out
    
    
class ThreeLayerBN(nn.Module):

    def __init__(self, input_dim, output_dim):
        super(ThreeLayerBN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 100)
        self.exit_1 = nn.Linear(100, output_dim)
        self.fc2 = nn.Linear(100, 50)
        self.exit_2 = nn.Linear(50, output_dim)
        self.fc3 = nn.Linear(50, output_dim)
        self.exit_3 = nn.Linear(output_dim, output_dim)
   
    

    def forward(self, x):
        x = F.relu(self.fc1(x))
        exit_1 = self.exit_1(x)
        sm_1 = F.softmax(exit_1, dim=0)
        neg_entropy_1 = torch.sum(sm_1 * torch.log(sm_1))
        x = F.relu(self.fc2(x))
        exit_2 = self.exit_2(x)
        sm_2 = F.softmax(exit_2, dim=0)
        neg_entropy_2 = torch.sum(sm_2 * torch.log(sm_2))
        x = F.relu(self.fc3(x))
        exit_3 = self.exit_3(x)
        return 3, exit_3


class StackedLSTM(nn.Module):

    def __init__(self, output_dim, hidden_dim=1200, embedding_dim=300, num_layers=3):
        super(StackedLSTM, self).__init__()
        self.hidden_dim = hidden_dim
        self.embedding_dim = embedding_dim
        self.output_dim = output_dim
        self.num_layers = num_layers

        self.inp = nn.Linear(self.embedding_dim, self.hidden_dim)
        self.rnns = [nn.LSTM(self.hidden_dim,
                             self.hidden_dim, batch_first=True)
                     for i in range(self.num_layers)]
        self.rnns = torch.nn.ModuleList(self.rnns)
        self.out = nn.Linear(self.hidden_dim, self.output_dim)

        nn.init.kaiming_normal_(self.inp.weight)
        nn.init.kaiming_normal_(self.out.weight)

    def forward(self, batch):
        inp = self.inp(batch)
        lstm_out = inp

        for i, layer in enumerate(self.rnns):
            lstm_out, (h, c) = layer(lstm_out)

        logits = self.out(h)

        return logits





In [34]:
#test size 9042
# 70.2-3     1.92-2  27.8-1   
#ThreeLayer ThreeLayerBN RNNLM aLayer abLayer
model = ThreeLayerBN(300, 25)
##input dim, out dim
inp_size = [300]
stat(model, inp_size)

#stat(model, (3, 224, 224))


      module name input shape output shape   params memory(MB)      MAdd     Flops  MemRead(B)  MemWrite(B) duration[%]  MemR+W(B)
0             fc1         300          100  30100.0       0.00  59,900.0  30,000.0    121600.0        400.0      20.73%   122000.0
1          exit_1         100           25   2525.0       0.00   4,975.0   2,500.0     10500.0        100.0      15.93%    10600.0
2             fc2         100           50   5050.0       0.00   9,950.0   5,000.0     20600.0        200.0      17.23%    20800.0
3          exit_2          50           25   1275.0       0.00   2,475.0   1,250.0      5300.0        100.0      15.98%     5400.0
4             fc3          50           25   1275.0       0.00   2,475.0   1,250.0      5300.0        100.0      14.52%     5400.0
5          exit_3          25           25    650.0       0.00   1,225.0     625.0      2700.0        100.0      15.60%     2800.0
total                                       40875.0       0.00  81,000.0  40,625.0 

In [None]:
def compute_LSTM_flops(module, inp, out):
    assert isinstance(module, nn.LSTM)
    assert len(inp.size()) == 2 and len(out.size()) == 2
    #nn.LSTM(self.d_emb, self.d_hid, self.num_layers, batch_first=True, dropout = 0.02)
    batch_size = inp.size()[0]
    return batch_size * inp.size()[1] * out.size()[1]
