In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import DataUtils
import Masking

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
batch_size = 64

In [3]:
columns_excluded = [0, 1, 26]
columns_kept = [False, False, True, True, True,
               True, True, True, True, True,
               True, True, True, True, True,
               True, True, True, True, True,
               True, True, True, True, True,
               True, False]
print(columns_kept)

[False, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False]


In [4]:
trainloader, testloader = DataUtils.get_dataloaders(batch_size)

In [5]:
example = next(enumerate(trainloader))

In [6]:
ex_X = example[1][0]
print(ex_X.shape)

ex_y = example[1][1]
print(ex_y.shape)

torch.Size([64, 543, 27])
torch.Size([64, 543])


In [7]:
# Should we use a linear attention transformer?
class ImputationTransformer(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        
        self.input_projection = nn.Linear(543*24, embed_dim)
        self.positional_embed = nn.Parameter(torch.randn(embed_dim))
        
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=4, activation="gelu")
        self.transformer_blocks = nn.TransformerEncoder(self.encoder_layer, num_layers=3)
    
    def forward(self, x):
        x = x.flatten(1)
        z = self.input_projection(x)
        z = z + self.positional_embed
        z = self.transformer_blocks(z)
        
        return z

In [8]:
class ReconstructionImputationTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.embed_dim = 128
        self.imputation_transformer = ImputationTransformer(self.embed_dim)
        self.reconstruction = nn.Sequential(nn.Linear(self.embed_dim, 2048),
                                           nn.ReLU(),
                                           nn.Linear(2048, 543*24))
    
    def forward(self, x):
        z = self.imputation_transformer(x)
        z = self.reconstruction(z)
        
        return z.view(x.shape[0],543,24)

In [9]:
ex_X[:,:,columns_kept].shape

torch.Size([64, 543, 24])

In [10]:
r_tran = ReconstructionImputationTransformer()
print("Num params:",sum(p.numel() for p in r_tran.parameters() if p.requires_grad))

Num params: 31007208


In [11]:
r_tran(ex_X[:,:,columns_kept].float()).shape

torch.Size([64, 543, 24])

### Training Model

In [None]:
def mse_missing(xhat, masked_X, X):
    missing_idx = (masked_X==float(-1))
    error = xhat[missing_idx] - X[missing_idx]
    return torch.mean(torch.pow(error, 2))

In [13]:
objective = mse_missing

r_tran = ReconstructionImputationTransformer()

lr = 1e-4
n_epochs = 5
optim = torch.optim.Adam(r_tran.parameters(), lr=lr)
losses = []

for n in range(n_epochs):
    counter = 0
    for i, (X, y) in enumerate(tqdm(trainloader)):
        
        masked_X = Masking.mask_input(X)
        optim.zero_grad()
        xhat = r_tran(masked_X[:,:,columns_kept].float())
        loss = objective(xhat, masked_X[:,:,columns_kept], X[:,:,columns_kept])\
        loss.backward()
        losses.append(loss.item())
        optim.step()
        counter += 1
        
    print("Epoch:", n+1, "Loss:",np.mean(losses[-counter:][0]))

  0%|                                                                                           | 0/12 [00:00<?, ?it/s]

tensor([-0.6659, -0.4141,  0.0926,  ..., -1.0164, -0.0238, -0.4041],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.4277, dtype=torch.float64, grad_fn=<MeanBackward0>)


  8%|██████▉                                                                            | 1/12 [00:03<00:40,  3.70s/it]

tensor([-0.8668, -0.9383, -0.9288,  ..., -0.0895, -0.1473, -0.1360],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.3640, dtype=torch.float64, grad_fn=<MeanBackward0>)


 17%|█████████████▊                                                                     | 2/12 [00:07<00:36,  3.69s/it]

tensor([-0.3076, -0.0623, -0.1278,  ..., -0.6020,  0.0729, -0.7282],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.3554, dtype=torch.float64, grad_fn=<MeanBackward0>)


 25%|████████████████████▊                                                              | 3/12 [00:11<00:33,  3.72s/it]

tensor([-0.1795, -0.1946, -0.8323,  ...,  0.2172, -0.3710, -0.6059],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.2996, dtype=torch.float64, grad_fn=<MeanBackward0>)


 33%|███████████████████████████▋                                                       | 4/12 [00:14<00:29,  3.72s/it]

tensor([-0.5821, -0.8921, -0.1407,  ..., -0.9583, -0.6301, -0.9998],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.2933, dtype=torch.float64, grad_fn=<MeanBackward0>)


 42%|██████████████████████████████████▌                                                | 5/12 [00:18<00:25,  3.70s/it]

tensor([-0.2359, -0.1642,  0.0438,  ..., -0.9795, -0.1067, -0.7473],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.2637, dtype=torch.float64, grad_fn=<MeanBackward0>)


 50%|█████████████████████████████████████████▌                                         | 6/12 [00:22<00:22,  3.72s/it]

tensor([-0.4503, -0.3309,  0.1367,  ..., -0.6395, -0.4720, -0.1877],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.2135, dtype=torch.float64, grad_fn=<MeanBackward0>)


 58%|████████████████████████████████████████████████▍                                  | 7/12 [00:26<00:18,  3.74s/it]

tensor([-0.4125, -0.8267, -0.0170,  ..., -0.3172, -0.8836, -0.2521],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1962, dtype=torch.float64, grad_fn=<MeanBackward0>)


 67%|███████████████████████████████████████████████████████▎                           | 8/12 [00:29<00:14,  3.71s/it]

tensor([-0.5123,  0.1828,  0.4844,  ..., -0.5734,  0.2191,  0.3788],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1904, dtype=torch.float64, grad_fn=<MeanBackward0>)


 75%|██████████████████████████████████████████████████████████████▎                    | 9/12 [00:33<00:11,  3.69s/it]

tensor([-0.5666, -0.9447, -0.8487,  ..., -0.5005, -0.5979, -0.5830],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1630, dtype=torch.float64, grad_fn=<MeanBackward0>)


 83%|████████████████████████████████████████████████████████████████████▎             | 10/12 [00:37<00:07,  3.67s/it]

tensor([ 0.0572, -0.8823, -0.0290,  ..., -0.7073, -0.4913,  0.2169],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.2111, dtype=torch.float64, grad_fn=<MeanBackward0>)


 92%|███████████████████████████████████████████████████████████████████████████▏      | 11/12 [00:40<00:03,  3.67s/it]

tensor([-0.7753, -0.7318, -0.7390,  ...,  0.3144,  0.2860,  0.1414],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1539, dtype=torch.float64, grad_fn=<MeanBackward0>)


100%|██████████████████████████████████████████████████████████████████████████████████| 12/12 [00:41<00:00,  3.43s/it]


Epoch: 1 Loss: 0.42772562746347287


  0%|                                                                                           | 0/12 [00:00<?, ?it/s]

tensor([ 0.3747,  0.3296, -0.3179,  ...,  0.8463, -0.1998, -0.1929],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1472, dtype=torch.float64, grad_fn=<MeanBackward0>)


  8%|██████▉                                                                            | 1/12 [00:03<00:41,  3.75s/it]

tensor([-0.1296, -0.2841,  0.5887,  ..., -0.7630, -0.4738, -0.0096],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1408, dtype=torch.float64, grad_fn=<MeanBackward0>)


 17%|█████████████▊                                                                     | 2/12 [00:07<00:36,  3.69s/it]

tensor([-0.4470, -0.7758, -0.5148,  ...,  0.2976,  0.1224,  0.3410],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1372, dtype=torch.float64, grad_fn=<MeanBackward0>)


 25%|████████████████████▊                                                              | 3/12 [00:11<00:33,  3.74s/it]

tensor([-0.2373, -0.3620, -0.4558,  ...,  0.6442, -0.2932, -0.1955],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1406, dtype=torch.float64, grad_fn=<MeanBackward0>)


 33%|███████████████████████████▋                                                       | 4/12 [00:14<00:29,  3.70s/it]

tensor([-0.2777, -0.7041, -0.6452,  ...,  0.6392,  0.0633, -0.3459],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1136, dtype=torch.float64, grad_fn=<MeanBackward0>)


 42%|██████████████████████████████████▌                                                | 5/12 [00:18<00:25,  3.70s/it]

tensor([ 0.6659,  0.7604, -0.0184,  ..., -0.5226, -0.1106, -0.0814],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1286, dtype=torch.float64, grad_fn=<MeanBackward0>)


 50%|█████████████████████████████████████████▌                                         | 6/12 [00:22<00:22,  3.73s/it]

tensor([], dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(nan, dtype=torch.float64, grad_fn=<MeanBackward0>)


 58%|████████████████████████████████████████████████▍                                  | 7/12 [00:26<00:18,  3.75s/it]

tensor([-0.3427, -0.3706, -0.0587,  ...,  0.3783,  0.2690,  0.3239],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1309, dtype=torch.float64, grad_fn=<MeanBackward0>)


 67%|███████████████████████████████████████████████████████▎                           | 8/12 [00:29<00:15,  3.78s/it]

tensor([ 0.3319,  0.3067,  0.6309,  ..., -0.0947, -0.5314,  0.2723],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1178, dtype=torch.float64, grad_fn=<MeanBackward0>)


 75%|██████████████████████████████████████████████████████████████▎                    | 9/12 [00:33<00:11,  3.81s/it]

tensor([ 0.0277, -0.2443,  0.3483,  ...,  0.4118, -0.3974, -0.3479],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1097, dtype=torch.float64, grad_fn=<MeanBackward0>)


 83%|████████████████████████████████████████████████████████████████████▎             | 10/12 [00:37<00:07,  3.78s/it]

tensor([ 0.0920, -0.3416, -0.2260,  ...,  0.0789,  0.3705,  0.6778],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1046, dtype=torch.float64, grad_fn=<MeanBackward0>)


 92%|███████████████████████████████████████████████████████████████████████████▏      | 11/12 [00:41<00:03,  3.78s/it]

tensor([-0.4682, -0.4599, -0.0912,  ...,  0.3756, -0.2778, -0.2307],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1276, dtype=torch.float64, grad_fn=<MeanBackward0>)


100%|██████████████████████████████████████████████████████████████████████████████████| 12/12 [00:41<00:00,  3.49s/it]


Epoch: 2 Loss: 0.14724182338359063


  0%|                                                                                           | 0/12 [00:00<?, ?it/s]

tensor([ 0.3112, -0.2100, -0.2971,  ..., -0.6171,  0.3241, -0.3835],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1076, dtype=torch.float64, grad_fn=<MeanBackward0>)


  8%|██████▉                                                                            | 1/12 [00:03<00:40,  3.69s/it]

tensor([ 0.2136,  0.0238, -0.0075,  ..., -0.0238,  0.4858,  0.3004],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1044, dtype=torch.float64, grad_fn=<MeanBackward0>)


 17%|█████████████▊                                                                     | 2/12 [00:07<00:37,  3.75s/it]

tensor([-0.2275, -0.3810, -0.2779,  ...,  0.0849,  0.1816, -0.2654],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.0975, dtype=torch.float64, grad_fn=<MeanBackward0>)


 25%|████████████████████▊                                                              | 3/12 [00:11<00:33,  3.77s/it]

tensor([ 0.4087, -0.0975,  0.0608,  ...,  0.1465,  0.3867,  0.4058],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.0989, dtype=torch.float64, grad_fn=<MeanBackward0>)


 33%|███████████████████████████▋                                                       | 4/12 [00:15<00:30,  3.76s/it]

tensor([-0.3138, -0.3460, -0.0160,  ..., -0.2931, -0.5374, -0.5672],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1214, dtype=torch.float64, grad_fn=<MeanBackward0>)


 42%|██████████████████████████████████▌                                                | 5/12 [00:18<00:26,  3.81s/it]

tensor([ 0.3982,  0.4234,  0.8006,  ...,  0.9287, -0.1812,  0.1420],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1245, dtype=torch.float64, grad_fn=<MeanBackward0>)


 50%|█████████████████████████████████████████▌                                         | 6/12 [00:22<00:23,  3.83s/it]

tensor([-0.2589, -0.4040, -0.2745,  ...,  0.0229,  0.0877, -0.5122],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.0930, dtype=torch.float64, grad_fn=<MeanBackward0>)


 58%|████████████████████████████████████████████████▍                                  | 7/12 [00:26<00:19,  3.83s/it]

tensor([ 0.2724,  0.3162, -0.1786,  ...,  0.3371,  0.4676,  0.5442],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.0947, dtype=torch.float64, grad_fn=<MeanBackward0>)


 67%|███████████████████████████████████████████████████████▎                           | 8/12 [00:30<00:15,  3.80s/it]

tensor([-0.1189,  0.3475,  0.2944,  ..., -0.5234, -0.3918,  0.0575],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.0906, dtype=torch.float64, grad_fn=<MeanBackward0>)


 75%|██████████████████████████████████████████████████████████████▎                    | 9/12 [00:34<00:11,  3.75s/it]

tensor([ 0.6159,  0.6989, -0.0813,  ...,  0.7812,  0.1613,  0.0897],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1246, dtype=torch.float64, grad_fn=<MeanBackward0>)


 83%|████████████████████████████████████████████████████████████████████▎             | 10/12 [00:37<00:07,  3.76s/it]

tensor([-0.1095, -0.1259,  0.6865,  ...,  0.7945, -0.0385, -0.0760],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1232, dtype=torch.float64, grad_fn=<MeanBackward0>)


 92%|███████████████████████████████████████████████████████████████████████████▏      | 11/12 [00:41<00:03,  3.80s/it]

tensor([ 0.6002,  0.6873, -0.1830,  ..., -0.3027,  0.0411, -0.0180],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1088, dtype=torch.float64, grad_fn=<MeanBackward0>)


100%|██████████████████████████████████████████████████████████████████████████████████| 12/12 [00:42<00:00,  3.52s/it]


Epoch: 3 Loss: 0.10760912658420772


  0%|                                                                                           | 0/12 [00:00<?, ?it/s]

tensor([-0.2129,  0.2755,  0.9066,  ...,  0.7686,  0.0689,  0.1834],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1218, dtype=torch.float64, grad_fn=<MeanBackward0>)


  8%|██████▉                                                                            | 1/12 [00:04<00:44,  4.04s/it]

tensor([-0.3756, -0.3601, -0.1891,  ..., -0.2519,  0.0939,  0.1276],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.0864, dtype=torch.float64, grad_fn=<MeanBackward0>)


 17%|█████████████▊                                                                     | 2/12 [00:08<00:40,  4.02s/it]

tensor([ 0.1732,  0.4133, -0.0924,  ...,  0.0844,  0.3684,  0.4144],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.0918, dtype=torch.float64, grad_fn=<MeanBackward0>)


 25%|████████████████████▊                                                              | 3/12 [00:11<00:35,  3.91s/it]

tensor([ 0.1260, -0.1733, -0.0566,  ...,  0.2289, -0.2115,  0.4517],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.0930, dtype=torch.float64, grad_fn=<MeanBackward0>)


 33%|███████████████████████████▋                                                       | 4/12 [00:15<00:30,  3.82s/it]

tensor([-0.3641,  0.0942,  0.0791,  ..., -0.1819,  0.1265,  0.1346],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.0946, dtype=torch.float64, grad_fn=<MeanBackward0>)


 42%|██████████████████████████████████▌                                                | 5/12 [00:19<00:26,  3.76s/it]

tensor([-0.3211, -0.2951, -0.3650,  ..., -0.1268, -0.0862, -0.0678],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1237, dtype=torch.float64, grad_fn=<MeanBackward0>)


 50%|█████████████████████████████████████████▌                                         | 6/12 [00:22<00:22,  3.73s/it]

tensor([-0.0760, -0.0507,  0.8134,  ..., -0.2797, -0.0244, -0.1604],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1220, dtype=torch.float64, grad_fn=<MeanBackward0>)


 58%|████████████████████████████████████████████████▍                                  | 7/12 [00:26<00:18,  3.71s/it]

tensor([ 0.3821,  0.5870,  0.0654,  ..., -0.1006, -0.3025, -0.3369],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1005, dtype=torch.float64, grad_fn=<MeanBackward0>)


 67%|███████████████████████████████████████████████████████▎                           | 8/12 [00:30<00:14,  3.68s/it]

tensor([-0.0362, -0.2419,  0.1733,  ..., -0.2599, -0.1140, -0.0929],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.0876, dtype=torch.float64, grad_fn=<MeanBackward0>)


 75%|██████████████████████████████████████████████████████████████▎                    | 9/12 [00:33<00:10,  3.67s/it]

tensor([-0.0581, -0.0410,  0.2266,  ..., -0.5392, -0.2286, -0.1438],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.0897, dtype=torch.float64, grad_fn=<MeanBackward0>)


 83%|████████████████████████████████████████████████████████████████████▎             | 10/12 [00:37<00:07,  3.67s/it]

tensor([-0.2243, -0.3388, -0.1287,  ...,  0.7956, -0.3591, -0.2502],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1220, dtype=torch.float64, grad_fn=<MeanBackward0>)


 92%|███████████████████████████████████████████████████████████████████████████▏      | 11/12 [00:41<00:03,  3.69s/it]

tensor([ 0.5786,  0.0986,  0.0500,  ..., -0.1556,  0.4753,  0.3571],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.0960, dtype=torch.float64, grad_fn=<MeanBackward0>)


100%|██████████████████████████████████████████████████████████████████████████████████| 12/12 [00:41<00:00,  3.47s/it]


Epoch: 4 Loss: 0.12175750315029524


  0%|                                                                                           | 0/12 [00:00<?, ?it/s]

tensor([ 0.3650,  0.4220, -0.1020,  ...,  0.4602, -0.2905,  0.2780],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.0918, dtype=torch.float64, grad_fn=<MeanBackward0>)


  8%|██████▉                                                                            | 1/12 [00:03<00:39,  3.62s/it]

tensor([-0.0684,  0.3191,  0.8299,  ..., -0.4672, -0.1764, -0.1621],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1204, dtype=torch.float64, grad_fn=<MeanBackward0>)


 17%|█████████████▊                                                                     | 2/12 [00:07<00:36,  3.62s/it]

tensor([ 0.5443,  0.0226,  0.0764,  ..., -0.2930,  0.2296, -0.2430],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.0874, dtype=torch.float64, grad_fn=<MeanBackward0>)


 25%|████████████████████▊                                                              | 3/12 [00:10<00:33,  3.67s/it]

tensor([-0.2514,  0.2328,  0.7604,  ...,  0.7626, -0.0537, -0.0932],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1206, dtype=torch.float64, grad_fn=<MeanBackward0>)


 33%|███████████████████████████▋                                                       | 4/12 [00:14<00:29,  3.67s/it]

tensor([ 0.1386,  0.2747, -0.3531,  ..., -0.1904, -0.2491, -0.2441],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.0826, dtype=torch.float64, grad_fn=<MeanBackward0>)


 42%|██████████████████████████████████▌                                                | 5/12 [00:18<00:25,  3.65s/it]

tensor([-0.3748, -0.0083,  0.1411,  ..., -0.0686,  0.2655,  0.2197],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.0907, dtype=torch.float64, grad_fn=<MeanBackward0>)


 50%|█████████████████████████████████████████▌                                         | 6/12 [00:21<00:22,  3.67s/it]

tensor([-0.0805,  0.0632,  0.9259,  ...,  0.7896,  0.0242,  0.0395],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1185, dtype=torch.float64, grad_fn=<MeanBackward0>)


 58%|████████████████████████████████████████████████▍                                  | 7/12 [00:25<00:18,  3.65s/it]

tensor([ 0.5815,  0.6195, -0.1584,  ..., -0.2793, -0.1409, -0.0913],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1163, dtype=torch.float64, grad_fn=<MeanBackward0>)


 67%|███████████████████████████████████████████████████████▎                           | 8/12 [00:29<00:14,  3.65s/it]

tensor([ 0.2466, -0.0877, -0.0168,  ...,  0.3448,  0.3990,  0.2776],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.0945, dtype=torch.float64, grad_fn=<MeanBackward0>)


 75%|██████████████████████████████████████████████████████████████▎                    | 9/12 [00:32<00:10,  3.65s/it]

tensor([-0.1390, -0.0711,  0.7351,  ...,  0.8390,  0.2780,  0.2748],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1201, dtype=torch.float64, grad_fn=<MeanBackward0>)


 83%|████████████████████████████████████████████████████████████████████▎             | 10/12 [00:36<00:07,  3.65s/it]

tensor([ 0.2701,  0.3177, -0.0201,  ...,  0.7157,  0.1877,  0.2943],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.1118, dtype=torch.float64, grad_fn=<MeanBackward0>)


 92%|███████████████████████████████████████████████████████████████████████████▏      | 11/12 [00:40<00:03,  3.65s/it]

tensor([-0.2568, -0.1330,  0.2080,  ..., -0.0413,  0.3477,  0.5621],
       dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.0889, dtype=torch.float64, grad_fn=<MeanBackward0>)


100%|██████████████████████████████████████████████████████████████████████████████████| 12/12 [00:40<00:00,  3.39s/it]

Epoch: 5 Loss: 0.09182464163939537





In [14]:
with torch.no_grad():
    print(mse_missing(xhat, masked_X[:,:,columns_kept], X[:,:,columns_kept]))

tensor([-0.2568, -0.1330,  0.2080,  ..., -0.0413,  0.3477,  0.5621],
       dtype=torch.float64)
tensor(0.0889, dtype=torch.float64)


In [15]:
test_mses = []
with torch.no_grad():
    for i, (X, y) in enumerate(tqdm(testloader)):
            masked_X = Masking.mask_input(X)
            xhat = r_tran(masked_X[:,:,columns_kept].float())
            test_mse = mse_missing(xhat, masked_X[:,:,columns_kept], X[:,:,columns_kept]).item()
            test_mses.append(test_mse.item())
print("Test MSE:", np.mean(test_mses))

  0%|                                                                                           | 0/12 [00:01<?, ?it/s]

tensor([ 0.0204, -0.0875, -0.2634,  ..., -0.1168, -0.0523, -0.0008],
       dtype=torch.float64)





TypeError: mse_missing() missing 1 required positional argument: 'X'