In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from Data.Library import Library
import matplotlib.pyplot as plt
%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class S4Layer(nn.Module):
    def __init__(self, latent_size=8, in_channels = 1, out_channels=10, device=torch.device('cpu'), max_seq_length=1000):
        super().__init__()
        self.latent_size = latent_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.device = device

         # Use hippo matrices for A and B
        self.A = self.gen_A(self.latent_size)
        self.B = self.gen_B(self.latent_size)
        self.C = nn.Linear(self.latent_size, self.out_channels).to(self.device)
        self.D = nn.Linear(self.in_channels, self.out_channels).to(self.device)
        self.A_stack = None
        self.B_stack = None
        self.log_softmax = nn.LogSoftmax(-1)

    def gen_A(self, N):
        A = torch.zeros(N, N)
        for n in range(N):
            for k in range(N):
                if n > k:
                    A[n, k] = -(2*n + 1)**.5 * (2*k + 1)**.5
                elif n == k:
                    A[n, k] = -(n+1)
        return A
        
    def gen_B(self, N):
        B = torch.zeros(N)
        for n in range(N):
            B[n] = (2*n+1)**.5
        return B.unsqueeze(1)
        
    def discretize(self, step):
        A = self.A.to('cpu')
        B = self.B.to('cpu')
        N = A.shape[0]
        I = torch.eye(N)
        A_bar = torch.linalg.solve_triangular(I - (step / 2.0) * A, (I + (step / 2.0) * A), upper=False)
        B_bar = torch.linalg.solve_triangular((I - (step / 2.0) * A), B * step, upper=False)
        return A_bar.to(self.device), B_bar.to(self.device)
    
    def get_legendre_kernel(self, seq_length):
        if self.A_stack is None or self.B_stack is None:
            A_stack = torch.eye(self.latent_size).unsqueeze(-1).repeat((1, 1, seq_length)).permute(2, 0, 1).to(self.device)
            B_stack = torch.zeros(seq_length, self.latent_size, self.in_channels).to(self.device)
            A_bar, B_bar = self.discretize(step=1.0/seq_length)
            for idx in range(0, seq_length):
                A_stack[:idx] = A_stack[:idx] @ A_bar
                B_stack[idx] = B_bar
            A_stack = A_stack.permute(1, 2, 0).cpu()
            self.A_stack = A_stack
            self.B_stack = B_stack
            return A_stack, B_stack
        else:
            if seq_length > self.A_stack.shape[2]:
                print('recaslcing kernel')
                self.A_stack = None
                self.B_stack = None
                return self.get_legendre_kernel(seq_length)
            else:
                return self.A_stack, self.B_stack
    
    def forward(self, u, conv=True):
        # u has shape [batch_size, in_channels, sequence_length]
        if len(u.shape) == 2: # If 1 dimensional
            u = u.unsqueeze(1)
        if conv:
            output = self.get_legendre_conv(u.to(self.device))
        else:
            output = self.get_legendre_rec(u.to(self.device))
        # output has shape [batch_size, in_channels, latent_size, sequence_length]
        output = output.detach()
        print(output.shape)
        assert False
        output =  self.C(output) + self.D(u.permute(2, 0, 1))
        return output.permute(1, 2, 0)
    
    def get_legendre_rec(self, sequence):
        batch_size, in_channels, seq_length = sequence.shape
        x = torch.zeros([batch_size,self.latent_size, in_channels,  seq_length + 1]).to(self.device)
        A_bar, B_bar = self.discretize(step=1.0/seq_length)
        B_bar = B_bar.permute(1, 0).repeat(batch_size, 1).unsqueeze(2) # Shape of [batch_size, latent_sizem, 1]
        A_bar = A_bar.unsqueeze(0).repeat(batch_size, 1, 1)
        for idx in range(seq_length):
            seq = sequence[:,:,idx].unsqueeze(1)
            B_term = torch.matmul(B_bar, seq)
            x_term = x[:,:,:,idx]
            A_term = torch.matmul(A_bar, x_term)
            x[:, :, :, idx+1] = A_term + B_term
        # x has shape [batch_size, self.latent_size, in_channels, seq_length + 1]
        output = x[:, :, :, 1:]
        # output has shape [batch_size, self.latent_size, in_channels, seq_length]
        B_scale = self.B.unsqueeze(0).unsqueeze(2).to(self.device)
        # output has shape [batch_size, in_channels, latent_size, sequence_length]
        output = (B_scale * output).permute(0, 2, 1, 3)
        return output
    
    def get_legendre_conv(self, sequence):
        # Sequence of shape [batch_size, in_channels, seq_length]
        batch_size, in_channels, seq_length = sequence.shape
        A_stack, B_stack = self.get_legendre_kernel(seq_length)
        # Apply B stack
        u = sequence.permute(2,0,1)
        print(u.shape)
        assert False
        B_stack = B_stack.permute(0, 2, 1).to(self.device)
        u = torch.matmul(B_stack, u)
        print(u.shape)
        u = u.permute(1, 2, 0)
        # u has shape [batch_size, latent_size, seq_length]
        output = F.conv1d(u, A_stack.to(self.device), padding=seq_length-1)[:, :, :seq_length]
        # Output is of shape [batch_size, latent_size, seq_length]
        return (self.B * output).permute(2, 0, 1)
class S4Model(nn.Module):
    def __init__(self, vocab_size, d_model, d_internal, device = torch.device('cpu')):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.d_internal = d_internal
        self.device = device
        self.log_softmax = nn.LogSoftmax(-2)

        # Define model shape
        self.embeddings = nn.Embedding(self.vocab_size, self.d_model).to(self.device)
        self.S4 = S4Layer(latent_size = self.d_internal, in_channels = self.d_model, out_channels = self.vocab_size, device=self.device)
    
    def forward(self, sequence, conv=True):
        batch_size, seq_length = sequence.shape
        x = self.embeddings(sequence.to(self.device))
        # Current shape is [batch_size, sequence_length, channels]
        x = x.permute(0, 2, 1)#.to(self.device)
         # Current shape is [batch_size, channels, sequence_length]
        x = self.S4(x, conv=conv)
        x = self.log_softmax(x)
        x = x.to(torch.device('cpu')) 
        return x
model = S4Model(vocab_size, d_model, d_internal, device)
model(x_batch.long(), conv=True)

RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [255, 128] but got: [255, 64].

In [378]:
# Hyperparams
epochs = 64
lr = .001
seq_length=256
batch_size=64
d_model=32
d_internal=128
train_size = 2**20
test_size = 2**16
encoding=76
vocab_size = 76
torch.manual_seed(0)
conv = True

# Setup
device = torch.device('mps')
if False:
    library = Library(encoding=encoding, train_size=train_size, test_size=test_size, download_new=True)

model = S4Model(vocab_size, d_model, d_internal, device)
loss_fn = nn.NLLLoss()
optim = torch.optim.Adam(model.parameters(), lr=lr)
x_batch = torch.zeros([batch_size, seq_length-1])
y_batch = torch.zeros([batch_size, seq_length-1])
losses = torch.zeros(epochs)
perplexities = torch.zeros(epochs)
print('Training')
import time
tic = time.time()
# Training
for epoch in range(epochs):
    dataloader = library.get_train_dataloader(seq_length)
    for idx, data in enumerate(dataloader):
        mod_idx = idx % batch_size
        if data.shape[0] != seq_length:
            break # End of usable dataloader
        x_batch[mod_idx] = data[:-1]
        y_batch[mod_idx] = data[1:]
        if mod_idx == batch_size-1:
            # Update weights
            optim.zero_grad()
            y_pred = model(x_batch.long(), conv=conv)
            loss = loss_fn(y_pred, y_batch.long())
            losses[epoch] += loss

            print(f'{epoch}:{idx+1}:{losses[epoch]:.4f}', end='\r')
            loss.backward()
            optim.step()
            
    # Test
    perplexities[epoch] = library.calc_perplexity(model, batch_size=64, seq_length=256)
    print(f'{epoch}:Total Loss:{losses[epoch]:.2f}:Perplexity:{perplexities[epoch]:.2f}')
    torch.save(model.state_dict(), f'Models/{encoding}.pkl')
print(time.time()-tic)

Training


RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [255, 1] but got: [255, 32].