# ShapeNet Dataset explorer

## Plot some samples

In [3]:
import torch
from torch import nn
from torch import optim
import os.path as osp

import pytorch_lightning as pl
import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
from torch_geometric.loader import DataLoader

from pytorch_lightning.loggers import WandbLogger


In [1]:
category = 'Airplane'  # Pass in `None` to train on all categories.
path = osp.realpath(osp.join('..', 'data', 'ShapeNet'))

pre_transform, transform = T.NormalizeScale(), T.FixedPoints(128)
train_dataset = ShapeNet(path, category, split='train', transform=transform, pre_transform=pre_transform)
valid_dataset = ShapeNet(path, category, split='val', transform=transform, pre_transform=pre_transform)
test_dataset = ShapeNet(path, category, split='test', transform=transform, pre_transform=pre_transform)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=6)
valid_loader = DataLoader(valid_dataset, batch_size=4, shuffle=False, num_workers=6)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=6)

NameError: name 'osp' is not defined

In [2]:
data = train_dataset.get(0)

NameError: name 'train_dataset' is not defined

In [None]:
from hpcs.utils.viz import plot_cloud

In [None]:
# data.x contains points features extracted after having applied the transform
# data.pos contains the points coordinates
plotter = plot_cloud(data.x.numpy(), scalars=data.y.numpy(), point_size=3.0, notebook=True)

## Define a Pytorch Lightning module

In [None]:
class SimpleModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.nn = nn.Sequential(
            nn.Linear(3, 64),
            nn.ReLU(),
            nn.Linear(64, 4)
        )
        
    def forward(self, x):
        y_pred = self.nn(x)
        return y_pred
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    def training_step(self, batch, batch_idx):
        x, y = batch.x, batch.y
        y_hat = self(x)
        loss = nn.functional.cross_entropy(y_hat, y)
        self.log("train/loss", loss)
        return loss
        

In [None]:
simpl_mod = SimpleModel()

## Training the model

In [None]:
wandb_logger = WandbLogger(name='SegTest',save_dir='../logs/')

In [None]:
trainer = pl.Trainer(limit_train_batches=10, max_epochs=10,logger=wandb_logger)

In [None]:
trainer.fit(model=simpl_mod, train_dataloaders=train_loader)