In [None]:
%matplotlib inline


In [None]:
import matplotlib.pyplot as plt
import torch
from torch import nn as nn
from math import factorial
import random
import torch.nn.functional as F
import numpy as np
import seaborn as sn
import pandas as pd
import os 
from os.path import join
import glob
from math import factorial
ttype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
ctype = torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor
print(ttype)
from torch.nn.utils import weight_norm
from lmu import LegendreMemoryUnit

from tqdm.notebook import tqdm
import pickle
sn.set_context("poster")
import itertools
from csv import DictWriter
import matplotlib.pylab as plt
import csv
import numpy as np
import pandas as pd
import os
import seaborn as sn
sn.set_context('talk')

In [None]:
def generate_noise(maxn=18):
    """Generates dot and dash based noise."""
    
    threes = np.random.randint(int(.5*maxn), int(.75*maxn))
    ones = (maxn - threes) * 2
    noise = list(itertools.repeat([1,1,1,0], threes))
    noise[:int(len(noise)/3)] = list(itertools.repeat([0,0], int(len(noise)/3)))
    ones = ones + int(len(noise)/3)
    noise.extend(list(itertools.repeat([1,0], ones)))
    random.shuffle(noise)
    noise = np.concatenate(noise)
    return noise
noise = generate_noise()
print(noise.shape)
plt.plot(noise)


In [None]:
sig_lets = ["A","B","C","D","E","F","G","H",]

signals = ttype([[0,1,1,1,0,1,1,1,0,1,0,1,0,1,0,0,0],
                 [0,1,1,1,0,1,0,1,1,1,0,1,0,1,0,0,0],
                 [0,1,1,1,0,1,0,1,0,1,1,1,0,1,0,0,0],
                 [0,1,1,1,0,1,0,1,0,1,0,1,1,1,0,0,0],
                 
                 [0,1,0,1,1,1,0,1,1,1,0,1,1,1,0,0,0],
                 [0,1,1,1,0,1,0,1,1,1,0,1,1,1,0,0,0],
                 [0,1,1,1,0,1,1,1,0,1,0,1,1,1,0,0,0],
                 [0,1,1,1,0,1,1,1,0,1,1,1,0,1,0,0,0],

                ]
               ).view(8, 1, 1, -1)

plt.imshow(signals[:,0,0,:].detach().cpu())


In [None]:
torch.manual_seed(12345)
np.random.seed(12345)
training_samples = 32

training_signals = []
training_class = []

for i, sig in enumerate(signals):
    temp_signals = []
    temp_class = []
    for x in range(training_samples):
        noise = ttype(generate_noise())
        temp = torch.cat([sig[0,0], noise]).unsqueeze(0)
        while(any([(temp == c_).all() for c_ in temp_signals])):
            noise = ttype(generate_noise())
            temp = torch.cat([sig[0,0], noise]).unsqueeze(0)
        temp_signals.append(temp)
        temp_class.append(i)
    training_signals.extend(temp_signals)
    training_class.extend(temp_class)

batch_rand = torch.randperm(training_samples*signals.shape[0])        
training_signals = torch.cat(training_signals).cuda().unsqueeze(-1)[batch_rand]
training_class  = ctype(training_class).cuda().unsqueeze(-1)[batch_rand]

dataset = torch.utils.data.TensorDataset(training_signals, training_class)
dataset = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)



In [None]:
testing_samples = 10
testing_signals = []
testing_class = []

for i, sig in enumerate(signals):
    temp_signals = []
    temp_class = []
    for x in range(testing_samples):
        noise = ttype(generate_noise())
        temp = torch.cat([sig[0,0], noise]).unsqueeze(0)
        while(any([(temp == c_).all() for c_ in temp_signals]) or any([(temp == c_).all() for c_ in training_signals])):
            noise = ttype(generate_noise())
            temp = torch.cat([sig[0,0], noise]).unsqueeze(0)
        temp_signals.append(temp)
        temp_class.append(i)
    testing_signals.extend(temp_signals)
    testing_class.extend(temp_class)
batch_rand = torch.randperm(testing_samples*signals.shape[0])

testing_signals = torch.cat(testing_signals).cuda().unsqueeze(-1)[batch_rand]
testing_class  = ctype(testing_class).cuda().unsqueeze(-1)[batch_rand]


dataset_valid = torch.utils.data.TensorDataset(testing_signals, testing_class)
dataset_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=32, shuffle=False)


In [None]:
class LMUModel(nn.Module):
    def __init__(self, n_out, layer_params):
        super(LMUModel, self).__init__()
        self.layers = nn.ModuleList([LegendreMemoryUnit(**layer_params[i])
                                      for i in range(len(layer_params))])
        self.dense = nn.Linear(layer_params[-1]['hidden_size'], n_out)

        
    def forward(self, x):
        for l in self.layers:
            x, _ = l(x)    
        x = self.dense(x)
        return x

In [None]:
def train(model, ttype, train_loader, test_loader, optimizer, loss_func, epoch, perf_file,
          permute=None, loss_buffer_size=64, batch_size=4, device='cuda',
          prog_bar=None, maxn=6):
    
    assert(loss_buffer_size%batch_size==0)
        
    losses = []
    perfs = []
    last_test_perf = 0
    best_test_perf = -1
    
    for batch_idx, (data, target) in enumerate(train_loader):
        model.train()
        data = data.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = loss_func(out[:,-1],
                         target[:, 0])
        
        loss.backward()
        optimizer.step()

        perfs.append((torch.argmax(out[:,-1], dim=-1) == 
                      target[:, 0]).sum().item())
        perfs = perfs[int(-loss_buffer_size/batch_size):]
        losses.append(loss.detach().cpu().numpy())
        losses = losses[int(-loss_buffer_size/batch_size):]
        if not (prog_bar is None):
            # Update progress_bar
            s = "{}:{} Loss: {:.4f}, perf: {:.4f}, valid: {:.4f}"
            format_list = [e,batch_idx*batch_size, np.mean(losses), 
                           np.sum(perfs)/((len(perfs))*batch_size), last_test_perf]         
            s = s.format(*format_list)
            prog_bar.set_description(s)
        
        if ((batch_idx*batch_size)%loss_buffer_size == 0) & (batch_idx != 0):
            loss_track = {}
            last_test_perf = test(model, 'cuda', test_loader, 
                                  batch_size=batch_size, 
                                  permute=permute)
            loss_track['avg_loss'] = np.mean(losses)
            loss_track['last_test'] = last_test_perf
            loss_track['epoch'] = epoch
            loss_track['maxn'] = maxn
            loss_track['batch_idx'] = batch_idx
            loss_track['pres_num'] = batch_idx*batch_size + epoch*len(train_loader.dataset)
            loss_track['train_perf']= np.sum(perfs)/((len(perfs))*batch_size)
            with open(perf_file, 'a+') as fp:
                csv_writer = DictWriter(fp, fieldnames=list(loss_track.keys()))
                if fp.tell() == 0:
                    csv_writer.writeheader()
                csv_writer.writerow(loss_track)
                fp.flush()
            if best_test_perf < last_test_perf:
                torch.save(model.state_dict(), perf_file[:-4]+".pt")
                best_test_perf = last_test_perf

                
def test(model, device, test_loader, batch_size=4, permute=None):
    model.eval()
    correct = 0
    count = 0
    with torch.no_grad():
        for data, target in test_loader:
            data = data.to(device)
            target = target.to(device)
            
            out = model(data)
            pred = out[:,-1].argmax(dim=-1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            count += 1
    return correct / len(test_loader.dataset)

# Training and testing

In [None]:
# You likely don't need this to be this long, but just in case.
epochs = 400

# Just for visualizing average loss through time. 
loss_buffer_size = 100

In [None]:
test_noise_lengths = [6,7,9,13,21,37]
for maxn in test_noise_lengths:
    torch.manual_seed(12345)
    np.random.seed(12345)
    training_samples = 32

    training_signals = []
    training_class = []

    for i, sig in enumerate(signals):
        temp_signals = []
        temp_class = []
        for x in range(training_samples):
            noise = ttype(generate_noise(maxn))
            temp = torch.cat([sig[0,0], noise]).unsqueeze(0)
            while(any([(temp == c_).all() for c_ in temp_signals])):
                noise = ttype(generate_noise(maxn))
                temp = torch.cat([sig[0,0], noise]).unsqueeze(0)
            temp_signals.append(temp)
            temp_class.append(i)
        training_signals.extend(temp_signals)
        training_class.extend(temp_class)

    batch_rand = torch.randperm(training_samples*signals.shape[0])        
    training_signals = torch.cat(training_signals).cuda().unsqueeze(-1)[batch_rand]
    training_class  = ctype(training_class).cuda().unsqueeze(-1)[batch_rand]

    dataset = torch.utils.data.TensorDataset(training_signals, training_class)
    dataset = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
    testing_samples = 10
    testing_signals = []
    testing_class = []

    for i, sig in enumerate(signals):
        temp_signals = []
        temp_class = []
        for x in range(testing_samples):
            noise = ttype(generate_noise(maxn))
            temp = torch.cat([sig[0,0], noise]).unsqueeze(0)
            while(any([(temp == c_).all() for c_ in temp_signals]) or any([(temp == c_).all() for c_ in training_signals])):
                noise = ttype(generate_noise(maxn))
                temp = torch.cat([sig[0,0], noise]).unsqueeze(0)
            temp_signals.append(temp)
            temp_class.append(i)
        testing_signals.extend(temp_signals)
        testing_class.extend(temp_class)
    batch_rand = torch.randperm(testing_samples*signals.shape[0])

    testing_signals = torch.cat(testing_signals).cuda().unsqueeze(-1)[batch_rand]
    testing_class  = ctype(testing_class).cuda().unsqueeze(-1)[batch_rand]


    dataset_valid = torch.utils.data.TensorDataset(testing_signals, testing_class)
    dataset_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=32, shuffle=False)

    hz=125

    lmu_params = [dict(input_dim=1, hidden_size=hz, order=40, theta=temp.shape[-1]),
                  #dict(input_dim=hz, hidden_size=hz, order=4, theta=4),
                  #dict(input_dim=hz, hidden_size=hz, order=4, theta=4),
                 ]
    model = LMUModel(8, lmu_params).cuda()

    tot_weights = 0
    for p in model.parameters():
        tot_weights += p.numel()
    print("Total Weights:", tot_weights)
    print(model)
    loss_func = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters())
    epochs = 400
    batch_size = 32
    progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')
    for e in progress_bar:
        train(model, ttype, dataset, dataset_valid, 
              optimizer, loss_func, batch_size=batch_size,
              epoch=e, perf_file=join('perf','h8_LMU_length_6.csv'),
              prog_bar=progress_bar, maxn=maxn)
    