## This is my take on DeepSDF

To be completely honest, I'm not sure if that's what I was supposed to do. Maybe simple geometric solution would work, but here we go :)

That's basically my implementation of [DeepSDF](https://github.com/facebookresearch/DeepSDF/).

To give them credits, I looked at their implementation, while doing mine (not that it was very helpful though). So I'd recommend diving into their repo for a more sophisticated approach.

### P.S.
Take this project with a grain of salt, since for real-world applications it would be much more feasible to get a pretrained DeepSDF (for example from [here](https://github.com/marian42/shapegan/tree/pretrained-deepsdf-shapenet/examples/deepsdf-shapenet-pretrained)) and produce much better scores. This notebook is specifically about my experience in discovering the mesh & SDF kind-of data and working with shape embeddings, aka DeepSDF.

## Data
You can download the prepared data [here](https://drive.google.com/drive/folders/1AE_mohNpxRg3JXoBX2oiN-8xaMcGYYuX?usp=sharing)

## TL;DR

### What I got

- Modeled and trained a DeepSDF model 
- 8 * float64 = 64 byte representations
- 4.6 ms for a batch size of 1024
- Around 0.94 validation f1 occupancy with sdf < 1e-3 (very dirty considering the sign distribution though)
- Pretty bad inside-shape modeling

### What could help and improve my results
- **More flexible and accurate solver and better data sampling for inside-points modeling (maybe sampling negative SDFs more agressively)**
- **More RAM and time for modeling a much bigger dataset (30 train models doesn't sound like a good amount of information for learning the latent shape space)**
- **Adequate validation (I believe all the hard stuff is done in this notebook, and it shouldn't be difficult to push the f1 score up from this point, it's just a routine work with data)**
- Doing the loss clamping (as proposed in the original paper) for a better surface and occupancy modeling
- More experiments with an architecture and training (weight decay as proposed in DeepSDF, lr scheduling)

In [1]:
import os
import torch
import trimesh
import numpy as np

from torch import nn
from tqdm import tqdm
from itertools import chain
from sklearn.metrics import f1_score
from torch.utils.data import Dataset, DataLoader
from mesh_to_sdf import sample_sdf_near_surface

In [2]:
# Subsampling the train and test data, since my PC isn't from NASA
SUBSAMPLE_TRAIN_VALID = 30
SUBSAMPLE_TEST = 5

# Network params
NUM_EPOCHS_TRAIN = 50
NUM_EPOCHS_TEST = 30

BATCH_SIZE = 1024
EMBEDDING_DIM = 8  # Higher values neither help nor make sence considering the amount of subsampled shapes
DROPOUT = 0.2
HIDDEN_DIMS = [64] * 5

In [3]:
class SimpleBlock(nn.Module):

    ''' Linear -> BN -> Dropout -> LeakyReLU '''
    
    def __init__(self, inp_dim, out_dim, dropout=0.2):
        super(SimpleBlock, self).__init__()
        
        self.model = nn.Sequential(
            nn.Linear(inp_dim, out_dim),
            nn.BatchNorm1d(out_dim),
            nn.Dropout(dropout),
            nn.LeakyReLU()
        )
        
    def forward(self, X):
        return self.model(X)
    
class SDFDecoder(nn.Module):
    
    ''' MLP decoder of SimpleBlock layers with tahn activation in prediction head '''
    
    def __init__(self, inp_dim, hidden_dims):
        super(SDFDecoder, self).__init__()
        
        self.model = nn.Sequential(
            SimpleBlock(inp_dim, hidden_dims[0]),
            *[SimpleBlock(i, o) for i, o in zip(hidden_dims, hidden_dims[1:])],
            nn.Linear(hidden_dims[-1], 1),
            nn.Tanh()
        )
        
    def forward(self, X):
        return self.model(X)
    
    
class SDFDataset(Dataset):
    
    ''' Dataset class for producing [latent_vector + xyz] -> SDF mapping'''
    
    def __init__(self, filenames, n_samples, encoder):
        """ Supports loading and sampling from @filenames paths but I didn't use it in the end """
        self.encoder = encoder
        self.indices = []
        self.X = torch.Tensor([])
        self.y = torch.Tensor([])
        
        for i, filepath in enumerate(tqdm(filenames)):
            self.indices.extend(n_samples * [i])
            mesh = trimesh.load(filepath)
            points, sdf = sample_sdf_near_surface(mesh, number_of_points=n_samples, sign_method='depth')
            self.X = torch.cat((self.X, torch.from_numpy(points)))
            self.y = torch.cat((self.y, torch.from_numpy(sdf)))
        
        self.indices = torch.Tensor(self.indices).int()
        
    # A bit tinky-winky save/load, zipping would make more sense but whatever
    def save(self, indices_name, X_name, y_name):
        torch.save(self.indices, indices_name)
        torch.save(self.X, X_name)
        torch.save(self.y, y_name)
    
    def load(self, indices_name, X_name, y_name):
        self.indices = torch.load(indices_name)
        self.X = torch.load(X_name)
        self.y = torch.load(y_name)
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        """ Get encoded latent vector + xyz pair with corresponding SDF """
        return torch.cat((self.encoder(self.indices[idx]), self.X[idx])), self.y[idx]

In [4]:
# Encoders initialization
train_encoder = nn.Embedding(SUBSAMPLE_TRAIN_VALID, EMBEDDING_DIM)
test_encoder = nn.Embedding(SUBSAMPLE_TEST, EMBEDDING_DIM)

# We'll load the prepaired data in the next cell
train_dataset = SDFDataset([], None, train_encoder)
valid_dataset = SDFDataset([], None, train_encoder)
test_dataset = SDFDataset([], None, test_encoder)

0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]


In [5]:
# Load already processed data
train_dataset.load('processed_data/train_idx.pt', 'processed_data/train_X.pt', 'processed_data/train_y.pt')
valid_dataset.load('processed_data/valid_idx.pt', 'processed_data/valid_X.pt', 'processed_data/valid_y.pt')
test_dataset.load('processed_data/test_idx.pt', 'processed_data/test_X.pt', 'processed_data/test_y.pt')

In [6]:
# Again, a bit tinky-winky subsampling
train_ind = train_dataset.indices < SUBSAMPLE_TRAIN_VALID
valid_ind = valid_dataset.indices < SUBSAMPLE_TRAIN_VALID
test_ind = test_dataset.indices < SUBSAMPLE_TEST

train_dataset.indices = train_dataset.indices[train_ind]
train_dataset.X = train_dataset.X[train_ind]
train_dataset.y = train_dataset.y[train_ind]

valid_dataset.indices = valid_dataset.indices[valid_ind]
valid_dataset.X = valid_dataset.X[valid_ind]
valid_dataset.y = valid_dataset.y[valid_ind]

test_dataset.indices = test_dataset.indices[test_ind]
test_dataset.X = test_dataset.X[test_ind]
test_dataset.y = test_dataset.y[test_ind]

In [7]:
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [8]:
print(f'{len(train_dataset)=}', f'{len(valid_dataset)=}', f'{len(test_dataset)=}', sep='\n')

len(train_dataset)=3000000
len(valid_dataset)=600000
len(test_dataset)=400000


In [9]:
decoder = SDFDecoder(EMBEDDING_DIM + 3, HIDDEN_DIMS)

criterion = nn.L1Loss()
optimizer = torch.optim.Adam(chain(train_encoder.parameters(), decoder.parameters()), lr = 0.0005)

In [10]:
train_encoder.train()
decoder.train()

for epoch in range(NUM_EPOCHS_TRAIN):
    pbar = tqdm(train_dataloader)
    for X, y in pbar:
        
        optimizer.zero_grad()
        
        out = decoder(X).squeeze()
    
        loss = criterion(out, y)
        pbar.set_description(f'Batch loss: {loss.item()}, f1: {f1_score(y>0, out>0)}', refresh=True)

        # Backward pass
        loss.backward()
        optimizer.step()
        
    valid_y, valid_out = map(torch.cat, zip(*[(y, decoder(X).squeeze()) for X, y in valid_dataloader]))
    
    ind = valid_y.abs() < 1e-3
    occupancy_y = valid_y[ind] > 0
    occupancy_out = valid_out[ind] > 0

    print(f'Valid {epoch=}, loss={criterion(valid_out, valid_y).item():.4f}, f1={f1_score(valid_y>0, valid_out>0):.3f},\
     occupancy_f1={f1_score(occupancy_y, occupancy_out):.3f}')

Batch loss: 0.04344232380390167, f1: 0.8998435054773084: 100%|█| 2930/2930 [01:39<00:00, 29.47it/


Valid epoch=0, loss=0.0395, f1=0.898,     occupancy_f1=0.946


Batch loss: 0.04686539247632027, f1: 0.9051321928460342: 100%|█| 2930/2930 [01:26<00:00, 33.68it/


Valid epoch=1, loss=0.0396, f1=0.898,     occupancy_f1=0.946


Batch loss: 0.013404348865151405, f1: 0.9000000000000001: 100%|█| 2930/2930 [01:29<00:00, 32.62it


Valid epoch=2, loss=0.0137, f1=0.897,     occupancy_f1=0.944


Batch loss: 0.009753807447850704, f1: 0.8982785602503912: 100%|█| 2930/2930 [01:28<00:00, 33.24it


Valid epoch=3, loss=0.0103, f1=0.897,     occupancy_f1=0.945


Batch loss: 0.008623799309134483, f1: 0.9017160686427457: 100%|█| 2930/2930 [01:35<00:00, 30.80it


Valid epoch=4, loss=0.0092, f1=0.894,     occupancy_f1=0.941


Batch loss: 0.009868474677205086, f1: 0.8885375494071146: 100%|█| 2930/2930 [01:37<00:00, 30.13it


Valid epoch=5, loss=0.0086, f1=0.892,     occupancy_f1=0.940


Batch loss: 0.007716733496636152, f1: 0.8957345971563981: 100%|█| 2930/2930 [01:31<00:00, 31.91it


Valid epoch=6, loss=0.0083, f1=0.892,     occupancy_f1=0.938


Batch loss: 0.00811475794762373, f1: 0.8833865814696484: 100%|█| 2930/2930 [01:32<00:00, 31.72it/


Valid epoch=7, loss=0.0081, f1=0.877,     occupancy_f1=0.918


Batch loss: 0.00865777675062418, f1: 0.8885350318471338: 100%|█| 2930/2930 [01:32<00:00, 31.71it/


Valid epoch=8, loss=0.0078, f1=0.888,     occupancy_f1=0.933


Batch loss: 0.006950098555535078, f1: 0.88659793814433: 100%|█| 2930/2930 [01:32<00:00, 31.53it/s


Valid epoch=9, loss=0.0076, f1=0.895,     occupancy_f1=0.942


Batch loss: 0.009306528605520725, f1: 0.8908227848101266: 100%|█| 2930/2930 [01:29<00:00, 32.57it


Valid epoch=10, loss=0.0076, f1=0.892,     occupancy_f1=0.938


Batch loss: 0.00768881244584918, f1: 0.8892430278884463: 100%|█| 2930/2930 [01:30<00:00, 32.37it/


Valid epoch=11, loss=0.0074, f1=0.889,     occupancy_f1=0.934


Batch loss: 0.006368952803313732, f1: 0.8848292295472596: 100%|█| 2930/2930 [01:32<00:00, 31.80it


Valid epoch=12, loss=0.0074, f1=0.847,     occupancy_f1=0.875


Batch loss: 0.008971486240625381, f1: 0.8966061562746647: 100%|█| 2930/2930 [01:37<00:00, 30.17it


Valid epoch=13, loss=0.0073, f1=0.890,     occupancy_f1=0.936


Batch loss: 0.007169054355472326, f1: 0.8981191222570534: 100%|█| 2930/2930 [01:31<00:00, 31.94it


Valid epoch=14, loss=0.0077, f1=0.884,     occupancy_f1=0.929


Batch loss: 0.006643187720328569, f1: 0.9099378881987578: 100%|█| 2930/2930 [01:35<00:00, 30.71it


Valid epoch=15, loss=0.0072, f1=0.887,     occupancy_f1=0.932


Batch loss: 0.006735484581440687, f1: 0.8798076923076924: 100%|█| 2930/2930 [01:32<00:00, 31.84it


Valid epoch=16, loss=0.0072, f1=0.887,     occupancy_f1=0.931


Batch loss: 0.008244123309850693, f1: 0.8954041204437401: 100%|█| 2930/2930 [01:32<00:00, 31.74it


Valid epoch=17, loss=0.0072, f1=0.886,     occupancy_f1=0.929


Batch loss: 0.0086289057508111, f1: 0.8849840255591054: 100%|█| 2930/2930 [01:41<00:00, 28.86it/s


Valid epoch=18, loss=0.0071, f1=0.889,     occupancy_f1=0.934


Batch loss: 0.006356589961796999, f1: 0.8846459824980112: 100%|█| 2930/2930 [01:33<00:00, 31.28it


Valid epoch=19, loss=0.0071, f1=0.885,     occupancy_f1=0.929


Batch loss: 0.006907808594405651, f1: 0.9043887147335423: 100%|█| 2930/2930 [01:33<00:00, 31.50it


Valid epoch=20, loss=0.0070, f1=0.893,     occupancy_f1=0.939


Batch loss: 0.006666330620646477, f1: 0.8842443729903536: 100%|█| 2930/2930 [01:33<00:00, 31.48it


Valid epoch=21, loss=0.0070, f1=0.894,     occupancy_f1=0.940


Batch loss: 0.005677207838743925, f1: 0.9041309431021044: 100%|█| 2930/2930 [01:36<00:00, 30.51it


Valid epoch=22, loss=0.0070, f1=0.892,     occupancy_f1=0.938


Batch loss: 0.007726510986685753, f1: 0.8894192521877485: 100%|█| 2930/2930 [01:33<00:00, 31.35it


Valid epoch=23, loss=0.0070, f1=0.892,     occupancy_f1=0.939


Batch loss: 0.00893603265285492, f1: 0.8801287208366856: 100%|█| 2930/2930 [01:32<00:00, 31.70it/


Valid epoch=24, loss=0.0070, f1=0.884,     occupancy_f1=0.928


Batch loss: 0.008289656601846218, f1: 0.9106449106449106: 100%|█| 2930/2930 [01:41<00:00, 28.99it


Valid epoch=25, loss=0.0070, f1=0.889,     occupancy_f1=0.935


Batch loss: 0.0057388306595385075, f1: 0.903831118060985: 100%|█| 2930/2930 [01:32<00:00, 31.60it


Valid epoch=26, loss=0.0069, f1=0.885,     occupancy_f1=0.929


Batch loss: 0.0069833677262067795, f1: 0.8904761904761904: 100%|█| 2930/2930 [01:34<00:00, 30.97i


Valid epoch=27, loss=0.0069, f1=0.881,     occupancy_f1=0.924


Batch loss: 0.007301142904907465, f1: 0.8929421094369547: 100%|█| 2930/2930 [01:32<00:00, 31.61it


Valid epoch=28, loss=0.0069, f1=0.881,     occupancy_f1=0.924


Batch loss: 0.008149604313075542, f1: 0.8913560666137985: 100%|█| 2930/2930 [01:34<00:00, 30.89it


Valid epoch=29, loss=0.0068, f1=0.890,     occupancy_f1=0.935


Batch loss: 0.007708899211138487, f1: 0.8724939855653568: 100%|█| 2930/2930 [01:33<00:00, 31.31it


Valid epoch=30, loss=0.0068, f1=0.889,     occupancy_f1=0.934


Batch loss: 0.006839623674750328, f1: 0.8998435054773083: 100%|█| 2930/2930 [01:40<00:00, 29.18it


Valid epoch=31, loss=0.0068, f1=0.893,     occupancy_f1=0.939


Batch loss: 0.006167256738990545, f1: 0.877502001601281: 100%|█| 2930/2930 [01:34<00:00, 31.00it/


Valid epoch=32, loss=0.0067, f1=0.885,     occupancy_f1=0.929


Batch loss: 0.005818095523864031, f1: 0.8878281622911695: 100%|█| 2930/2930 [01:35<00:00, 30.72it


Valid epoch=33, loss=0.0067, f1=0.883,     occupancy_f1=0.926


Batch loss: 0.008353653363883495, f1: 0.8509575353871773: 100%|█| 2930/2930 [01:36<00:00, 30.38it


Valid epoch=34, loss=0.0068, f1=0.881,     occupancy_f1=0.923


Batch loss: 0.006025611888617277, f1: 0.8798076923076923: 100%|█| 2930/2930 [01:37<00:00, 30.16it


Valid epoch=35, loss=0.0067, f1=0.893,     occupancy_f1=0.939


Batch loss: 0.0064497594721615314, f1: 0.8974158183241974: 100%|█| 2930/2930 [01:34<00:00, 31.11i


Valid epoch=36, loss=0.0067, f1=0.895,     occupancy_f1=0.942


Batch loss: 0.006586477626115084, f1: 0.8908227848101267: 100%|█| 2930/2930 [01:35<00:00, 30.75it


Valid epoch=37, loss=0.0067, f1=0.886,     occupancy_f1=0.931


Batch loss: 0.005862246733158827, f1: 0.8913385826771653: 100%|█| 2930/2930 [01:34<00:00, 31.10it


Valid epoch=38, loss=0.0067, f1=0.893,     occupancy_f1=0.940


Batch loss: 0.005514708813279867, f1: 0.8839427662957074: 100%|█| 2930/2930 [01:35<00:00, 30.71it


Valid epoch=39, loss=0.0071, f1=0.887,     occupancy_f1=0.930


Batch loss: 0.005446005146950483, f1: 0.8915281076801267: 100%|█| 2930/2930 [01:34<00:00, 31.12it


Valid epoch=40, loss=0.0066, f1=0.891,     occupancy_f1=0.938


Batch loss: 0.005131710786372423, f1: 0.9007036747458951: 100%|█| 2930/2930 [01:33<00:00, 31.25it


Valid epoch=41, loss=0.0066, f1=0.895,     occupancy_f1=0.942


Batch loss: 0.006648695562034845, f1: 0.886762360446571: 100%|█| 2930/2930 [01:33<00:00, 31.23it/


Valid epoch=42, loss=0.0066, f1=0.874,     occupancy_f1=0.914


Batch loss: 0.005919893272221088, f1: 0.9024199843871976: 100%|█| 2930/2930 [01:34<00:00, 31.16it


Valid epoch=43, loss=0.0066, f1=0.890,     occupancy_f1=0.935


Batch loss: 0.006516121793538332, f1: 0.8986645718774549: 100%|█| 2930/2930 [01:35<00:00, 30.82it


Valid epoch=44, loss=0.0067, f1=0.889,     occupancy_f1=0.934


Batch loss: 0.007771013770252466, f1: 0.8881839809674861: 100%|█| 2930/2930 [01:37<00:00, 29.99it


Valid epoch=45, loss=0.0066, f1=0.885,     occupancy_f1=0.929


Batch loss: 0.006158776115626097, f1: 0.8990536277602522: 100%|█| 2930/2930 [01:35<00:00, 30.84it


Valid epoch=46, loss=0.0066, f1=0.891,     occupancy_f1=0.937


Batch loss: 0.007832656614482403, f1: 0.8894192521877484: 100%|█| 2930/2930 [01:39<00:00, 29.35it


Valid epoch=47, loss=0.0066, f1=0.887,     occupancy_f1=0.932


Batch loss: 0.005524642299860716, f1: 0.8930817610062893: 100%|█| 2930/2930 [01:34<00:00, 30.86it


Valid epoch=48, loss=0.0066, f1=0.894,     occupancy_f1=0.941


Batch loss: 0.006987675558775663, f1: 0.8937844217151848: 100%|█| 2930/2930 [01:36<00:00, 30.36it


Valid epoch=49, loss=0.0065, f1=0.891,     occupancy_f1=0.937


In [11]:
# Again, I don't have a NASA PC
del train_dataset, valid_dataset, train_dataloader, valid_dataloader

In [12]:
# Freeze the decoder layers to finetune the test representations
decoder.requires_grad = False
for param in decoder.parameters():
    param.requires_grad = False

In [13]:
test_criterion = nn.L1Loss()
test_optimizer = torch.optim.Adam(test_encoder.parameters(), lr = 0.0003)

In [14]:
test_encoder.train()

for epoch in range(NUM_EPOCHS_TEST):
    pbar = tqdm(test_dataloader)
    for X, y in pbar:
        
        test_optimizer.zero_grad()
        
        out = decoder(X).squeeze()
        
        test_loss = test_criterion(out, y)
        pbar.set_description(f'Batch loss: {test_loss.item()}, f1: {f1_score(y>0, out>0)}', refresh=True)

        # Backward pass
        test_loss.backward()
        test_optimizer.step()
    
    test_y, test_out = map(torch.cat, zip(*[(y, decoder(X).squeeze()) for X, y in test_dataloader]))
    ind = test_y < 1e-3
    occupancy_y = test_y[ind] > 0
    occupancy_out = test_out[ind] > 0

    print(f'Test {epoch=}, loss={criterion(test_out, test_y).item():.4f}, f1={f1_score(test_y>0, test_out>0):.3f},\
    occupancy_f1={f1_score(occupancy_y, occupancy_out):.3f}')

Batch loss: 0.06008940190076828, f1: 0.6603550295857987: 100%|█| 391/391 [00:12<00:00, 32.56it/s]


Test epoch=0, loss=0.1445, f1=0.457,    occupancy_f1=0.319


Batch loss: 0.11212427914142609, f1: 0.4747191011235955: 100%|█| 391/391 [00:11<00:00, 33.88it/s]


Test epoch=1, loss=0.0856, f1=0.568,    occupancy_f1=0.437


Batch loss: 0.026577701792120934, f1: 0.8295454545454546: 100%|█| 391/391 [00:11<00:00, 32.85it/s


Test epoch=2, loss=0.0548, f1=0.685,    occupancy_f1=0.549


Batch loss: 0.04530294984579086, f1: 0.7005524861878454: 100%|█| 391/391 [00:11<00:00, 34.33it/s]


Test epoch=3, loss=0.0409, f1=0.754,    occupancy_f1=0.606


Batch loss: 0.14327748119831085, f1: 0.5142083897158322: 100%|█| 391/391 [00:11<00:00, 32.98it/s]


Test epoch=4, loss=0.0322, f1=0.805,    occupancy_f1=0.648


Batch loss: 0.034217238426208496, f1: 0.7749003984063745: 100%|█| 391/391 [00:11<00:00, 33.60it/s


Test epoch=5, loss=0.0282, f1=0.828,    occupancy_f1=0.665


Batch loss: 0.02553776279091835, f1: 0.8506375227686702: 100%|█| 391/391 [00:11<00:00, 33.72it/s]


Test epoch=6, loss=0.0254, f1=0.846,    occupancy_f1=0.680


Batch loss: 0.02668728493154049, f1: 0.823199251637044: 100%|██| 391/391 [00:11<00:00, 33.08it/s]


Test epoch=7, loss=0.0235, f1=0.861,    occupancy_f1=0.690


Batch loss: 0.02032265067100525, f1: 0.895104895104895: 100%|██| 391/391 [00:11<00:00, 33.15it/s]


Test epoch=8, loss=0.0218, f1=0.872,    occupancy_f1=0.698


Batch loss: 0.02156975492835045, f1: 0.8855895196506549: 100%|█| 391/391 [00:11<00:00, 34.02it/s]


Test epoch=9, loss=0.0207, f1=0.877,    occupancy_f1=0.702


Batch loss: 0.017775217071175575, f1: 0.8875326939843068: 100%|█| 391/391 [00:11<00:00, 32.65it/s


Test epoch=10, loss=0.0202, f1=0.879,    occupancy_f1=0.703


Batch loss: 0.017910804599523544, f1: 0.8818897637795275: 100%|█| 391/391 [00:11<00:00, 33.37it/s


Test epoch=11, loss=0.0193, f1=0.883,    occupancy_f1=0.706


Batch loss: 0.016597749665379524, f1: 0.8900523560209425: 100%|█| 391/391 [00:11<00:00, 34.04it/s


Test epoch=12, loss=0.0186, f1=0.884,    occupancy_f1=0.707


Batch loss: 0.01919686608016491, f1: 0.8645276292335117: 100%|█| 391/391 [00:11<00:00, 32.86it/s]


Test epoch=13, loss=0.0178, f1=0.887,    occupancy_f1=0.709


Batch loss: 0.014323596842586994, f1: 0.8939526730937775: 100%|█| 391/391 [00:11<00:00, 33.21it/s


Test epoch=14, loss=0.0172, f1=0.886,    occupancy_f1=0.708


Batch loss: 0.014604304917156696, f1: 0.8814749780509219: 100%|█| 391/391 [00:11<00:00, 33.31it/s


Test epoch=15, loss=0.0164, f1=0.888,    occupancy_f1=0.710


Batch loss: 0.013365899212658405, f1: 0.8969696969696969: 100%|█| 391/391 [00:11<00:00, 33.68it/s


Test epoch=16, loss=0.0157, f1=0.888,    occupancy_f1=0.710


Batch loss: 0.017211416736245155, f1: 0.8477666362807658: 100%|█| 391/391 [00:11<00:00, 32.77it/s


Test epoch=17, loss=0.0148, f1=0.889,    occupancy_f1=0.710


Batch loss: 0.01550893485546112, f1: 0.8695652173913043: 100%|█| 391/391 [00:12<00:00, 32.34it/s]


Test epoch=18, loss=0.0140, f1=0.889,    occupancy_f1=0.711


Batch loss: 0.010101839900016785, f1: 0.893913043478261: 100%|█| 391/391 [00:11<00:00, 33.10it/s]


Test epoch=19, loss=0.0133, f1=0.888,    occupancy_f1=0.710


Batch loss: 0.010683066211640835, f1: 0.8944636678200691: 100%|█| 391/391 [00:14<00:00, 27.81it/s


Test epoch=20, loss=0.0127, f1=0.888,    occupancy_f1=0.710


Batch loss: 0.013909241184592247, f1: 0.9090909090909091: 100%|█| 391/391 [00:11<00:00, 34.13it/s


Test epoch=21, loss=0.0123, f1=0.888,    occupancy_f1=0.709


Batch loss: 0.010546230711042881, f1: 0.9017241379310346: 100%|█| 391/391 [00:11<00:00, 33.16it/s


Test epoch=22, loss=0.0120, f1=0.889,    occupancy_f1=0.710


Batch loss: 0.01228383369743824, f1: 0.8873362445414846: 100%|█| 391/391 [00:11<00:00, 33.22it/s]


Test epoch=23, loss=0.0117, f1=0.888,    occupancy_f1=0.709


Batch loss: 0.011690007522702217, f1: 0.8894691035683202: 100%|█| 391/391 [00:11<00:00, 34.10it/s


Test epoch=24, loss=0.0115, f1=0.888,    occupancy_f1=0.710


Batch loss: 0.009921910241246223, f1: 0.8740088105726873: 100%|█| 391/391 [00:12<00:00, 32.38it/s


Test epoch=25, loss=0.0112, f1=0.888,    occupancy_f1=0.710


Batch loss: 0.01305618416517973, f1: 0.8747795414462082: 100%|█| 391/391 [00:12<00:00, 32.54it/s]


Test epoch=26, loss=0.0111, f1=0.888,    occupancy_f1=0.710


Batch loss: 0.013755199499428272, f1: 0.8869412795793163: 100%|█| 391/391 [00:12<00:00, 32.44it/s


Test epoch=27, loss=0.0110, f1=0.888,    occupancy_f1=0.710


Batch loss: 0.008971111848950386, f1: 0.8737864077669903: 100%|█| 391/391 [00:11<00:00, 32.81it/s


Test epoch=28, loss=0.0108, f1=0.888,    occupancy_f1=0.710


Batch loss: 0.009782833978533745, f1: 0.8990509059534081: 100%|█| 391/391 [00:11<00:00, 33.92it/s


Test epoch=29, loss=0.0107, f1=0.888,    occupancy_f1=0.710


In [15]:
len(occupancy_y), len(occupancy_out)

(171612, 171612)

In [16]:
occupancy_y.sum(), occupancy_out.sum()

(tensor(95473), tensor(167854))

In [17]:
batch = next(iter(test_dataloader))[0]
len(batch)

1024

In [18]:
%timeit decoder(batch)

4.63 ms ± 1.72 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
