<a href="https://colab.research.google.com/github/comp6248-polaris/reproducibility_challenge/blob/master/dl_reproducibility.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
from torch import optim
from torch import nn
import torch.nn.functional as F
from torch.distributions import Bernoulli, Multinomial
from torch.autograd import Variable
from torch.utils.data import DataLoader, TensorDataset

from torchtext import datasets
from torchtext import data

import os
import numpy as np 
import random
from itertools import count
import collections
from tqdm import  tqdm

from sklearn.metrics import f1_score


## Preprocessing

In [0]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
SEED = 2019
set_seed(SEED)

In [0]:
TEXT = data.Field(sequential=True, tokenize='spacy', lower=True, fix_length=400) 
LABEL = data.LabelField(dtype=torch.float)

In [4]:
train, test_data = datasets.IMDB.splits(TEXT, LABEL) # 25,000 training and 25,000 testing data
train_data, valid_data = train.split(split_ratio=0.8) # split training data into 20,000 training and 5,000 vlidation sample

downloading aclImdb_v1.tar.gz


aclImdb_v1.tar.gz: 100%|██████████| 84.1M/84.1M [00:07<00:00, 10.9MB/s]


In [5]:
MAX_VOCAB_SIZE = 25_000

TEXT.build_vocab(train_data, max_size=MAX_VOCAB_SIZE, vectors="glove.6B.100d", unk_init = torch.Tensor.normal_)
LABEL.build_vocab(train_data)

.vector_cache/glove.6B.zip: 862MB [00:54, 15.9MB/s]                           
100%|█████████▉| 398983/400000 [00:19<00:00, 20372.74it/s]

In [6]:
print(f'Number of training examples: {len(train_data)}')
print(f'Number of validation examples: {len(valid_data)}')
print(f'Number of testing examples: {len(test_data)}')
print(f"Unique tokens in TEXT vocabulary: {len(TEXT.vocab)}")
print(f"Unique tokens in LABEL vocabulary: {len(LABEL.vocab)}")

Number of training examples: 20000
Number of validation examples: 5000
Number of testing examples: 25000
Unique tokens in TEXT vocabulary: 25002
Unique tokens in LABEL vocabulary: 2


In [0]:
BATCH_SIZE = 1

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size=BATCH_SIZE,
    device=device)

## Networks architecture

In [0]:
class CNN_LSTM(nn.Module):
  
    def __init__(self,input_dim, embedding_dim, ker_size, n_filters, hidden_dim, output_dim):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        # self.embedding.weight.requires_grad = False
        self.conv = nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(ker_size, embedding_dim))
        self.lstm = nn.LSTM(n_filters*16, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, text): # input 1 by 20
        # CNN and LSTM network
        '''
        
        --- input & output dimension ---
        
        Input text: 
        
        **Embedding**
        1.Input:
        2.Output:
        
        **CNN**
        1. Input(minibatch×in_channels×iH×iW):
        2. Output:
        
        **LSTM**
        1. Inputs: input, (h_0, c_0)
        input(seq_len, batch, input_size):
        h_0:
        2. Outputs: output, (h_n, c_n)
        output:
        h_n(num_layers * num_directions, batch, hidden_size):
        

        '''
        embedded = self.embedding(text)
        conved = self.conv(embedded.unsqueeze(1))
        conved = F.relu(conved)
        conved = conved.squeeze(3) # conved is 1*128 * 16
        staked = conved.view(1, 128*16)
#         staked = staked.unsqueeze(1)
        packed_output, (hn,cn) = self.lstm(staked) 
        out = self.fc(hn.squeeze(0)) # input 1*128
        return out
    
class Policy_S(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super().__init__()    
        self.fc_s_hidden0 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_s_hidden1 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_s_hidden2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_s_output = nn.Linear(hidden_dim, output_dim)
        
    
    def forward(self, ht):
        # pi_s
        out = self.fc_s_hidden0(ht)
        out = self.fc_s_hidden1(out)
        out = self.fc_s_hidden2(out)
        s = self.fc_s_output(out)
        s_pro = torch.sigmoid(s)
        return s_pro

class Policy_C(nn.Module):
  
    def __init__(self, hidden_dim, output_dim):
        super().__init__()
        
        self.fc_c_hidden = nn.Linear(hidden_dim, hidden_dim)
        self.fc_c_output = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, ht):
        # pi_c
        out = self.fc_c_hidden(ht)
        c = self.fc_c_output(out)

        return c
  
  
class Policy_N(nn.Module):
  
    def __init__(self, hidden_dim, max_k):
        super().__init__()
        
        self.fc_n_hidden0 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_n_hidden1 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_n_hidden2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_n_output = nn.Linear(hidden_dim, max_k+1)

    def forward(self, ht):

       
        out = self.fc_n_hidden0(ht)
        out = self.fc_n_hidden1(out)
        out = self.fc_n_hidden2(out)
        n = self.fc_n_output(out)
        n_pro = torch.softmax(n,1)
    
        return n_pro
 

In [0]:
# import torchvision.models as models

# def print_model_parm_flops(model, input):

#     # prods = {}
#     # def save_prods(self, input, output):
#         # print 'flops:{}'.format(self.__class__.__name__)
#         # print 'input:{}'.format(input)
#         # print '_dim:{}'.format(input[0].dim())
#         # print 'input_shape:{}'.format(np.prod(input[0].shape))
#         # grads.append(np.prod(input[0].shape))

#     prods = {}
#     def save_hook(name):
#         def hook_per(self, input, output):
#             # print 'flops:{}'.format(self.__class__.__name__)
#             # print 'input:{}'.format(input)
#             # print '_dim:{}'.format(input[0].dim())
#             # print 'input_shape:{}'.format(np.prod(input[0].shape))
#             # prods.append(np.prod(input[0].shape))
#             prods[name] = np.prod(input[0].shape)
#             # prods.append(np.prod(input[0].shape))
#         return hook_per

#     list_1=[]
#     def simple_hook(self, input, output):
#         list_1.append(np.prod(input[0].shape))
#     list_2={}
#     def simple_hook2(self, input, output):
#         list_2['names'] = np.prod(input[0].shape)


#     multiply_adds = False
#     list_conv=[]
#     def conv_hook(self, input, output):
#         batch_size, input_channels, input_height, input_width = input[0].size()
#         output_channels, output_height, output_width = output[0].size()

#         kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * (2 if multiply_adds else 1)
#         bias_ops = 1 if self.bias is not None else 0

#         params = output_channels * (kernel_ops + bias_ops)
#         flops = batch_size * params * output_height * output_width

#         list_conv.append(flops)


#     list_linear=[] 
#     def linear_hook(self, input, output):
#         batch_size = input[0].size(0) if input[0].dim() == 2 else 1

#         weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)
#         bias_ops = self.bias.nelement()

#         flops = batch_size * (weight_ops + bias_ops)
#         list_linear.append(flops)

#     list_bn=[] 
#     def bn_hook(self, input, output):
#         list_bn.append(input[0].nelement())

#     list_relu=[] 
#     def relu_hook(self, input, output):
#         list_relu.append(input[0].nelement())

#     list_pooling=[]
#     def pooling_hook(self, input, output):
#         batch_size, input_channels, input_height, input_width = input[0].size()
#         output_channels, output_height, output_width = output[0].size()

#         kernel_ops = self.kernel_size * self.kernel_size
#         bias_ops = 0
#         params = output_channels * (kernel_ops + bias_ops)
#         flops = batch_size * params * output_height * output_width

#         list_pooling.append(flops)


            
#     def foo(net):
#         childrens = list(net.children())
#         if not childrens:
#             if isinstance(net, torch.nn.Conv2d):
#                 # net.register_forward_hook(save_hook(net.__class__.__name__))
#                 # net.register_forward_hook(simple_hook)
#                 # net.register_forward_hook(simple_hook2)
#                 net.register_forward_hook(conv_hook)
#             if isinstance(net, torch.nn.Linear):
#                 net.register_forward_hook(linear_hook)
#             if isinstance(net, torch.nn.BatchNorm2d):
#                 net.register_forward_hook(bn_hook)
#             if isinstance(net, torch.nn.ReLU):
#                 net.register_forward_hook(relu_hook)
#             if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d):
#                 net.register_forward_hook(pooling_hook)
#             return
#         for c in childrens:
#                 foo(c)

#     foo(model)
#     out = model(input.cuda())


#     total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling))
    
# #     print('  + Number of FLOPs: {0}'.format (total_flops))
#     return total_flops
#     # print list_bn


#     # print 'prods:{}'.format(prods)
#     # print 'list_1:{}'.format(list_1)
#     # print 'list_2:{}'.format(list_2)
#     # print 'list_final:{}'.format(list_final)

In [0]:
# INPUT_DIM = len(TEXT.vocab)
# EMBEDDING_DIM = 100
# KER_SIZE = 5
# HIDDEN_DIM = 128
# OUTPUT_DIM = 1
# CHUNCK_SIZE = 20
# TEXT_LEN = 400
# MAX_K = 3

# BATCH_SIZE = 1

# learning_rate = 0.001

# test_model = CNN_RNN(INPUT_DIM, EMBEDDING_DIM, KER_SIZE, HIDDEN_DIM).train().cuda()
# test_policy_s = Policy_S(HIDDEN_DIM, OUTPUT_DIM).train().cuda()
# test_policy_n = Policy_N(HIDDEN_DIM, MAX_K).train().cuda()
# test_policy_c = Policy_C(HIDDEN_DIM, OUTPUT_DIM).train().cuda()

# cnn_cost = -print_model_parm_flops(test_model, torch.randint(1,100, (1, 20)))
# p = torch.rand(1,128)
# s_cost = -print_model_parm_flops(test_policy_s, p)
# c_cost = -print_model_parm_flops(test_policy_c, p)
# n_cost = -print_model_parm_flops(test_policy_n, p)
# print('cnn_cost', cnn_cost)
# print('s_cost', s_cost)
# print('c_cost', c_cost)
# print('n_cost', n_cost)

In [0]:
s_cost = -66177
c_cost = -33153
n_cost = -66564
cnn_cost = -1026048.0

In [0]:
print(len(TEXT.vocab))
INPUT_DIM = len(TEXT.vocab)
EMBEDDING_DIM = 100
KER_SIZE = 5
HIDDEN_DIM = 128
OUTPUT_DIM = 1
CHUNCK_SIZE = 20
TEXT_LEN = 400
MAX_K = 3

BATCH_SIZE = 1

gamma = 0.99
alpha = 0.2
learning_rate = 0.001

model = CNN_RNN(INPUT_DIM, EMBEDDING_DIM, KER_SIZE, HIDDEN_DIM)
policy_s = Policy_S(HIDDEN_DIM, OUTPUT_DIM)
policy_n = Policy_N(HIDDEN_DIM, MAX_K)
policy_c = Policy_C(HIDDEN_DIM, OUTPUT_DIM)

# model.cuda()
# policy_s.cuda()
# policy_n.cuda()
# policy_c.cuda()

loss_function = nn.BCEWithLogitsLoss()


params = list(model.parameters()) + list(policy_s.parameters()) + list(policy_n.parameters()) + list(policy_c.parameters())
optimizer = optim.Adam(params, lr=learning_rate)

for epoch in range(1):
    print('epoch', epoch)
    print('train')
    # train
    model.train()
    policy_s.train()
    policy_n.train()
    policy_c.train()

    for index, (train) in enumerate(train_iterator):
        text = train.text.transpose(0,1)
        label = train.label

        pre_label = 0
        curr_step = 0
        
        text = text.view(CHUNCK_SIZE, BATCH_SIZE, CHUNCK_SIZE) # transform 1*400 to 20*1*20

        state_pool = []
        action_pool = []
        reward_pool = []
    
        for t in count(): # loop until a text is classified or currstep is up to 20
            reward = 0
            cost = 0

            if curr_step >= 20: # when curr step is beyond 0 - 19
                break

            text_input = text[curr_step] # text_input 1*20
            
            ht = model(text_input)
            s_pro = policy_s(ht)
            m = Bernoulli(s_pro.detach())
            s_action = m.sample()
            action = s_action.item()
            cost += (cnn_cost + s_cost)

            if int(action) == 1: # s_action is 1, then classify
                c = policy_c(ht)
                loss = loss_function(c.squeeze(1), label)
                reward += (loss + alpha*cost)
                state_pool.append(ht)
                reward_pool.append(reward)
                action_pool.append([s_action])
                break

            elif int(action) == 0: # s_action is 0, then compute next action 
                n_pro = policy_n(ht)
                m = Multinomial(1, n_pro.detach())
                n_action = m.sample()
                step = torch.argmax(n_action)
                curr_step += step
                cost += n_cost
                reward += alpha*cost


            state_pool.append(ht)
            reward_pool.append(reward)
            action_pool.append([s_action, n_action])
      
    
        # compute G with gamma, RINFORCE gradient descent is adopted here

        running_add =0 
        for i in reversed(range(len(reward_pool))):
            if reward_pool[i] == 0:
                running_add = 0
            else:
                running_add = running_add * gamma + reward_pool[i]
                reward_pool[i] = running_add
          
    # update policy
        optimizer.zero_grad()
        sum_pro = torch.zeros([1,1])
        policy_s.eval()
        policy_c.eval()
        policy_n.eval()
        
        for i in range(len(reward_pool)):
            state = state_pool[i]
            action = action_pool[i]
            reward = reward_pool[i]

            if i == len(reward_pool)-1: # the last time step 
                s_pro = policy_s(state)
                m = Bernoulli(s_pro)
                c = policy_c(state)
                c_pro = torch.sigmoid(c)
                sum_pro += (m.log_prob(action[0]) + torch.log(c_pro))
            else:
                s_pro = policy_s(state)
                m = Bernoulli(s_pro)
                n_pro = policy_n(state)
                n = Multinomial(1, n_pro)

                sum_pro += (m.log_prob(action[0]) + n.log_prob(action[1])) # Negtive score function x reward
                
        sum_loss = sum_pro*reward_pool[0]
        sum_loss.backward()  
            
        optimizer.step()          
#     print('index:{0}'.format(index))
        if index == 50: # train on 100 training data
            break

  # eval
    model.eval()
    policy_s.eval()
    policy_c.eval()
    policy_n.eval()
    print('eval')
    
    valid_labels = []
    predicted_labels = []
    
    for index, (valid) in enumerate(valid_iterator):
        text = valid.text.transpose(0,1)
        label = valid.label.cpu()
        curr_step = 0
        text = text.view(CHUNCK_SIZE, BATCH_SIZE, CHUNCK_SIZE)
    
        for t in count():
      
            if curr_step >= 20:
                break
        
            text_input = text[curr_step]
            ht = model(text_input)
            s_pro = policy_s(ht)
            m = Bernoulli(s_pro.detach())
            s_action = m.sample()
            action = s_action
      
      
            if int(action) == 1:
                c = policy_c(ht)        
                break
            elif int(action) == 0:
                n_pro = policy_n(ht)
                m = Multinomial(1, n_pro.detach())
                n_action = m.sample()
                step = torch.argmax(n_action)
                curr_step += step
        
        c_pro = torch.sigmoid(c)
      
        print('c_pro:{0}, pre_label:{1}, label:{2}'.format(c_pro, pre_label, label))
        
        valid_labels.append(label)
        predicted_labels.append(c_pro)
        if index == 2: # evaluate on 10 valid data
            break
      
    
  

25002
epoch 0
train
s_pro tensor([[0.5131]], grad_fn=<SigmoidBackward>)
s_pro tensor([[0.5113]], grad_fn=<SigmoidBackward>)
s_pro tensor([[0.5132]], grad_fn=<SigmoidBackward>)
c and loss tensor([[0.4975]]) tensor(0.6982, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
s_pro tensor([[0.4970]], grad_fn=<SigmoidBackward>)
s_pro tensor([[0.4970]], grad_fn=<SigmoidBackward>)
c and loss tensor([[0.5081]]) tensor(0.6770, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
s_pro tensor([[0.4844]], grad_fn=<SigmoidBackward>)
c and loss tensor([[0.5082]]) tensor(0.6768, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
s_pro tensor([[0.4945]], grad_fn=<SigmoidBackward>)
s_pro tensor([[0.4870]], grad_fn=<SigmoidBackward>)
s_pro tensor([[0.4892]], grad_fn=<SigmoidBackward>)
s_pro tensor([[0.4892]], grad_fn=<SigmoidBackward>)
c and loss tensor([[0.5360]]) tensor(0.7678, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
s_pro tensor([[0.4808]], grad_fn=<SigmoidBackward>)
c and loss tensor([[0.5249]]) tensor

Testing functions

In [0]:
reward_pool =  [-1,-2,-3,-4,-5,-6,-7,-8]
running_add = 0
for i in reversed(range(len(reward_pool))):
    print(i)
    if reward_pool[i] == 0:
         running_add = 0
    else:
        running_add = running_add * gamma + reward_pool[i]
        reward_pool[i] = running_add
        
print(reward_pool)

7
6
5
4
3
2
1
0
[-34.35730017846292, -33.694242604508, -32.0143864692, -29.30746108, -25.563092, -20.7708, -14.92, -8.0]


In [0]:
children = list(policy_c.children())
for child in children:
    print(child)

Linear(in_features=128, out_features=128, bias=True)
Linear(in_features=128, out_features=128, bias=True)
Linear(in_features=128, out_features=1, bias=True)


In [0]:
# import torch


# # ---- Public functions

# def add_flops_counting_methods(net_main_module):
#     """Adds flops counting functions to an existing model. After that
#     the flops count should be activated and the model should be run on an input
#     image.
    
#     Example:
    
#     fcn = add_flops_counting_methods(fcn)
#     fcn = fcn.cuda().train()
#     fcn.start_flops_count()
    
#     _ = fcn(batch)
    
#     fcn.compute_average_flops_cost() / 1e9 / 2 # Result in GFLOPs per image in batch
    
#     Important: dividing by 2 only works for resnet models -- see below for the details
#     of flops computation.
    
#     Attention: we are counting multiply-add as two flops in this work, because in
#     most resnet models convolutions are bias-free (BN layers act as bias there)
#     and it makes sense to count muliply and add as separate flops therefore.
#     This is why in the above example we divide by 2 in order to be consistent with
#     most modern benchmarks. For example in "Spatially Adaptive Computatin Time for Residual
#     Networks" by Figurnov et al multiply-add was counted as two flops.
    
#     This module computes the average flops which is necessary for dynamic networks which
#     have different number of executed layers. For static networks it is enough to run the network
#     once and get statistics (above example).
    
#     Implementation:
#     The module works by adding batch_count to the main module which tracks the sum
#     of all batch sizes that were run through the network.
    
#     Also each convolutional layer of the network tracks the overall number of flops
#     performed.
    
#     The parameters are updated with the help of registered hook-functions which
#     are being called each time the respective layer is executed.
    
#     Parameters
#     ----------
#     net_main_module : torch.nn.Module
#         Main module containing network
        
#     Returns
#     -------
#     net_main_module : torch.nn.Module
#         Updated main module with new methods/attributes that are used
#         to compute flops.
#     """
    
#     # adding additional methods to the existing module object,
#     # this is done this way so that each function has access to self object
#     net_main_module.start_flops_count = start_flops_count.__get__(net_main_module)
#     net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module)
#     net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module)
#     net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module)
    
#     net_main_module.reset_flops_count()
    
#     # Adding varialbles necessary for masked flops computation
#     net_main_module.apply(add_flops_mask_variable_or_reset)
    
#     return net_main_module



# def compute_average_flops_cost(self):
#     """
#     A method that will be available after add_flops_counting_methods() is called
#     on a desired net object.
    
#     Returns current mean flops consumption per image.
    
#     """
    
#     batches_count = self.__batch_counter__
    
#     flops_sum = 0
    
#     for module in self.modules():

#         if isinstance(module, torch.nn.Conv2d):

#             flops_sum += module.__flops__
    
    
#     return flops_sum / batches_count


# def start_flops_count(self):
#     """
#     A method that will be available after add_flops_counting_methods() is called
#     on a desired net object.
    
#     Activates the computation of mean flops consumption per image.
#     Call it before you run the network.
    
#     """
    
#     add_batch_counter_hook_function(self)
    
#     self.apply(add_flops_counter_hook_function)

    
# def stop_flops_count(self):
#     """
#     A method that will be available after add_flops_counting_methods() is called
#     on a desired net object.
    
#     Stops computing the mean flops consumption per image.
#     Call whenever you want to pause the computation.
    
#     """
    
#     remove_batch_counter_hook_function(self)
    
#     self.apply(remove_flops_counter_hook_function)

    
# def reset_flops_count(self):
#     """
#     A method that will be available after add_flops_counting_methods() is called
#     on a desired net object.
    
#     Resets statistics computed so far.
    
#     """
    
#     add_batch_counter_variables_or_reset(self)
    
#     self.apply(add_flops_counter_variable_or_reset)


# def add_flops_mask(module, mask):
    
#     def add_flops_mask_func(module):
        
#         if isinstance(module, torch.nn.Conv2d):
            
#             module.__mask__ = mask
    
#     module.apply(add_flops_mask_func)

    
# def remove_flops_mask(module):
    
#     module.apply(add_flops_mask_variable_or_reset)

    
# # ---- Internal functions


# def conv_flops_counter_hook(conv_module, input, output):
        
#     # Can have multiple inputs, getting the first one
#     input = input[0]
    
#     batch_size = input.shape[0]
#     output_height, output_width = output.shape[2:]
    
#     kernel_height, kernel_width = conv_module.kernel_size
#     in_channels = conv_module.in_channels
#     out_channels = conv_module.out_channels
    
#     # We count multiply-add as 2 flops
#     conv_per_position_flops = 2 * kernel_height * kernel_width * in_channels * out_channels
    
#     active_elements_count = batch_size * output_height * output_width
    
#     if conv_module.__mask__ is not None:
        
#         # (b, 1, h, w)
#         flops_mask = conv_module.__mask__.expand(batch_size, 1, output_height, output_width)
#         active_elements_count = flops_mask.sum()
        
    
#     overall_conv_flops = conv_per_position_flops * active_elements_count
      
#     bias_flops = 0
    
#     if conv_module.bias is not None:
        
#         bias_flops = out_channels * active_elements_count
    
#     overall_flops = overall_conv_flops + bias_flops
    
#     conv_module.__flops__ += overall_flops

    
# def batch_counter_hook(module, input, output):
    
#     # Can have multiple inputs, getting the first one
#     input = input[0]
    
#     batch_size = input.shape[0]
    
#     module.__batch_counter__ += batch_size


    
# def add_batch_counter_variables_or_reset(module):
    
#     module.__batch_counter__ = 0


# def add_batch_counter_hook_function(module):
    
#     if hasattr(module, '__batch_counter_handle__'):
        
#         return
    
#     handle = module.register_forward_hook(batch_counter_hook)
#     module.__batch_counter_handle__ = handle

    
# def remove_batch_counter_hook_function(module):
    
#     if hasattr(module, '__batch_counter_handle__'):
        
#         module.__batch_counter_handle__.remove()
        
#         del module.__batch_counter_handle__


# def add_flops_counter_variable_or_reset(module):
    
#     if isinstance(module, torch.nn.Conv2d):
        
#         module.__flops__ = 0

# def add_flops_counter_hook_function(module):
        
#     if isinstance(module, torch.nn.Conv2d):
        
#         if hasattr(module, '__flops_handle__'):
            
#             return

#         handle = module.register_forward_hook(conv_flops_counter_hook)
#         module.__flops_handle__ = handle

# def remove_flops_counter_hook_function(module):
    
#     if isinstance(module, torch.nn.Conv2d):
        
#         if hasattr(module, '__flops_handle__'):
            
#             module.__flops_handle__.remove()
            
#             del module.__flops_handle__

# # --- Masked flops counting


# # Also being run in the initialization
# def add_flops_mask_variable_or_reset(module):
    
#     if isinstance(module, torch.nn.Conv2d):
        
#         module.__mask__ = None

In [0]:
# fcn = model
# fcn = add_flops_counting_methods(fcn)
# fcn = fcn.cuda().train()
# fcn.start_flops_count()
# _ = fcn(text_input)    
# fcn.compute_average_flops_cost() / 1e9 / 2 # Result in GFLOPs per image in batch

In [0]:

t = torch.zeros([2,1,5,1])
print(t.shape)
print(t.squeeze(3).shape)
print(t.squeeze(0).shape)
print(t.squeeze().shape)
print(t.squeeze(1).shape)
print(t.unsqueeze(0).shape)
print(t.unsqueeze(1).shape)

torch.Size([2, 1, 5, 1])
torch.Size([2, 1, 5])
torch.Size([2, 1, 5, 1])
torch.Size([2, 5])
torch.Size([2, 5, 1])
torch.Size([1, 2, 1, 5, 1])
torch.Size([2, 1, 1, 5, 1])


In [0]:
m = Bernoulli(torch.tensor([0.4]))
action = m.sample()
print()
print(action)
print(m.logits)
print(m.log_prob(action))
n = Multinomial(1, torch.tensor([0.1,0.1,0.1,0.7]))
action = n.sample()
print()
print(action)
print(n.logits)
print(n.log_prob(action))
print()
for i in range(5):
    action = n.sample()
    print(action)


tensor([0.])
tensor([-0.4055])
tensor([-0.5108])

tensor([0., 0., 0., 1.])
tensor([-2.3026, -2.3026, -2.3026, -0.3567])
tensor(-0.3567)

tensor([0., 0., 0., 1.])
tensor([0., 0., 0., 1.])
tensor([1., 0., 0., 0.])
tensor([0., 0., 1., 0.])
tensor([0., 0., 0., 1.])
