### Imports

In [1]:
import matplotlib.pyplot as plt
import math

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from tqdm import tqdm

from torchvision import datasets, transforms

import os
import spect

### Helper functions

In [2]:
def get_context_points(data, context_points=100):
    
    mask = np.zeros_like(data[0])
    
    n,m = mask.shape
    mask = mask.reshape(-1)

    mask[:context_points] = 1
    np.random.shuffle(mask)

    mask = mask.reshape(n,m)
    
    data = np.array(data.tolist())

    data[0][mask != 1] = 0
    data[1][mask != 1] = 0

    data = torch.tensor(data)
    
    return data

def normalize(x):
    x = (x - x.min())/(x.max() - x.min())
    return x

def get_log_p(data, mu, sigma):
    return -torch.log(torch.sqrt(2*math.pi*sigma**2)) - (data - mu)**2/(2*sigma**2)

### Model Definitions

In [3]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(4, 128)
        self.fc2 = nn.Linear(128,128)
        self.fc3 = nn.Linear(128,128)
        

    def forward(self, x):
        
        # this gets the x,y coordinates of the zero values
        cntx = x[0].nonzero()[:,[-2,-1]]
        
        x_points = cntx[:,0]
        y_points = cntx[:,1]
        
        # we then need to pull out real, imag intensities
        real = x[0][x_points, y_points]
        imag = x[1][x_points, y_points]

        x = torch.stack((normalize(x_points.float()), normalize(y_points.float()), real.reshape(-1), imag.reshape(-1)))
        
        x.transpose_(0,1)
        output = torch.empty((x.shape[0], 128)).to(device)
        
        # aggregate all outputs of intensity data found in x at the various context points
        for i, row in enumerate(x):
            output[i] = self.fc3(F.relu(self.fc2(self.fc1(row))))
        
        output = output.mean(0) # aggregation
        return output
    
class Decoder(nn.Module):
    def __init__(self, m,n):
        super(Decoder, self).__init__()
        self.m = m
        self.n = n
        self.fc1 = nn.Linear(130, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 2)

    def forward(self, r):

        # we only take in r, because in this case x is all points in size of image (28,28)
        x = torch.tensor([[i, j] for i in range(0,self.m) for j in range(0,self.n)]).float().to(device)
        x = torch.cat((x, r.view(1,-1).repeat(1,self.m*self.n).view(self.m*self.n,128)), 1)
        
        h = self.fc4(F.relu(self.fc3(F.relu(self.fc2(F.relu(self.fc1(x)))))))
        
        mu = h[:,0]
        log_sigma = h[:,1]
        
        # bound the variance
        sigma = 0.1 + 0.9 * F.softplus(log_sigma)
        
        return mu, sigma



### Load Data

In [4]:
use_cuda = torch.cuda.is_available()
torch.manual_seed(1)
device = torch.device("cuda" if use_cuda else "cpu")
use_mnist=False

In [5]:
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

if use_mnist:
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,)),
                       ])),
        batch_size=batch_size, shuffle=True, **kwargs)


    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,)),
                       ])),
        batch_size=test_batch_size, shuffle=True, **kwargs)
else:
    dir_ = "/mnt/pccfs/backed_up/andrew/hearables/clean/"
    data = torch.tensor([spect.get_stft(dir_ + x, split=True)[2] for x in os.listdir(dir_)])
    np.random.shuffle(data)
    train_loader = data[:300]
    test_loader = data[301:]
    


### Hyperparameters

In [6]:
m,n = 129,1252 #28, 28
num_pixels = m*n

batch_size = 1
test_batch_size = 1000
epochs = 10

log_interval = 500


min_context_points = num_pixels * 0.05 # always have at least 5% of all pixels
max_context_points = num_pixels * 0.95 # always have at most 95% of all pixels

### Train

In [8]:
encoder = Encoder().to(device)
decoder = Decoder(m, n).to(device)
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()))


for epoch in range(1, epochs+1):
    encoder.train()
    decoder.train()
    progress = tqdm(enumerate(train_loader))
    for batch_idx, data, *rest in progress:
        
        num_context_points = np.random.randint(min_context_points, max_context_points)
        cntx = get_context_points(data, context_points=num_context_points)
        cntx = cntx.to(device).float()
        optimizer.zero_grad()
        
        # run the model to get r
        output = encoder(cntx)
        mu, sigma = decoder(output)
        
        # ok, now we need to think about multidimensionally distributed log probability
        log_p = get_log_p(data.view(2,m,n).to(device).view(-1), mu, sigma)
                
        
        loss = -log_p.mean()
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            progress.set_description('Loss: {:.6f} Mean: {}/{} Sig: {}/{}'.format(loss.item(), mu.max(), mu.min(), sigma.max(), sigma.min()))
    
    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        for data, *rest in test_loader:
            data = data[0].view(1,1,m, n)
            cntx = get_context_points(data, context_points=200)
            data, cntx = data.to(device), cntx.to(device)

            output = encoder(cntx)
            mu, sigma = decoder(output)

            if use_mnist:
                plt.imshow(cntx.reshape(m,n), cmap='gray')
                plt.axis("off")
                plt.show()

                plt.imshow(data.reshape(m,n), cmap='gray')
                plt.axis("off")
                plt.show()

                plt.imshow(mu.detach().reshape(m,n), cmap='gray')
                plt.axis("off")
                plt.title("mean")
                plt.show()

                plt.imshow(sigma.detach().reshape(n,b), cmap='gray')
                plt.axis("off")
                plt.title("variance")
                plt.show()
            else:
                spect.plot_spect(cntx.reshape(m,n))
                spect.plot_spect(data.reshape(m,n))
                spect.plot_spect(mu.detach().reshape(m,n))
                spect.plot_spect(sigma.detach().reshape(m,n))


0it [00:00, ?it/s][A

RuntimeError: The size of tensor a (323016) must match the size of tensor b (161508) at non-singleton dimension 0

In [None]:
import pickle

with open("encoder_spect.pkl", "wb") as of:
    pickle.dump(encoder, of)

with open("decoder_spect.pkl", "wb") as of:
    pickle.dump(decoder, of)