In [1]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

## Attention encoding

In [2]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads): 
        super(SelfAttention, self).__init__() 
        self.embed_size = embed_size 
        self.heads = heads
        self.head_dim = embed_size // heads
        
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False) 
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False) 
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False) 
        self.fc_out = nn.Linear(heads*self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
        
        # Split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values) 
        keys = self.keys(keys)
        queries = self.queries(queries)
        
        QK = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) 
        
        if mask is not None:
            QK = QK.masked_fill(mask == 0, float("-1e20"))
        
        attention = torch.softmax(QK / (self.embed_size**(1/2)), dim=3)
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)
        out = self.fc_out(out)
        return out

In [3]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion): 
        super(TransformerBlock, self).__init__() 
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size) 
        self.norm2 = nn.LayerNorm(embed_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion*embed_size), 
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_size, embed_size)
        )
        self.dropout = nn.Dropout(dropout)
      
    def forward(self, value, key, query, mask): 
        attention = self.attention(value, key, query, mask)
        
        x = self.dropout(self.norm1(attention + query)) 
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x)) 
        return out

In [4]:
class Encoder(nn.Module): 
    def __init__(
        self,
        src_vocab_size,
        embed_size,
        num_layers,
        heads,
        device,
        forward_expansion,
        dropout,
        max_length
    ):
        super(Encoder, self).__init__() 
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size) 
        self.position_embedding = nn.Embedding(max_length, embed_size)
    
        self.layers = nn.ModuleList([
            TransformerBlock(embed_size, heads, dropout=dropout, forward_expansion=forward_expansion) for _ in range(num_layers)
        ])
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))
        
        for layer in self.layers:
            out = layer(out, out, out, mask)
            
        return out

## sLSTM

In [5]:
"""
From Mudit Bhargava
"""
class sLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(sLSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        self.weight_ih = nn.Parameter(torch.randn(4 * hidden_size, input_size))
        self.weight_hh = nn.Parameter(torch.randn(4 * hidden_size, hidden_size))
        self.bias = nn.Parameter(torch.randn(4 * hidden_size))
        
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight_ih)
        nn.init.xavier_uniform_(self.weight_hh)
        nn.init.zeros_(self.bias)

    def forward(self, input, hx):
        h, c = hx
        gates = F.linear(input, self.weight_ih, self.bias) + F.linear(h, self.weight_hh)
        
        z, i, f, o = gates.chunk(4, 1)
        
        z = torch.tanh(z)
        i = torch.exp(i)  # Exponential input gate
        f = torch.exp(f)  # Exponential forget gate
        o = torch.sigmoid(o)
        
        c = f * c + i * z
        h = o * torch.tanh(c)
        
        return h, c
        

class sLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout=0.0):
        super(sLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout

        self.layers = nn.ModuleList([sLSTMCell(input_size if i == 0 else hidden_size, hidden_size) 
                                     for i in range(num_layers)])  # multiple memory
        self.dropout_layer = nn.Dropout(dropout)

    def forward(self, input_seq, hidden_state=None):
        batch_size, seq_length, _ = input_seq.size()
        
        if hidden_state is None:
            hidden_state = self.init_hidden(batch_size)
        
        outputs = []
        for t in range(seq_length):
            x = input_seq[:, t, :]
            for layer_idx, layer in enumerate(self.layers):
                h, c = hidden_state[layer_idx]
                h, c = layer(x, (h, c))
                hidden_state[layer_idx] = (h, c)
                x = self.dropout_layer(h) if layer_idx < self.num_layers - 1 else h
            outputs.append(x)
        
        return torch.stack(outputs, dim=1), hidden_state

    def init_hidden(self, batch_size):
        return [(torch.zeros(batch_size, self.hidden_size, device=self.layers[0].weight_ih.device),
                 torch.zeros(batch_size, self.hidden_size, device=self.layers[0].weight_ih.device))
                for _ in range(self.num_layers)]

In [6]:
class sLSTMBlock(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout=0.0):
        super(sLSTMBlock, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout

        self.lstm = sLSTM(input_size, hidden_size, num_layers, dropout)
        self.norm = nn.LayerNorm(hidden_size)
        self.activation = nn.GELU()
        self.dropout_layer = nn.Dropout(dropout)
        self.proj = nn.Linear(hidden_size, input_size)

    def forward(self, input_seq, hidden_state=None):
        lstm_output, hidden_state = self.lstm(input_seq, hidden_state)
        output = self.activation(lstm_output)
        output = self.norm(output)
        output = self.proj(output)
        output = self.dropout_layer(output + input_seq)  # Residual connection
        return output, hidden_state

## AttensLSTM

In [7]:
class AttensLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size, num_layers, num_heads, num_blocks, seq_length, device, dropout=0.0):
        super(AttensLSTM, self).__init__()
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.num_blocks = num_blocks
        self.seq_length = seq_length
        self.device = device
        self.dropout = dropout

        self.encoder = Encoder(src_vocab_size=self.vocab_size, embed_size=self.embedding_size, num_layers=self.num_layers,
                               heads=self.num_heads, device=self.device, forward_expansion=2, 
                               dropout=self.dropout, max_length=self.seq_length)
                                
        self.blocks = nn.ModuleList([
            sLSTMBlock(embedding_size, hidden_size, num_layers, dropout)
            for _ in range(num_blocks)
        ])
        self.output_layer = nn.Linear(embedding_size, vocab_size)

    def forward(self, input_seq, hidden_states=None):
        # embedded_seq = self.embedding(input_seq)
        encoded_seq = self.encoder(input_seq, mask=None)
        
        if hidden_states is None:
            hidden_states = [None] * self.num_blocks
        
        # output_seq = embedded_seq
        output_seq = encoded_seq
        for i, block in enumerate(self.blocks):
            output_seq, hidden_states[i] = block(output_seq, hidden_states[i])
        
        output_seq = self.output_layer(output_seq)
        return output_seq, hidden_states

## Shape verification

In [8]:
# Hyperparameters
vocab_size = 10000
batch_size = 4
seq_length = 10
embedding_size = 256
hidden_size = 512
num_layers = 2
num_heads = 4
num_blocks = 3
dropout = 0.1

device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

# Generate random input sequence
input_seq = torch.randint(0, vocab_size, (batch_size, seq_length)).to(device)

model = AttensLSTM(vocab_size=vocab_size, embedding_size=embedding_size, hidden_size=hidden_size, num_layers=num_layers, num_heads=num_heads, 
                   num_blocks=num_blocks, seq_length=seq_length, device=device, dropout=dropout).to(device)

output, hidden_state = model(input_seq)

In [9]:
output.shape

torch.Size([4, 10, 10000])

## Regular task: even pairs

In [10]:
import numpy as np

def even_pairs(
    num_samples, vocab_size, # =2 for binary seq
    seq_length, seed
):
    rng = np.random.default_rng(seed)
    input_seq = torch.tensor(rng.integers(vocab_size, size=[num_samples, seq_length]))
    output_seq = torch.zeros(num_samples, seq_length)

    for sample_idx in range(num_samples):
        output_seq[sample_idx, -1] = int(input_seq[sample_idx, 0] == input_seq[sample_idx, -1])
        
    return input_seq, output_seq

In [11]:
# Test
vocab_size = 2
num_samples = 10
seq_length = 5

even_pairs(num_samples=num_samples, vocab_size=vocab_size, seq_length=seq_length, seed=28)

(tensor([[1, 1, 0, 1, 1],
         [1, 1, 0, 0, 0],
         [1, 0, 0, 1, 0],
         [1, 0, 1, 1, 1],
         [0, 0, 1, 1, 1],
         [1, 0, 0, 0, 1],
         [1, 0, 1, 0, 1],
         [1, 1, 1, 1, 0],
         [1, 0, 1, 0, 1],
         [0, 0, 1, 0, 0]]),
 tensor([[0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.]]))

In [12]:
# Hyperparameters
vocab_size = 2
embedding_size = 128
hidden_size = 256

num_layers = 1
num_blocks = 2

num_samples = 300
batch_size = 64
seq_length = 20

num_epochs = 20
learning_rate = 0.0001
clip_value = 1

class even_pairsDataset(torch.utils.data.Dataset):
    def __init__(self, num_samples, vocab_size, seq_length):
        self.input, self.target = even_pairs(num_samples=num_samples, vocab_size=vocab_size,
                                             seq_length=seq_length, seed=31)
    
    def __len__(self):
        return len(self.target)
    
    def __getitem__(self, idx):
        input = self.input[idx]
        target = self.target[idx]
        return input, target

device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
model = AttensLSTM(vocab_size=vocab_size, embedding_size=embedding_size, hidden_size=hidden_size, num_layers=num_layers, num_heads=num_heads, 
                   num_blocks=num_blocks, seq_length=seq_length, device=device, dropout=dropout).to(device)

def init_weights(m):
    # print(m)
    if type(m) in [nn.Linear, nn.Embedding]:
        nn.init.xavier_uniform_(m.weight)
        if hasattr(m, 'bias') and m.bias is not None:
            nn.init.constant_(m.bias, 0)
model.apply(init_weights)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

train_dataset = even_pairsDataset(num_samples, vocab_size, seq_length)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = even_pairsDataset(num_samples, vocab_size, seq_length)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [13]:
def train(epoch):
    model.train()
    total_loss = 0
    for batch_idx, (input, target) in enumerate(train_loader):
        input_seq = input.to(device)
        target_seq = target.to(device)
        output, _ = model(input_seq)
        output = output.contiguous().view(-1, vocab_size)
        target_seq = target_seq.contiguous().view(-1)
        
        loss = criterion(output, target_seq.long())
        optimizer.zero_grad()
        loss.backward()
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
        optimizer.step()
        total_loss += loss.item()

        #print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}/{len(train_loader)}, Training Loss: {loss.item()}")
        
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Training Loss: {avg_loss:.4f}")

In [14]:
def val(epoch):
    model.eval()
    val_loss = 0
    gt_seq = []
    pred_seq = []
    with torch.no_grad():
        for batch_idx, (input, target) in enumerate(test_loader):
            input_seq = input.to(device)
            target_seq = target.to(device)
            output, _ = model(input_seq)
            output = output.contiguous().view(-1, vocab_size)
            target_seq = target_seq.contiguous().view(-1)
            loss = criterion(output, target_seq.long())
            val_loss += loss.item()
    
            preds = torch.argmax(output, 1)
            pred_seq.append(preds.cpu().data.numpy())
            gt_seq.append(target_seq.cpu().data.numpy())

            #print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}/{len(train_loader)}, Validation Loss: {loss.item()}")

    val_loss = val_loss/len(test_loader)
    gt_seq, pred_seq = np.concatenate(gt_seq), np.concatenate(pred_seq)
    acc = np.sum(gt_seq==pred_seq)/len(pred_seq)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Validation Loss: {val_loss:.4f}, Accuracy: {acc:.4f}")
    # print('Epoch: {} \tValidation Loss: {:.6f}, Accuracy: {:6f}'.format(epoch, val_loss, acc))

In [15]:
start_time = time.time()
for epoch in range(num_epochs):
    train(epoch)
    val(epoch)
    print()

end_time = time.time()
print(f"Training completed! Total time: {end_time - start_time:.2f} seconds")

Epoch 1/20, Average Training Loss: 0.2233
Epoch 1/20, Average Validation Loss: 0.2249, Accuracy: 0.9757

Epoch 2/20, Average Training Loss: 0.1800
Epoch 2/20, Average Validation Loss: 0.1311, Accuracy: 0.9757

Epoch 3/20, Average Training Loss: 0.1425
Epoch 3/20, Average Validation Loss: 0.0870, Accuracy: 0.9753

Epoch 4/20, Average Training Loss: 0.1289
Epoch 4/20, Average Validation Loss: 0.0902, Accuracy: 0.9757

Epoch 5/20, Average Training Loss: 0.1027
Epoch 5/20, Average Validation Loss: 0.0714, Accuracy: 0.9757

Epoch 6/20, Average Training Loss: 0.0863
Epoch 6/20, Average Validation Loss: 0.0390, Accuracy: 0.9803

Epoch 7/20, Average Training Loss: 0.0776
Epoch 7/20, Average Validation Loss: 0.0389, Accuracy: 0.9798

Epoch 8/20, Average Training Loss: 0.0721
Epoch 8/20, Average Validation Loss: 0.0398, Accuracy: 0.9788

Epoch 9/20, Average Training Loss: 0.0667
Epoch 9/20, Average Validation Loss: 0.0413, Accuracy: 0.9785

Epoch 10/20, Average Training Loss: 0.0678
Epoch 10/20,