# FAST

A PyTorch model with Skimming, rereading, and early stopping.

Use REINFORCE with baseline.

Reward function:
Use a single reward for an episode.
If the prediction is correct, the reward is 1. Else the reward is -1.

## 1. Set up Environment

In [33]:
import torch
from torch import optim
from torch import nn
import torch.nn.functional as F
from torch.distributions import Bernoulli, Categorical
from torchtext import datasets
from torchtext import data
from torchtext.data import Field, Dataset, Example
import os
import time
import numpy as np 
import random
import argparse
import pandas as pd

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device: ", device)

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
set_seed(2019)

Using device:  cuda


## 2. Prepare the Data

In [50]:
AG_news = pd.read_csv('AG_news_raw.csv', names=['Label', 'Text'], dtype={'Label':'int', 'Text':'str'}, index_col=False)
AG_news.head()

Unnamed: 0,Label,Text
0,3,"""Fears for T N pension after talks"",""Unions re..."
1,4,"""The Race is On: Second Private Team Sets Laun..."
2,4,"""Ky. Company Wins Grant to Study Peptides (AP)..."
3,4,"""Prediction Unit Helps Forecast Wildfires (AP)..."
4,4,"""Calif. Aims to Limit Farm-Related Smog (AP)"",..."


In [60]:
fields = {'Text':TEXT, 'Label':LABEL}
AG_news_dict = AG_news.to_dict()
fields.items

<function dict.items>

In [34]:
 class DataFrameDataset(Dataset):
     """Class for using pandas DataFrames as a datasource"""
     def __init__(self, examples, fields, filter_pred=None):
         """
         Create a dataset from a pandas dataframe of examples and Fields
         Arguments:
             examples pd.DataFrame: DataFrame of examples
             fields {str: Field}: The Fields to use in this tuple. The
                 string is a field name, and the Field is the associated field.
             filter_pred (callable or None): use only exanples for which
                 filter_pred(example) is true, or use all examples if None.
                 Default is None
         """
         self.examples = examples.apply(SeriesExample.fromSeries, args=(fields,), axis=1).tolist()
         if filter_pred is not None:
             self.examples = filter(filter_pred, self.examples)
         self.fields = dict(fields)
         # Unpack field tuples
         for n, f in list(self.fields.items()):
             if isinstance(n, tuple):
                 self.fields.update(zip(n, f))
                 del self.fields[n]

 class SeriesExample(Example):
     """Class to convert a pandas Series to an Example"""

     @classmethod
     def fromSeries(cls, data, fields):
         return cls.fromdict(data.to_dict(), fields)

     @classmethod
     def fromdict(cls, data, fields):
         ex = cls()

         for key, field in fields.items():
            setattr(ex, key, field.preprocess(data[key]))

    return ex


TypeError: 'DataFrame' object is not callable

In [27]:
TEXT = data.Field(sequential=True, tokenize='spacy', lower=True, fix_length=400) 
LABEL = data.LabelField(dtype=torch.float)

# Split the IMDB data into training, validation and testing sets
print('Splitting data...')
train, test_data = datasets.IMDB.splits(TEXT, LABEL) # 25,000 training and 25,000 testing data
train_data, valid_data = train.split(split_ratio=0.8) # split training data into 20,000 training and 5,000 vlidation sample

print("Number of training examples: ",{len(train_data)})
print("Number of validation examples: ",{len(valid_data)})
print("Number of testing examples: ",{len(test_data)})

MAX_VOCAB_SIZE = 25000

TypeError: splits() got multiple values for argument 'path'

In [4]:
# Define datatypes with instructions for converting to tensors 
# sequential=True:    tokenisation applied
# tokenize='spacy':   SpaCy tokenizer used to tokenize strings into sequential examples
# lower=True:         convert text to lowercase
# fix_length=400:     all examples padded to length 400
# dtype=torch.float:  torch.dtype class that represents a batch of examples of this kind of data
TEXT = data.Field(sequential=True, tokenize='spacy', lower=True, fix_length=400) 
LABEL = data.LabelField(dtype=torch.float)

# Split the IMDB data into training, validation and testing sets
print('Splitting data...')
train, test_data = datasets.IMDB.splits(TEXT, LABEL) # 25,000 training and 25,000 testing data
train_data, valid_data = train.split(split_ratio=0.8) # split training data into 20,000 training and 5,000 vlidation sample

print("Number of training examples: ",{len(train_data)})
print("Number of validation examples: ",{len(valid_data)})
print("Number of testing examples: ",{len(test_data)})

MAX_VOCAB_SIZE = 25000

# Construct vocab objects for the text and label fields using the training datasets only
# max_size=MAX_VOCAB_SIZE:         limit vocab size to 25,000 words
# vectors="glove.6B.100d"          produce word embeddings using glove.6B.100d (global vectors for word representation)
# unk_init = torch.Tensor.normal_  initialize out-of-vocabulary word vectors to a random sample from the tensor
print('Building vocabulary...')
TEXT.build_vocab(train_data, max_size=MAX_VOCAB_SIZE, vectors="glove.6B.100d", unk_init = torch.Tensor.normal_)
LABEL.build_vocab(train_data)

# Split the datasets into batches (how many items to send to the model in each iteration). 
# Model is updated in increments according to each batch to help prevent overfitting
BATCH_SIZE = 1  # the batch size for a dataset iterator
print('Building iterators...')
train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size=BATCH_SIZE,
    device=device)


Splitting data...
Number of training examples:  {1301}
Number of validation examples:  {325}
Number of testing examples:  {25000}
Building vocabulary...
Building iterators...


In [6]:
print(train)

<torchtext.datasets.imdb.IMDB object at 0x7ffa9be62898>


## 3. Build Network Architectures

### 3.1. Generate Embeddings: CNN and LSTM Network
Takes input, creates embedding, applies convolutional layer and one-layer LSTM. Outputs hidden representation (ht).

In [None]:
class CNN_LSTM(nn.Module):

    def __init__(self, input_dim, embedding_dim, ker_size, n_filters, hidden_dim): # Initialise nn.Module 
        super().__init__()
        
        # Define layers
        # 1. Embedding layer stores word embeddings as indices for later retreival
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        
        # 2. 2D convolution layer
            # in_channels:  1 input channel
            # out_channels: number of channels produced by the convolution set in training 
            # kernel_size:  size of convolving kernel set in training
        self.conv = nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(ker_size, embedding_dim))
        
        # 3. Single layer LSTM CNN with dropout (randomly turn off 10% of neurons to prevent co-adaptation)
            # n_filters:  input size i.e. the number of features in the input
            # hidden_dim: hidden size i.e. the number of features in the hidden state
        self.lstm = nn.LSTM(input_size=n_filters, hidden_size=hidden_dim)
        self.dropout = nn.Dropout(p=0.1)
        
        # 4. Run ReLU activation function over the embedding layer neurons
        self.relu = nn.ReLU()
   
    def forward(self, text, h_0): # Define forward pass with input text
        
        # Layer 1. Apply the embedding method to the text
        embedded = self.embedding(text)
        #print(embeded.size())
        
        # Layer 2. Apply convolution, ReLU and dropout
        conved = self.relu(self.conv(embedded.unsqueeze(1)))  # 1 * 128 * 16 * 1
        #print(conved.size())
        
        batch = conved.size()[0]
        conved = self.dropout(conved)
        conved = conved.squeeze(3)  # 1 * 128 * 16
        conved = torch.transpose(conved, 1, 2)  # 1 * 16 * 128
        conved = torch.transpose(conved, 1, 0)  # 16 * 1 * 128
        c_0 = torch.zeros([1, batch, 128]).to(device)
        output, (hidden, cell) = self.lstm(conved, (h_0, c_0))
        ht = hidden.squeeze(0)  # 1 * 128
        return ht

At every time step, the model reads one chunk which has a size of 20 words.

--- input & output dimension ---

Input text: 1 * 20

**Embedding**
1.Input: 1 * 20
2.Output: 1 * 20 * 100

**CNN**
1. Input(minibatch×in_channels×iH×iW): 1 * 1 * 20 * 100
2. Output(minibatch×out_channels×oH×oW): 1 * 128 * 16 * 1

**LSTM**
1. Inputs: input, (h_0, c_0)
input(seq_len, batch, input_size): (16, 1 , 128)
c_0(num_layers * num_directions, batch, hidden_size): (1 * 1, 1, 128)
h_0(num_layers * num_directions, batch, hidden_size): (1 * 1, 1, 128)
2. Outputs: output, (h_n, c_n)
output:
h_n(num_layers * num_directions, batch, hidden_size): (1 * 1, 1, 128)


### 3.2. Stopping Module: MLP
Three hidden-layer MLP with 128 hidden units per layer

In [None]:
class Policy_S(nn.Module):

    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()    
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(p=0.1)
        self.relu = nn.ReLU()
          
    def forward(self, ht):
        # Input -> hidden 1
        out = self.fc1(ht)
        out = self.dropout(out)
        out = self.relu(out)
        # Hidden 1 -> hidden 2
        out = self.fc2(out)
        out = self.dropout(out)
        out = self.relu(out)
        # Hidden 2 -> hidden 3
        out = self.fc3(out)
        out = self.dropout(out)
        out = self.relu(out)
        # Hidden 3 -> output
        out = self.fc4(out)
        out = torch.sigmoid(out)
        return out

### 3.3. Policy Module (re-reading and skipping): MLP
Three hidden-layer MLP with 128 hidden units per layer

In [None]:
class Policy_N(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()    
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(p=0.1)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)
         
    def forward(self, ht):
        out = self.fc1(ht)
        out = self.dropout(out)
        out = self.relu(out)
        
        out = self.fc2(out)
        out = self.dropout(out)
        out = self.relu(out)
        
        out = self.fc3(out)
        out = self.dropout(out)
        out = self.relu(out)
        out = self.fc4(out)
        out = self.softmax(out)
        return out

### 3.4. Classifier: MLP
Single-layer MLP with 128 hidden units per layer

In [None]:
class Policy_C(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        #self.fc = nn.Linear(input_dim, output_dim)
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.dropout = nn.Dropout(p=0.1)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()
         
    def forward(self, ht):
        #return self.fc(ht)
        out = self.fc1(ht)
        out = self.dropout(out)
        out = self.relu(out)
        out = self.fc2(out)
        return out

### 3.5. Value Network??

In [None]:
class ValueNetwork(nn.Module):
    '''Baseline
    Reduce the variance.

    Single-layer MLP with 128 hidden units.
    '''
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        #self.fc = nn.Linear(input_dim, output_dim)
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.dropout = nn.Dropout(p=0.1)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()
        
    def forward(self, ht):
        out = self.fc1(ht)
        out = self.dropout(out)
        out = self.relu(out)
        out = self.fc2(out)
        return out

## 4. Set Parameters for Model Training 

In [None]:
# set up parameters
INPUT_DIM = len(TEXT.vocab)
EMBEDDING_DIM = 100
KER_SIZE = 5
HIDDEN_DIM = 128
OUTPUT_DIM = 1
CHUNCK_SIZE = 20
MAX_K = 4  # the output dimension for step size 0, 1, 2, 3
LABEL_DIM = 2
N_FILTERS = 128
BATCH_SIZE = 1


#gamma = args.gamma
#alpha = args.alpha
#learning_rate = 0.001

gamma = 0.2
alpha = 0.99
learning_rate = 0.001


# the number of training epoches
num_of_epoch = 10
# the number of batch size for gradient descent when training
batch_sz = 50

# set up the criterion
criterion = nn.CrossEntropyLoss().to(device)
# set up models
clstm = CNN_LSTM(INPUT_DIM, EMBEDDING_DIM, KER_SIZE, N_FILTERS, HIDDEN_DIM).to(device)
print(clstm)
policy_s = Policy_S(HIDDEN_DIM, HIDDEN_DIM, OUTPUT_DIM).to(device)
policy_n = Policy_N(HIDDEN_DIM, HIDDEN_DIM, MAX_K).to(device)
policy_c = Policy_C(HIDDEN_DIM, HIDDEN_DIM, LABEL_DIM).to(device)
value_net = ValueNetwork(HIDDEN_DIM, HIDDEN_DIM, OUTPUT_DIM).to(device)


# set up optimiser
params_pg = list(policy_s.parameters()) + list(policy_c.parameters()) + list(policy_n.parameters())
optim_loss = optim.Adam(clstm.parameters(), lr=learning_rate)
optim_policy = optim.Adam(params_pg, lr=learning_rate)
optim_value = optim.Adam(value_net.parameters(), lr=learning_rate)

# add pretrained embeddings
pretrained_embeddings = TEXT.vocab.vectors
clstm.embedding.weight.data.copy_(pretrained_embeddings)
clstm.embedding.weight.requires_grad = True  # update the initial weights

# set the default tensor type for GPU
#torch.set_default_tensor_type('torch.cuda.FloatTensor')

def finish_episode(policy_loss_sum, encoder_loss_sum, baseline_value_batch):
    '''
    Called when a data sample has been processed.
    '''
    baseline_value_sum = torch.stack(baseline_value_batch).sum()
    policy_loss = torch.stack(policy_loss_sum).mean()
    encoder_loss = torch.stack(encoder_loss_sum).mean()
    objective_loss = encoder_loss - policy_loss + baseline_value_sum
    # set gradient to zero
    optim_loss.zero_grad()
    optim_policy.zero_grad()
    optim_value.zero_grad()
    # back propagation
    objective_loss.backward()
    # gradient update
    optim_loss.step()
    optim_policy.step()
    optim_value.step()

## 5. Prepare Model Evaluation Metrics

Compute FLOPs(Floating point operations) of the models.

cnn_cost: CNN model (which is separated from CNN_LSTM)
s_cost: policy s (stopping module)
c_cost: policy c (classifier)
lstm_cost: LSTM model (which is separated from CNN_LSTM)
cnn_whole: CNN model with whole reading(400 words).

The costs are based on the size of one chunk (20 words)

In [None]:
# set up parameters
CHUNK_SIZE = 20
BATCH_SIZE = 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

cnn_cost = 1024000
s_cost = 50050
c_cost = 16770 
n_cost = 50310
lstm_cost = 286720
clstm_cost = cnn_cost + lstm_cost
cnn_whole = 25344000


def sample_policy_s(ht, policy_s):
    '''
    Draw a stopping decision from a Bernoulli distribution specified by policy s.
    '''
    s_prob = policy_s(ht)
    m = Bernoulli(s_prob)
    stop_decision = m.sample()
    # compute the log prob
    log_prob_s = m.log_prob(stop_decision)
    return stop_decision, log_prob_s

def sample_policy_c(output_c):
    '''
    Draw a label from a multinomial distribution specified by policy c.
    '''
    prob_c = F.softmax(output_c, dim=1)
    m = Categorical(prob_c)
    pred_label = m.sample()
    log_prob_c = m.log_prob(pred_label)
    return pred_label, log_prob_c

def sample_policy_n(ht, policy_n):
    '''
    Draw an action from a multinomial distribution specified by policy n.
    '''
    action_probs = policy_n(ht)
    m = Categorical(action_probs)
    step = m.sample()
    log_prob_n = m.log_prob(step)
    return step.item(), log_prob_n
    
def compute_policy_value_losses(cost_ep, loss, saved_log_probs, baseline_value_ep, alpha, gamma):
    '''compute the policy losses and value losses for the current episode
    '''
    # normalise cost
    norm_cost_ep = (cost_ep - np.mean(cost_ep)) / (np.std(cost_ep) + 1e-7)
    #print('norm_cost_ep:', norm_cost_ep)
    reward_ep = - alpha * norm_cost_ep
    reward_ep[-1] -= loss.item()
    # compute discounted rewards
    discounted_rewards = [r * gamma ** i for i, r in enumerate(reward_ep)]
    policy_loss_ep = []
    value_losses = []
    for i, log_prob in enumerate(saved_log_probs):
        # baseline_value_ep[i].item(): updating the policy loss doesn't include the gradient of baseline values
        advantage = sum(discounted_rewards) - baseline_value_ep[i].item()
        policy_loss_ep.append(log_prob * advantage)
        value_losses.append((sum(discounted_rewards) - baseline_value_ep[i]) ** 2)   
    return policy_loss_ep, value_losses


def evaluate(clstm, policy_s, policy_n, policy_c, iterator):
    '''
    Evaluate a model with skimming, rereading, and early stopping
    and compute the average FLOPs per data.
    '''
    # set the models in evaluation mode
    clstm.eval()
    policy_s.eval()
    policy_n.eval()
    policy_c.eval()
    count_all = 0
    count_correct = 0
    start = time.time()
    # the sum of FLOPs of the iterator set
    flops_sum = 0
    with torch.no_grad():
        for batch in iterator:
            label = batch.label.to(torch.long)  # for cross entropy loss, the long type is required
            text = batch.text.view(CHUNK_SIZE, BATCH_SIZE, CHUNK_SIZE) # transform 1*400 to 20*1*20
            curr_step = 0
            h_0 = torch.zeros([1,1,128]).to(device)
            count = 0
            while curr_step < 20 and count < 5: # loop until a text can be classified or currstep is up to 20
                count += 1
                # pass the input through cnn-lstm and policy s
                text_input = text[curr_step] # text_input 1*20
                ht = clstm(text_input, h_0)  # 1 * 128
                h_0 = ht.unsqueeze(0)  # 1 * 1 * 128, next input of lstm
                # draw a stop decision
                stop_decision, log_prob_s = sample_policy_s(ht, policy_s)
                flops_sum += clstm_cost + s_cost
                stop_decision = stop_decision.item()
                if stop_decision == 1: # classify
                    break
                else:
                    # draw an action (reread or skip)
                    step, log_prob_n = sample_policy_n(ht, policy_n)
                    flops_sum += n_cost
                    curr_step += int(step)  # reread or skip
            # draw a predicted label
            output_c = policy_c(ht)
            flops_sum += c_cost
            # draw a predicted label 
            pred_label, log_prob_c = sample_policy_c(output_c)
            if pred_label.item() == label:
                count_correct += 1
            count_all += 1
    print('Evaluation time elapsed: %.2f s' % (time.time() - start))
    avg_flop_per_sample = int(flops_sum / len(iterator))
    print('Average FLOPs per sample: ', avg_flop_per_sample)
    return count_all, count_correct


def evaluate_earlystop(clstm, policy_s, policy_c, iterator):
    '''
    Evaluate a early stopping model with only a stopping module
    and compute the average FLOPs per data.
    '''
    # set the models in evaluation mode
    clstm.eval()
    policy_s.eval()
    policy_c.eval()
    count_all = 0
    count_correct = 0
    start = time.time()
    # the sum of FLOPs of the iterator set
    flops_sum = 0
    with torch.no_grad():
        for batch in iterator:
            label = batch.label.to(torch.long) # 64
            text = batch.text.view(CHUNK_SIZE, BATCH_SIZE, CHUNK_SIZE) # transform 1*400 to 20*1*20
            curr_step = 0
            # set up the initial input for lstm
            h_0 = torch.zeros([1,1,128]).to(device) 
            saved_log_probs = []
            while (curr_step < 20):
                '''
                loop until stop decision equals 1 
                or the whole text has been read
                '''
                # read a chunk
                text_input = text[curr_step]
                # hidden state
                ht = clstm(text_input, h_0)  # 1 * 128
                h_0 = ht.unsqueeze(0).cuda()  # 1 * 1 * 128, next input of lstm
                # draw a stop decision
                stop_decision, log_prob_s = sample_policy_s(ht, policy_s)
                stop_decision = stop_decision.item()
                flops_sum += clstm_cost + s_cost
                if stop_decision == 1:
                    break
                else:
                    curr_step += 1
            # output of classifier       
            output_c = policy_c(ht)
            flops_sum += c_cost
            # draw a predicted label 
            pred_label, log_prob_c = sample_policy_c(output_c)
            if pred_label.item() == label:
                count_correct += 1
            count_all += 1     
    print('Evaluation time elapsed: %.2f s' % (time.time() - start))
    avg_flop_per_sample = int(flops_sum / len(iterator))
    print('Average FLOPs per sample: ', avg_flop_per_sample)  
    return count_all, count_correct


def print_model_parm_flops(model, input):
    '''
    Compute FLOPs of a model.
    '''
    multiply_adds = False
    list_conv=[]
    def conv_hook(self, input, output):
        batch_size, input_channels, input_height, input_width = input[0].size()
        print('input size', input[0].size())
        print('output size:', output[0].size())
        output_channels, output_height, output_width = output[0].size()
        kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * (2 if multiply_adds else 1)-1
        bias_ops = 1 if self.bias is not None else 0
        params = output_channels * (kernel_ops + bias_ops)
        flops = batch_size * params * output_height * output_width
        list_conv.append(flops)

    list_linear=[] 
    def linear_hook(self, input, output):
        batch_size = input[0].size(0) if input[0].dim() == 2 else 1
        weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)
        bias_ops = self.bias.nelement()
        flops = batch_size * (weight_ops + bias_ops)
        list_linear.append(flops)

    list_bn=[] 
    def bn_hook(self, input, output):
        list_bn.append(input[0].nelement())
        
    list_sig = []
    def sig_hook(self, input, output):
        list_sig.append(input[0].nelement())
    
    list_softmax = []
    def softmax_hook(self, input, output):
        print(input[0].nelement())
        list_softmax.append(input[0].nelement())
        
    list_relu=[] 
    def relu_hook(self, input, output):
        list_relu.append(input[0].nelement())

    list_pooling=[]
    def pooling_hook(self, input, output):
        batch_size, input_channels, input_height, input_width = input[0].size()
        output_channels, output_height, output_width = output[0].size()
        kernel_ops = self.kernel_size * self.kernel_size
        bias_ops = 0
        params = output_channels * (kernel_ops + bias_ops)
        flops = batch_size * params * output_height * output_width
        list_pooling.append(flops)
    
    def foo(net):
        childrens = list(net.children())
        if not childrens:
            if isinstance(net, torch.nn.Conv2d):
                net.register_forward_hook(conv_hook)
            if isinstance(net, torch.nn.Linear):
                net.register_forward_hook(linear_hook)
            if isinstance(net, torch.nn.BatchNorm2d):
                net.register_forward_hook(bn_hook)
            if isinstance(net, torch.nn.ReLU):
                net.register_forward_hook(relu_hook)
            if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d):
                net.register_forward_hook(pooling_hook)
            if isinstance(net, torch.nn.Sigmoid):
                net.register_forward_hook(sig_hook)
            if isinstance(net, torch.nn.Softmax):
                net.register_forward_hook(softmax_hook)
            return
        for c in childrens:
                foo(c)
    foo(model)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    out = model(input.to(device))
    total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling) + sum(list_sig) +sum(list_softmax)) 
    return total_flops

## 6. Train Model

In [None]:
def main():
    '''
    Training and evaluation of the model.
    '''
    print('Training starts...')
    for epoch in range(num_of_epoch):
        print('\nEpoch', epoch+1)
        # log the start time of the epoch
        start = time.time()
        # set the models in training mode
        clstm.train()
        policy_s.train()
        policy_n.train()
        policy_c.train()
        # reset the count of reread_or_skim_times
        reread_or_skim_times = 0
        policy_loss_sum = []
        encoder_loss_sum = []
        baseline_value_batch = []
        for index, train in enumerate(train_iterator):
            label = train.label.to(torch.long)  # for cross entropy loss, the long type is required
            text = train.text.view(CHUNCK_SIZE, BATCH_SIZE, CHUNCK_SIZE) # transform 1*400 to 20*1*20
            curr_step = 0  # the position of the current chunk
            h_0 = torch.zeros([1,1,128]).to(device)  # run on GPU
            count = 0  # maximum skim/reread time: 5
            baseline_value_ep = []
            saved_log_probs = []  # for the use of policy gradient update
            # collect the computational costs for every time step
            cost_ep = []  
            while curr_step < CHUNCK_SIZE and count < 5: 
                # Loop until a text can be classified or currstep is up to 20 or count reach the maximum i.e. 5.
                # update count
                count += 1
                # pass the input through cnn-lstm and policy s
                text_input = text[curr_step] # text_input 1*20
                ht = clstm(text_input, h_0)  # 1 * 128
                # separate the value which is the input of value net
                ht_ = ht.clone().detach().requires_grad_(True)
                # compute a baseline value for the value network
                bi = value_net(ht_)
                # 1 * 1 * 128, next input of lstm
                h_0 = ht.unsqueeze(0)
                # draw a stop decision
                stop_decision, log_prob_s = sample_policy_s(ht, policy_s)
                stop_decision = stop_decision.item()
                if stop_decision == 1: # classify
                    break
                else: 
                    reread_or_skim_times += 1
                    # draw an action (reread or skip)
                    step, log_prob_n = sample_policy_n(ht, policy_n)
                    curr_step += int(step)  # reread or skip
                    if curr_step < CHUNCK_SIZE and count < 5:
                        # If the code can still execute the next loop, it is not the last time step.
                        cost_ep.append(clstm_cost + s_cost + n_cost)
                        # add the baseline value
                        baseline_value_ep.append(bi)
                        # add the log prob for the current actions
                        saved_log_probs.append(log_prob_s + log_prob_n)
            # draw a predicted label
            output_c = policy_c(ht)
            # cross entrpy loss input shape: input(N, C), target(N)
            loss = criterion(output_c, label)  # positive value
            # draw a predicted label 
            pred_label, log_prob_c = sample_policy_c(output_c)
            if stop_decision == 1:
                # add the cost of the last time step
                cost_ep.append(clstm_cost + s_cost + c_cost)
                saved_log_probs.append(log_prob_s + log_prob_c)
            else:
                # add the cost of the last time step
                cost_ep.append(clstm_cost + s_cost + c_cost + n_cost)
                # At the moment, the probability of drawing a stop decision is 1,
                # so its log probability is zero which can be ignored in th sum.
                saved_log_probs.append(log_prob_c.unsqueeze(0))
            # add the baseline value
            baseline_value_ep.append(bi)
            # add the cross entropy loss
            encoder_loss_sum.append(loss)
            # compute the policy losses and value losses for the current episode
            policy_loss_ep, value_losses = compute_policy_value_losses(cost_ep, loss, saved_log_probs, baseline_value_ep, alpha, gamma)
            policy_loss_sum.append(torch.cat(policy_loss_ep).sum())
            baseline_value_batch.append(torch.cat(value_losses).sum())
            # update gradients
            if (index + 1) % batch_sz == 0:  # take the average of 50 samples
                finish_episode(policy_loss_sum, encoder_loss_sum, baseline_value_batch)
                del policy_loss_sum[:], encoder_loss_sum[:], baseline_value_batch[:]
                
            if (index + 1) % 2000 == 0:
                print('\n current episode: ',{index + 1})
                # log the current position of the text which the agent has gone through
                print('curr_step: ', curr_step)
                # log the sum of the rereading and skimming times
                print('current reread_or_skim_times: ',{reread_or_skim_times})


        print('Epoch time elapsed: %.2f s' % (time.time() - start))
        print('reread_or_skim_times in this epoch:', reread_or_skim_times)
        count_all, count_correct = evaluate(clstm, policy_s, policy_n, policy_c, valid_iterator)
        print('Epoch: %s, Accuracy on the validation set: %.2f' % (epoch + 1, count_correct / count_all))
        count_all, count_correct = evaluate(clstm, policy_s, policy_n, policy_c, train_iterator)
        print('Epoch: %s, Accuracy on the training set: %.2f' % (epoch + 1, count_correct / count_all))
        
    print('Compute the accuracy on the testing set...')
    count_all, count_correct = evaluate(clstm, policy_s, policy_n, policy_c, test_iterator)
    print('Accuracy on the testing set: %.2f' % (count_correct / count_all))

if __name__ == '__main__':
    main()