In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from datasets import CubeObstacle, CylinderObstacle, BlockageDataset
from utils.config import Hyperparameters as hp

In [2]:
random_seed = 42
torch.random.manual_seed(random_seed)
np.random.seed(random_seed)

batch_size = 2**10

In [3]:
class Net(nn.Module):
    def __init__(self, num_node, hidden_N, hidden_L, output_N=3):
        super(Net, self).__init__()
        self.hidden_N = hidden_N
        self.hidden_L = hidden_L
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(num_node, hidden_N))
        for _ in range(hidden_L):
            self.layers.append(nn.Linear(hidden_N, hidden_N))

        self.dropouts = nn.ModuleList()
        for _ in range(hidden_L):
            self.dropouts.append(nn.Dropout(0.3))

        self.batches = nn.ModuleList()
        for _ in range(hidden_L):
            self.batches.append(nn.BatchNorm1d(hidden_N))

        self.output = nn.Linear(hidden_N, output_N)
    def forward(self, x):
        z = x
        for layer, dropout, batch_norm in zip(self.layers, self.dropouts, self.batches):
            z = layer(z)
            z = batch_norm(z)
            z = F.leaky_relu(z, 0.05)
            z = dropout(z)

        z = torch.sigmoid(self.output(z))
        return z

In [4]:
obstacle_ls = [
    CubeObstacle(-30, 25, 35, 60, 20, 0.1),
    CubeObstacle(-30, -25, 45, 10, 35, 0.1),
    CubeObstacle(-30, -60, 35, 60, 20, 0.1),
    CubeObstacle(50, -20, 35, 25, 25, 0.1),
    CylinderObstacle(10, -5,  70, 15, 0.1),
]

In [5]:
dataset = BlockageDataset(100000, obstacle_ls, 4, dtype=torch.float32).to(hp.device)
dataset.__len__()

100%|██████████| 100000/100000 [00:01<00:00, 78094.92it/s]


100000

In [6]:
data = dataset.gnd_nodes.cpu().numpy()
data = np.delete(data,2, axis=2).reshape(-1,8)
data

array([[ 48.473988, -54.333748, -88.3084  , ..., -66.39599 , -88.52163 ,
         16.66145 ],
       [ 76.86868 , -69.65142 , -17.853691, ...,  38.344025,  32.268536,
        -92.319824],
       [-26.34998 ,  95.806366, -15.852688, ...,  43.452374,  16.414026,
         59.560936],
       ...,
       [ 37.135468,  16.901632,  -9.638972, ...,  44.28218 , -84.911   ,
         21.507885],
       [ 71.97478 , -88.20765 , -63.707054, ..., -63.35672 , -69.68139 ,
         51.659626],
       [ 50.732513, -89.22347 , -46.146835, ..., -20.157822, -80.070595,
        -85.88436 ]], dtype=float32)

In [7]:
df = pd.DataFrame(data)
df

Unnamed: 0,0,1,2,3,4,5,6,7
0,48.473988,-54.333748,-88.308403,-40.014122,-5.138267,-66.395988,-88.521629,16.661449
1,76.868683,-69.651421,-17.853691,0.927458,-98.529388,38.344025,32.268536,-92.319824
2,-26.349979,95.806366,-15.852688,0.389081,81.986031,43.452374,16.414026,59.560936
3,-4.978327,-22.789812,93.341225,39.433250,-83.355385,72.591927,-3.888012,-86.249222
4,75.549881,-59.201080,65.292923,11.537738,-89.077942,94.369057,-6.118632,71.571556
...,...,...,...,...,...,...,...,...
99995,52.765034,-74.727470,-38.952068,-71.071815,-14.429907,72.474571,-40.082649,-29.646507
99996,34.372471,-14.616413,-48.689114,60.765034,-73.032761,-78.029633,9.310032,-95.530640
99997,37.135468,16.901632,-9.638972,45.678440,-36.138901,44.282181,-84.911003,21.507885
99998,71.974777,-88.207649,-63.707054,8.195306,-26.512190,-63.356720,-69.681389,51.659626


In [8]:
df.to_csv("./data/dataset.csv", index=False)