In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import DataUtils
import Masking

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
batch_size = 64

In [3]:
columns_excluded = [0, 1, 26]
columns_kept = [False, False, True, True, True,
               True, True, True, True, True,
               True, True, True, True, True,
               True, True, True, True, True,
               True, True, True, True, True,
               True, False]
print(columns_kept)

[False, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False]


In [4]:
trainloader, testloader = DataUtils.get_dataloaders(batch_size)

In [5]:
example = next(enumerate(trainloader))

In [6]:
ex_X = example[1][0]
print(ex_X.shape)

ex_y = example[1][1]
print(ex_y.shape)

torch.Size([64, 542, 27])
torch.Size([64])


In [7]:
# Should we use a linear attention transformer?
class ImputationTransformer(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        
        self.input_projection = nn.Linear(542*24, embed_dim)
        self.positional_embed = nn.Parameter(torch.randn(embed_dim))
        
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=4, activation="gelu")
        self.transformer_blocks = nn.TransformerEncoder(self.encoder_layer, num_layers=3)
    
    def forward(self, x):
        x = x.flatten(1)
        z = self.input_projection(x)
        z = z + self.positional_embed
        z = self.transformer_blocks(z)
        
        return z

In [17]:
class ReconstructionImputationTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.embed_dim = 128
        self.imputation_transformer = ImputationTransformer(self.embed_dim)
        self.reconstruction = nn.Sequential(nn.Linear(self.embed_dim, 2048),
                                           nn.ReLU(),
                                           nn.Linear(2048, 542*24))
    
    def forward(self, x):
        z = self.imputation_transformer(x)
        z = self.reconstruction(z)
        
        return z.view(x.shape[0],542,24)

In [9]:
ex_X[:,:,columns_kept].shape

torch.Size([64, 542, 24])

In [19]:
r_tran = ReconstructionImputationTransformer()
print("Num params:",sum(p.numel() for p in r_tran.parameters() if p.requires_grad))

Num params: 30954960


In [20]:
r_tran(ex_X[:,:,columns_kept].float()).shape

torch.Size([64, 542, 24])

### Training Model

In [13]:
def mse_missing(xhat, masked_X, X):
    missing_idx = (masked_X==float(-1))
    error = xhat[missing_idx] - X[missing_idx]
    return torch.nanmean(torch.pow(error, 2))

In [21]:
objective = mse_missing

r_tran = ReconstructionImputationTransformer()

lr = 1e-4
n_epochs = 25
optim = torch.optim.Adam(r_tran.parameters(), lr=lr)
losses = []

for n in range(n_epochs):
    counter = 0
    for i, (X, y) in enumerate(tqdm(trainloader)):
        masked_X = Masking.mask_input(X)
        optim.zero_grad()
        xhat = r_tran(masked_X[:,:,columns_kept].float())
        loss = objective(xhat, masked_X[:,:,columns_kept], X[:,:,columns_kept])
        loss.backward()
        losses.append(loss.item())
        optim.step()
        counter += 1
        
    print("Epoch:", n+1, "Loss:",np.mean(losses[-counter:][0]))

100%|████████████████████████████████████████████████████████████████████████████████| 323/323 [20:43<00:00,  3.85s/it]


Epoch: 1 Loss: 0.23848849606029304


100%|████████████████████████████████████████████████████████████████████████████████| 323/323 [20:16<00:00,  3.76s/it]


Epoch: 2 Loss: 0.009394740167889418


100%|████████████████████████████████████████████████████████████████████████████████| 323/323 [20:29<00:00,  3.81s/it]


Epoch: 3 Loss: 0.00914702090447563


100%|████████████████████████████████████████████████████████████████████████████████| 323/323 [20:28<00:00,  3.80s/it]


Epoch: 4 Loss: 0.008450185747194925


100%|████████████████████████████████████████████████████████████████████████████████| 323/323 [20:32<00:00,  3.82s/it]


Epoch: 5 Loss: 0.00818822823173568


100%|████████████████████████████████████████████████████████████████████████████████| 323/323 [20:34<00:00,  3.82s/it]


Epoch: 6 Loss: 0.008966166759786372


100%|████████████████████████████████████████████████████████████████████████████████| 323/323 [20:45<00:00,  3.86s/it]


Epoch: 7 Loss: 0.008078510104771635


100%|████████████████████████████████████████████████████████████████████████████████| 323/323 [20:35<00:00,  3.82s/it]


Epoch: 8 Loss: 0.007845037007920162


100%|████████████████████████████████████████████████████████████████████████████████| 323/323 [20:38<00:00,  3.84s/it]


Epoch: 9 Loss: 0.008686192886435542


 27%|██████████████████████                                                           | 88/323 [05:35<14:56,  3.81s/it]


KeyboardInterrupt: 

In [22]:
torch.save(r_tran.state_dict(), './saved_models/recons_imputation_transformer_prototype.pt')
torch.save(r_tran.imputation_transformer.state_dict(), './saved_models/imputation_transformer_prototype.pt')

In [23]:
test_mses = []
with torch.no_grad():
    for i, (X, y) in enumerate(tqdm(testloader)):
            masked_X = Masking.mask_input(X)
            xhat = r_tran(masked_X[:,:,columns_kept].float())
            test_mse = mse_missing(xhat, masked_X[:,:,columns_kept], X[:,:,columns_kept])
            test_mses.append(test_mse.item())
print("Test MSE:", np.mean(test_mses))

100%|████████████████████████████████████████████████████████████████████████████████| 205/205 [06:21<00:00,  1.86s/it]

Test MSE: 0.015520776898235857





In [27]:
model = ImputationTransformer(embed_dim=128)
model.load_state_dict(torch.load('./saved_models/imputation_transformer_prototype.pt'))

<All keys matched successfully>

In [30]:
model(ex_X[:,:,columns_kept].float()).shape

torch.Size([64, 128])

In [54]:
import pandas as pd

In [24]:
def interpolate(tensor):
    interpolated = []
    for i in range(tensor.shape[0]):
        df = pd.DataFrame(tensor[i])
        df[df == -1] = float('nan')
        df = df.interpolate(method='linear', axis=0).fillna(method='bfill')
        interpolated.append(df.to_numpy())
        
    return torch.FloatTensor(np.array(interpolated))

In [76]:
test_mses = []
for i, (X, y) in enumerate(tqdm(testloader)):
        masked_X = Masking.mask_input(X)
        xhat = interpolate(masked_X[:,:,columns_kept])
        test_mse = mse_missing(xhat, masked_X[:,:,columns_kept], X[:,:,columns_kept])
        test_mses.append(test_mse.item())
print("Test MSE:", np.mean(test_mses))

100%|██████████████████████████████████████████████████████████████████████████████████| 12/12 [00:02<00:00,  5.17it/s]

Test MSE: 0.13215842927828966





In [79]:
def backfill_impute(tensor):
    interpolated = []
    for i in range(tensor.shape[0]):
        df = pd.DataFrame(tensor[i])
        df[df == -1] = float('nan')
        df = df.fillna(method='bfill').fillna(method='ffill')
        interpolated.append(df.to_numpy())
        
    return torch.FloatTensor(np.array(interpolated))

In [80]:
test_mses = []
for i, (X, y) in enumerate(tqdm(testloader)):
        masked_X = Masking.mask_input(X)
        xhat = backfill_impute(masked_X[:,:,columns_kept])
        test_mse = mse_missing(xhat, masked_X[:,:,columns_kept], X[:,:,columns_kept])
        test_mses.append(test_mse.item())
print("Test MSE:", np.mean(test_mses))

100%|██████████████████████████████████████████████████████████████████████████████████| 12/12 [00:01<00:00,  9.70it/s]

Test MSE: 0.1667811141694234



