# Long short term memory

We previously explored RNNs, neural networks that are able to propagate some hidden state through a rolled out version of itself. A major problem with RNNs is exploding or vanishing gradients. Gradient clipping solves the exploding gradient problem, but the vanishing gradient problem is harder to solve. LSTMs propose a different architecture which benefits from a hidden state like RNNs, but mitigates the vanishing gradient problem. Another issue RNNs have is that the hidden state often forgets information from a while ago in the sequence and is more biased towards more recent tokens. LSTMs also address this issue with their gated structure. 


In [275]:
from torch import nn
import torch 
from torch.nn.utils.rnn import PackedSequence
import torch.nn.functional
from torch.utils.data import random_split
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import AG_NEWS
from torch import nn
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

## LSTM model defintion
The goal here is to build a substitute for the torch.nn.rnn.lstm module. This should be able to handle packed sequences and batched data the same way the source code for that module does. For now, does not need to have multiple layers or be bidirectional.

In [276]:
class lstm(nn.Module): 
    def __init__(self, input_size, hidden_dim, output_dim=1) -> None:
        super().__init__()
        self.input_dim = input_size
        self.hidden_dim  = hidden_dim
        
        self.forget_gate = nn.Sequential(
            nn.Linear(input_size+hidden_dim, hidden_dim),
            nn.Sigmoid()
        )
        self.input_gate = nn.Sequential(
            nn.Linear(input_size+hidden_dim, hidden_dim),
            nn.Sigmoid()
        )
        self.input_node = nn.Sequential(
            nn.Linear(input_size+hidden_dim, hidden_dim),
            nn.Tanh()
        )
        self.output_gate = nn.Sequential(
            nn.Linear(input_size+hidden_dim, hidden_dim),
            nn.Sigmoid()
        )
        self.tanh = nn.Tanh()

        # this output layer can be fancier if needed by the use case
        self.output_layer = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, h_in=None, c_in=None):

        if isinstance(x, PackedSequence):
            input, batch_sizes, sorted_indices, unsorted_indices = x
            max_batch_size = batch_sizes[0]
            if h_in is None: 
                h_in = self.init_h(max_batch_size, x)
                c_in = self.init_h(max_batch_size, x)
        
        data_offset = 0
        outputs = []
        for batch_size in batch_sizes:
            print(f"original input shape: {input.shape}")
            current_input = input[data_offset:data_offset + batch_size]
            data_offset += batch_size
            print(f"batch_size: {batch_size}")
            active_h = h_in[:,:batch_size,:]
            active_c = c_in[:,:batch_size,:]
            print(f"hidden_shape = {active_h.shape}")
            print(f"current input size = {current_input.shape}")
            combined = torch.cat([current_input.unsqueeze(0), active_h], dim=2)
            print(f"combined shape = {combined.shape}")
            
            i_gate_output = self.input_gate(combined)
            i_node_output = self.input_node(combined)
            o_gate_output = self.output_gate(combined)
            f_gate_output = self.forget_gate(combined)

            print(type(f_gate_output * active_c))
            c_out = (f_gate_output * active_c) + (i_node_output * i_gate_output)

            h_out = self.tanh(c_out) * o_gate_output
            out = self.output_layer(h_out)

            h_in[:batch_size] = h_out
            c_in[:batch_size] = c_out
            outputs.append(out)

            # Handle decreasing batch size
            h_in[batch_size:] = 0
            c_in[batch_size:] = 0

       
        if isinstance(x, PackedSequence):
            output_packed = PackedSequence(outputs, batch_sizes, sorted_indices, unsorted_indices)
            return output_packed, h_out, c_out
        
        return out, h_out, c_out
    
    
    def init_h(self, batch_size, x):
        #alternatives include but not limited to Xavier/Kaiminh initialization
        return torch.zeros(1, batch_size, self.hidden_dim, dtype=x.data.dtype, device=x.data.device)

In [277]:
class newsLSTM(nn.Module): 
    def __init__(self, vocab_size, embed_size, hidden_size) -> None:
        super(newsLSTM, self).__init__()
        
        self.encoder = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.hidden_size = hidden_size 
        
        self.lstm = lstm(input_size=embed_size, hidden_dim=hidden_size)
        
        self.hidden2label = nn.Linear(2*hidden_size, 4)
        self.softmax = nn.LogSoftmax(dim=1)
        self.dropoutLayer = nn.Dropout(p=0.5)

    def forward(self, x, x_len):
        embedded = self.encoder(x)
        x_packed = nn.utils.rnn.pack_padded_sequence(embedded, x_len, batch_first=True, enforce_sorted=False)
        output, h_t, c_t = self.lstm(x_packed)  # Pass the initial hidden state 'h' to the RNN
        print(h_t.shape)
        
        hidden = self.dropoutLayer(torch.cat((h_t[-2,:,:], h_t[-1,:,:]), dim=1))
        
        # Linear layer and softmax
        label_space = self.hidden2label(hidden)
        
        return label_space
    

We repeat the exact same news classification task performed by in the bidirectionalRNN subdir

In [278]:
train_iter = AG_NEWS(split='train')

# Convert to list to enable random splitting
train_dataset = list(train_iter)

#80-20 train-val split 
train_size = int(len(train_dataset) * 0.8)  
val_size = len(train_dataset) - train_size  
train_data, val_data = random_split(train_dataset, [train_size, val_size])

tokenizer = get_tokenizer("basic_english")

def yield_tokens(data_iter):
    for text in data_iter:
        yield tokenizer(text)

VOCAB_SIZE = 5000

# Build vocab based on the train_data
train_data_iter = (text for _, text in train_data)
vocab = build_vocab_from_iterator(yield_tokens(train_data_iter), specials=["<unk>"], max_tokens=VOCAB_SIZE)
vocab.set_default_index(vocab["<unk>"])

In [279]:
train_data[0]

(1,
 'Bush Courts Pa. Swing Voters on Economy (AP) AP - President Bush wooed suburban swing voters Thursday with hopeful words about the economy, contending his administration is making progress for American workers and portraying rival John Kerry as a tax-and-spend Democrat.')

In [280]:
vocab(['word', 'probably', 'unknown', 'gibberish'])

[2102, 1693, 0, 0]

In [281]:
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1

In [282]:
vocab.lookup_tokens([4999])

['maine']

In [283]:
def collate_batch(batch):
    device = torch.device("cpu")
    label_list, text_list, lengths = [], [], []
    
    # Sort the batch in the descending order
    batch.sort(key=lambda x: len(x[1]), reverse=True)
    
    for _label, _text in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        lengths.append(processed_text.size(0))
        
    label_list = torch.tensor(label_list, dtype=torch.int64)
    lengths = torch.tensor(lengths, dtype=torch.int64)
    
    # Pad sequences
    text_list = pad_sequence(text_list, batch_first=True)
    
    return label_list.to(device), text_list.to(device), lengths

In [284]:
train_loader = DataLoader(train_data, batch_size = 8, shuffle = True, collate_fn = collate_batch)
val_loader = DataLoader(val_data, batch_size = 8, shuffle = False, collate_fn = collate_batch)

In [285]:
batch = next(iter(train_loader))

# Inspect the shape of the input data
input_data = batch[1]  # Assuming the input data is the first element of the batch
input_shape = input_data.shape[0]

In [286]:
batch

(tensor([1, 3, 0, 1, 3, 0, 2, 1]),
 tensor([[1059,    0,    7, 4455,    6, 3730,  110,   13,   31,   14,   31,   15,
            25, 1826,   17, 2571, 1059,   33,   39,    4,   37, 1527,   18,  797,
           403,   52,   71,    1,    2,   50,   16,    9,   77,    1,  108,  433,
            86,  197,    0,  133,   20,  607, 2855,    3, 2133,   30,    0, 2335,
            17,  291,  285,  473,    0,  641,    0, 1213, 1084,    8, 4994,    0,
          2229,   66,   94,   16,    9,  211,  429,    6,    2,  723,  423,    1],
         [ 117, 3048,  986,   20,    9,    1, 2357,    1, 1592,   13, 3804,    1,
           171,   14, 3804,    1,  171,   15,    5,   97,  449,    0, 1216,  210,
             2,  605,    3,  466,  580,    1,  248,   81,  233,   55,   11,    5,
             0,    6,    0,   46, 2844, 1437, 3048,    0, 1085,   11,  117,  138,
          4978,    1,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0

In [287]:
batch[1].shape

torch.Size([8, 72])

In [288]:
len(batch)

3

In [289]:
a =  torch.ones(5, 50)
a =  torch.ones(5, 50)
a =  torch.ones(5, 50)

In [290]:
input_shape

8

In [291]:
print(batch[0])
print(batch[1].shape)

tensor([1, 3, 0, 1, 3, 0, 2, 1])
torch.Size([8, 72])


In [292]:
LEARNING_RATE = 1e-3
BATCH_SIZE = 128
NUM_EPOCHS = 50
DROPOUT = 0.5
DEVICE = torch.device('cpu')

EMBEDDING_DIM = 128
BIDIRECTIONAL = True
HIDDEN_DIM = 128
NUM_LAYERS = 2
OUTPUT_DIM = 4

In [293]:
model = newsLSTM(VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_DIM)
model = model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [294]:
def train(model, train_loader, val_loader, loss_function, optim, epochs, device):
    losses = [] #group losses for loss visualization 
    running_loss = 0.0
    for epoch in range(epochs):
        model.train()
        print("Epoch %d / %d" % (epoch+1, epochs))
        print("-"*10)
    
        for i, batch_data in enumerate(train_loader):
            
            model.train()
            (y, x, x_size) = batch_data
            #print("Labels: {}, data: {}, x_size.cpu(): {}".format(batch_data[0], x.shape,x_size.cpu()))

            logits = model(x, x_size.cpu())
            #print("Target size: {}, pred_size: {}".format(y.size(), logits.size()))
            loss = loss_function(logits, y)
            optim.zero_grad()
            loss.backward()
            optim.step()
            
            running_loss += loss.item()
            losses.append(loss)

            if (i+1) % 1000 == 0:
                print("Step: {}, average training loss over last 2000 steps: {:.4f}".format(i+1, running_loss/1000))
                running_loss = 0.0
            
        model.eval()
        val_loss = 0.0
        
        with torch.no_grad():
            for i, batch_data in enumerate(val_loader):
                (y, x, x_size) = batch_data
                y, x, x_size = y.to(device), x.to(device), x_size.to(device)
                
                logits = model(x, x_size.cpu())
                loss = loss_function(logits, y)
                
                val_loss += loss.item()
        
        print("Epoch: {}, validation loss: {:.4f}".format(epoch+1, val_loss/len(val_loader)))

In [295]:
train(model, train_loader, val_loader, torch.nn.functional.cross_entropy, optimizer, NUM_EPOCHS, DEVICE)

Epoch 1 / 50
----------
original input shape: torch.Size([422, 128])
batch_size: 8
hidden_shape = torch.Size([1, 8, 128])
current input size = torch.Size([8, 128])
combined shape = torch.Size([1, 8, 256])
<class 'torch.Tensor'>
original input shape: torch.Size([422, 128])
batch_size: 8
hidden_shape = torch.Size([1, 8, 128])
current input size = torch.Size([8, 128])
combined shape = torch.Size([1, 8, 256])
<class 'torch.Tensor'>
original input shape: torch.Size([422, 128])
batch_size: 8
hidden_shape = torch.Size([1, 8, 128])
current input size = torch.Size([8, 128])
combined shape = torch.Size([1, 8, 256])
<class 'torch.Tensor'>
original input shape: torch.Size([422, 128])
batch_size: 8
hidden_shape = torch.Size([1, 8, 128])
current input size = torch.Size([8, 128])
combined shape = torch.Size([1, 8, 256])
<class 'torch.Tensor'>
original input shape: torch.Size([422, 128])
batch_size: 8
hidden_shape = torch.Size([1, 8, 128])
current input size = torch.Size([8, 128])
combined shape = tor

RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 7 but got size 8 for tensor number 1 in the list.