In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from Data.Library import Library
import matplotlib.pyplot as plt
%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [22]:
class S4Layer(nn.Module):
    def __init__(self, latent_size=8, in_channels = 1, out_channels=10, device=torch.device('cpu')):
        super().__init__()
        self.latent_size = latent_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.device = device

         # Use hippo matrices for A and B
        self.A = self.gen_A(self.latent_size)
        self.B = self.gen_B(self.latent_size)
        self.C = torch.rand(self.latent_size, self.out_channels).to(self.device).requires_grad_()
        nn.init.xavier_uniform_(self.C)

        self.D = torch.rand(self.in_channels, self.out_channels).to(self.device).requires_grad_()
        nn.init.xavier_uniform_(self.D)
        self.K = None
        self.A_stack = None
        self.B_stack = None

    def gen_A(self, N):
        A = torch.zeros(N, N)
        for n in range(N):
            for k in range(N):
                if n > k:
                    A[n, k] = -(2*n + 1)**.5 * (2*k + 1)**.5
                elif n == k:
                    A[n, k] = -(n+1)
        return A
        
    def gen_B(self, N):
        B = torch.zeros(N)
        for n in range(N):
            B[n] = (2*n+1)**.5
        return nn.Parameter(B.unsqueeze(1)).to(self.device)
        
    def discretize(self, step):
        A = self.A.to('cpu')
        B = self.B.to('cpu')
        N = A.shape[0]
        I = torch.eye(N)
        A_bar = torch.linalg.solve_triangular(I - (step / 2.0) * A, (I + (step / 2.0) * A), upper=False)
        B_bar = torch.linalg.solve_triangular((I - (step / 2.0) * A), B * step, upper=False)
        return A_bar.to(self.device), B_bar.to(self.device)
    
    def get_legendre_kernel(self, seq_length):
        if self.A_stack is None or self.B_stack is None:
            self.A_stack = torch.eye(self.latent_size).unsqueeze(-1).repeat((1, 1, seq_length)).permute(2, 0, 1).to(self.device)
            self.B_stack = torch.zeros(seq_length, self.latent_size, self.in_channels).to(self.device)
            for idx in range(0, seq_length):
                A_bar, B_bar = self.discretize(step=1.0/(idx + 1))
                self.A_stack[:idx] = self.A_stack[:idx] @ A_bar
                self.B_stack[idx] = B_bar
            self.A_stack = self.A_stack.permute(1, 2, 0).cpu()
            self.B_Stack = self.B_stack.cpu()
            return self.A_stack, self.B_stack
        else:
            if seq_length > self.A_stack.shape[2]:
                print('recalcing kernel')
                self.A_stack = None
                self.B_stack = None
                return self.get_legendre_kernel(seq_length)
            else:
                return self.A_stack, self.B_stack
    def forward(self, u, conv=True):
        # u has shape [batch_size, in_channels, sequence_length]
        if len(u.shape) == 2:
            u = u.unsqueeze(1)
        elif len(u.shape) != 3:
            print('Unknown number of dimensions in forward_conv')
            assert False
        batch_size, in_channels, seq_length = u.shape
        if conv:
            output = self.get_legendre_conv(u.to(self.device))
        else:
            output = self.get_legendre_rec(u.to(self.device))
        output = output.detach()
        output =  output @ self.C + u.permute(2, 0, 1) @ self.D
        return output.permute(1, 2, 0)
    
    def get_legendre_rec(self, sequence):
        batch_size, in_channels, seq_length = sequence.shape
        x = torch.zeros([batch_size, self.latent_size, seq_length + 1]).to(self.device)
        for idx in range(seq_length):
            A_bar, B_bar = self.discretize(step=1.0/(1+idx))
            x[:, :, idx+1] = x[:,:, idx] @ A_bar.mT + sequence[:,:,idx] @ B_bar.mT
        # x has shape [batch_size, self.latent_size, seq_length + 1]
        output = x[:, :, 1:]
        # Output is of shape [batch_size, latent_size, seq_length]
        return (self.B * output).permute(2, 0, 1)
    def get_legendre_conv(self, sequence):
        # Sequence of shape [batch_size, in_channels, seq_length]
        batch_size, in_channels, seq_length = sequence.shape
        A_stack, B_stack = self.get_legendre_kernel(seq_length)
        # Apply B stack
        u = torch.bmm(sequence.permute(2, 0, 1), B_stack.to(self.device).mT).permute(1, 2, 0)
        # u has shape [batch_size, latent_size, seq_length]
        # Pad front with zeros to create correct convolutional form
        u_pad = F.pad(u, (seq_length-1, 0))
        # u_pad has shape [batch_size, latent_size, 2*seq_length-1]
        output = F.conv1d(u_pad, A_stack.to(self.device))
        # Output is of shape [batch_size, latent_size, seq_length]
        return (self.B * output).permute(2, 0, 1)
class S4Model(nn.Module):
    def __init__(self, vocab_size, d_model, d_internal, device = torch.device('cpu')):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.d_internal = d_internal
        self.device = device
        self.log_softmax = nn.LogSoftmax(-2)

        # Define model shape
        self.embeddings = nn.Embedding(self.vocab_size, self.d_model)
        self.S4 = S4Layer(latent_size = self.d_internal, in_channels = self.d_model, out_channels = self.vocab_size, device=self.device)
    
    def forward(self, sequence):
        batch_size, seq_length = sequence.shape
        x = self.embeddings(sequence)
        # Current shape is [batch_size, sequence_length, channels]
        x = x.permute(0, 2, 1).to(self.device)
         # Current shape is [batch_size, channels, sequence_length]
        x = self.S4(x)
        x = self.log_softmax(x)
        x = x.to(torch.device('cpu')) 
        return x

In [None]:
# Hyperparams
epochs = 128
lr = .001
seq_length=512
batch_size=128
d_model=64
d_internal=64
train_size = 2**16
encoding=76
torch.manual_seed(0)

# Setup
device = torch.device('cuda')
library = Library(encoding = encoding, train_size = train_size, streaming=False)

model = S4Model(encoding, d_model, d_internal, device)
loss_fn = nn.NLLLoss()
optim = torch.optim.Adam([model.S4.C, model.S4.D], lr=lr)
x_batch = torch.zeros([batch_size, seq_length-1])
y_batch = torch.zeros([batch_size, seq_length-1])
losses = torch.zeros(epochs)
perplexities = torch.zeros(epochs)
print('Training')
# Training
for epoch in range(epochs):
    dataloader = library.get_train_dataloader(seq_length)
    for idx, data in enumerate(dataloader):
        mod_idx = idx % batch_size
        if data.shape[0] != seq_length:
            break # End of usable dataloader
        x_batch[mod_idx] = data[:-1]
        y_batch[mod_idx] = data[1:]
        if mod_idx == batch_size-1:
            # Update weights
            optim.zero_grad()
            y_pred = model(x_batch.long())
            loss = loss_fn(y_pred, y_batch.long())
            losses[epoch] += loss

            print(f'{epoch}:{idx+1}:{losses[epoch]:.4f}', end='\r')
            loss.backward()
            optim.step()
            
    # Test
    perplexities[epoch] = library.calc_perplexity(model)
    print(f'{epoch}:Total Loss:{losses[epoch]:.2f}:Perplexity:{perplexities[epoch]:.2f}')
    torch.save(model.state_dict(), f'Models/{encoding}.pkl')

Training
0:Total Loss:686.59:Perplexity:4015.70
1:Total Loss:345.05:Perplexity:66.24
2:Total Loss:197.54:Perplexity:20.06
3:Total Loss:162.22:Perplexity:14.94
4:Total Loss:151.50:Perplexity:13.22
5:Total Loss:146.44:Perplexity:12.37
6:Total Loss:143.48:Perplexity:11.87
7:Total Loss:141.54:Perplexity:11.54
8:Total Loss:140.18:Perplexity:11.32
9:Total Loss:139.18:Perplexity:11.15
10:Total Loss:138.42:Perplexity:11.03
11:Total Loss:137.82:Perplexity:10.93
12:Total Loss:137.34:Perplexity:10.85
13:Total Loss:136.93:Perplexity:10.79
14:Total Loss:136.60:Perplexity:10.73
15:Total Loss:136.31:Perplexity:10.69
16:Total Loss:136.07:Perplexity:10.65
17:Total Loss:135.87:Perplexity:10.62
18:Total Loss:135.70:Perplexity:10.59
19:Total Loss:135.55:Perplexity:10.57
20:Total Loss:135.42:Perplexity:10.54
21:Total Loss:135.31:Perplexity:10.53
22:Total Loss:135.21:Perplexity:10.51
23:Total Loss:135.12:Perplexity:10.49
24:Total Loss:135.04:Perplexity:10.48
25:Total Loss:134.96:Perplexity:10.47
26:Total Lo