In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from datasets import CubeObstacle, CylinderObstacle, BlockageDataset
from utils.config import Hyperparameters as hp

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
random_seed = 42
torch.random.manual_seed(random_seed)
np.random.seed(random_seed)

batch_size = 2**10

In [3]:
class Net(nn.Module):
    def __init__(self, num_node, hidden_N, hidden_L, output_N=3):
        super(Net, self).__init__()
        self.hidden_N = hidden_N
        self.hidden_L = hidden_L
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(num_node, hidden_N))
        for _ in range(hidden_L):
            self.layers.append(nn.Linear(hidden_N, hidden_N))

        self.dropouts = nn.ModuleList()
        for _ in range(hidden_L):
            self.dropouts.append(nn.Dropout(0.3))

        self.batches = nn.ModuleList()
        for _ in range(hidden_L):
            self.batches.append(nn.BatchNorm1d(hidden_N))

        self.output = nn.Linear(hidden_N, output_N)
    def forward(self, x):
        z = x
        for layer, dropout, batch_norm in zip(self.layers, self.dropouts, self.batches):
            z = layer(z)
            z = batch_norm(z)
            z = F.leaky_relu(z, 0.05)
            z = dropout(z)

        z = torch.sigmoid(self.output(z))
        return z

In [4]:
obstacle_ls = [
    CubeObstacle(-30, 25, 35, 60, 20, 0.1),
    CubeObstacle(-30, -25, 45, 10, 35, 0.1),
    CubeObstacle(-30, -60, 35, 60, 20, 0.1),
    CubeObstacle(50, -20, 35, 25, 25, 0.1),
    CylinderObstacle(10, -5,  70, 15, 0.1),
]

In [5]:
dataset = BlockageDataset(100000, obstacle_ls, 4, dtype=torch.float32).to(hp.device)
dataset.__len__()

100%|██████████| 100000/100000 [00:01<00:00, 78923.73it/s]


100000

In [6]:
data = dataset.gnd_nodes.cpu().numpy()
data = np.delete(data,2, axis=2).reshape(-1,8)
data

array([[ 48.473988 , -54.333748 , -88.3084   , ..., -66.39599  ,
        -29.040543 , -20.003397 ],
       [-88.52163  ,  16.66145  ,  76.86868  , ...,  32.94784  ,
        -16.133064 ,  40.219517 ],
       [-17.853691 ,   0.9274576, -98.52939  , ..., -92.319824 ,
        -26.34998  ,  95.806366 ],
       ...,
       [ 80.559944 , -48.176567 ,  52.49426  , ...,  67.79585  ,
         74.88452  , -13.334251 ],
       [ 22.601303 ,  88.610504 ,  87.54375  , ..., -92.502686 ,
        -54.01157  ,  21.334442 ],
       [ 12.519815 , -78.57199  , -55.36997  , ...,  74.07114  ,
         94.96994  ,  30.149664 ]], dtype=float32)

In [7]:
df = pd.DataFrame(data)
df

Unnamed: 0,0,1,2,3,4,5,6,7
0,48.473988,-54.333748,-88.308403,-40.014122,-5.138267,-66.395988,-29.040543,-20.003397
1,-88.521629,16.661449,76.868683,-69.651421,19.561604,32.947842,-16.133064,40.219517
2,-17.853691,0.927458,-98.529388,38.344025,32.268536,-92.319824,-26.349979,95.806366
3,-15.852688,0.389081,81.986031,43.452374,16.414026,59.560936,72.873253,-8.622386
4,-4.978327,-22.789812,93.341225,39.433250,-83.355385,72.591927,-3.888012,-86.249222
...,...,...,...,...,...,...,...,...
99995,11.768341,82.952278,54.814850,-22.163610,-75.248955,-9.538043,-41.995087,98.432884
99996,-17.263592,-94.097000,-86.016670,-82.419418,-52.576508,47.734035,45.475471,-86.085663
99997,80.559944,-48.176567,52.494259,73.261276,24.939924,67.795853,74.884521,-13.334251
99998,22.601303,88.610504,87.543747,62.734123,-93.997139,-92.502686,-54.011570,21.334442


In [8]:
df.to_csv("./data/dataset.csv", index=False)