In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import pdb
import time
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import numpy as np
import torch

# For visualizer
import rospy
from Data.utils import *
from visualization_msgs.msg import *
from Models.DiscreteBKI import *
rospy.init_node('talker',disable_signals=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if device == "cuda":
  start = torch.cuda.Event(enable_timing=True)
  end = torch.cuda.Event(enable_timing=True)
else:
  start = None
  end = None
print("device is ", device)
    
home_dir = os.path.expanduser('~')
dataset_loc = os.path.join(home_dir, "Data/Rellis-3D")

device is  cuda


In [3]:
bki_map = DiscreteBKI(
    torch.tensor([256, 256, 16]).to(device), # Grid size
    torch.tensor([-25.6, -25.6, -2.0]).to(device), # Lower bound
    torch.tensor([25.6, 25.6, 1.2]).to(device), # Upper bound
    device=device
)

# Add visualization
map_pub = rospy.Publisher('SemMap', MarkerArray, queue_size=10)

In [None]:
# Load point cloud from RELLIS
velo_loc = os.path.join(dataset_loc, "00000", "os1_cloud_node_kitti_bin")
label_base_loc = os.path.join(dataset_loc, "00000", "salsa", "os1_cloud_node_semantickitti_label_id")
voxel_loc = os.path.join(dataset_loc,  "00000", "voxels")
os_files = os.listdir(velo_loc)

current_map = bki_map.initialize_grid()

curr_frame_id=0
end_frame_id=307
for velo_file in sorted(os_files):
    if curr_frame_id!=end_frame_id:
        curr_frame_id +=1
        continue
    print("Visualizing frame ", curr_frame_id)
    velo = np.fromfile(os.path.join(velo_loc, velo_file), dtype=np.float32).reshape(-1, 4)[:, :3]
    velo = torch.from_numpy(velo).to(device)
    labels = np.fromfile(os.path.join(label_base_loc, velo_file.split(".")[0]+".label"), dtype=np.uint32)
    voxels = np.fromfile(os.path.join(voxel_loc, velo_file.split(".")[0]+".label"), dtype=np.uint8)
    voxels = torch.from_numpy(LABELS_REMAP[voxels]).to(device).reshape(256, 256, 16)
    labels_remapped = torch.from_numpy(LABELS_REMAP[labels]).to(device=device) # Remap labels to be contiguous
    # pdb.set_trace()
    # Ego vehicle = 0
    non_void = (labels_remapped != 0) & (labels_remapped != 20)
    velo = velo[non_void]
    labels_remapped = labels_remapped[non_void]
    
    labeled_pc = torch.hstack( (velo, labels_remapped.reshape(-1, 1)) )
    
    posterior_map = bki_map(current_map, labeled_pc)

    publish_voxels(voxels, map_pub, 
        bki_map.centroids, 
        bki_map.min_bound.reshape(-1), 
        bki_map.max_bound.reshape(-1), 
        bki_map.grid_size.reshape(-1)
    )
    # publish_pc(velo, labels_remapped, map_pub,
    #     bki_map.min_bound.reshape(-1), 
    #     bki_map.max_bound.reshape(-1), 
    #     bki_map.grid_size.reshape(-1)
    # )

    if curr_frame_id==end_frame_id:
        break
    curr_frame_id += 1

In [None]:

print(torch.unique(torch.argmax(posterior_map, dim=-1), return_counts=True))

In [None]:
# Test 3D conv

num_classes = 20

# X, Y, Z
filters = torch.zeros(27, dtype=torch.float)
filters[13] = 1
filters = filters.view(1, 1, 3, 3, 3)

print(filters[0, 0, 1, :, :])

inputs = torch.ones(num_classes, 1, 5, 5, 5)

output = F.conv3d(inputs, filters, padding="same")
print(output[0, 0, :, :, 0])

In [7]:
import pdb
import time
import torch

from torch.utils.data import Dataset, DataLoader
from Data.dataset import Rellis3dDataset
from model_utils import *

train_dir = dataset_loc

rellis_ds = Rellis3dDataset(directory=train_dir, device=device, num_frames=10, remap=True, use_gt=False, model_setting="train")
dataloader_train = DataLoader(rellis_ds, batch_size=1, shuffle=False, collate_fn=rellis_ds.collate_fn, num_workers=2)

idx = 0
current_map = bki_map.initialize_grid()
curr_time = time.time()

total_class_counts = torch.zeros((21,), device=device, dtype=torch.long)
# Test running bki map with initial filter on
for points, points_labels, voxels, invalid_voxels, occupied_voxels in dataloader_train:
    batch_voxels_labels = torch.zeros((0, 1), device=device, dtype=torch.uint8)
    batch_preds = torch.zeros((0, 21), device=device, dtype=torch.float16)
    print("iteration ", idx)
    for f in range(len(points)):
        pc_np = np.vstack(np.array(points[f]))
        labels_np = np.vstack(np.array(points_labels[f]))
        labeled_pc = torch.from_numpy( np.hstack((pc_np, labels_np)) ).to(device)

        if labeled_pc.shape[0]==0: # Zero padded
            print("continue")
            continue
        
        # Publish each point cloud to rviz
        # print("Elapsed time for collating ", time.time() - curr_time)
        # curr_time = time.time()
        preds = bki_map(current_map, labeled_pc)
        prior_mask = torch.logical_not(torch.all(preds==bki_map.prior, dim=-1))
        # print("Elapsed time for forward pass ", time.time() - curr_time)
        # curr_time = time.time()

        voxels_np = np.array(voxels[f]).astype(np.uint8)
        voxels_labels= torch.from_numpy(
            voxels_np
        ).to(device)

        invalid_voxels_np = np.array(invalid_voxels[f]).astype(np.bool)
        valid_voxels_mask = torch.logical_not(
            torch.from_numpy(
                invalid_voxels_np
            ).to(device, dtype=torch.bool)
        )
        void_mask = voxels_labels!=0

        # Exclude free space, invalid voxels, and nonupdated map cells
        voxels_mask = void_mask & valid_voxels_mask & prior_mask
        valid_voxels_labels = voxels_labels[voxels_mask]
        preds_masked = preds[voxels_mask]

        # Uncomment below to get split class counts
        label_indices, label_counts = torch.unique(valid_voxels_labels, return_counts=True)
        label_indices = label_indices.type(torch.long)
        total_class_counts[label_indices] += label_counts

        # Uncomment below to test dataloader
        # occupied_voxels_mask = (voxels_labels!=0) # & (voxels_labels!=20)
        
        # invalid_voxels_np = np.array(invalid_voxels[f]).astype(np.bool)
        # valid_voxels_mask = torch.logical_not(
        #     torch.from_numpy(
        #         invalid_voxels_np
        #     ).to(device, dtype=torch.bool)
        # )

        # void_mask = voxels_labels != 0
        # prior_mask = torch.logical_not(torch.all(preds==bki_map.prior, dim=-1))

        # # Exclude free space, invalid voxels, and nonupdated map cells
        # voxels_mask = void_mask & valid_voxels_mask & prior_mask & occupied_voxels_mask
        # valid_voxels_labels = voxels_labels[voxels_mask]
        # preds_masked = preds[voxels_mask]

        # valid_voxels_labels       = valid_voxels_labels.view(-1, 1)
        # batch_voxels_labels = torch.vstack((batch_voxels_labels, valid_voxels_labels))
        # expected_preds = preds_masked / torch.sum(preds_masked, dim=-1, keepdim=True)
        # batch_preds = torch.vstack((batch_preds, expected_preds))

        # max_preds = torch.argmax(expected_preds, dim=-1, keepdim=True)
        # inter, union = iou_one_frame(max_preds, valid_voxels_labels)

        # if idx%20==0:
        #     union[union==0] += 1
        #     print("Intersection over union ", inter/union )
        #     publish_voxels(voxels_labels, map_pub, 
        #         bki_map.centroids,
        #         bki_map.min_bound.reshape(-1),
        #         bki_map.max_bound.reshape(-1),
        #         bki_map.grid_size.reshape(-1), valid_voxels_mask=voxels_mask)
        #     pdb.set_trace()
        #     publish_voxels(preds, map_pub, 
        #         bki_map.centroids,
        #         bki_map.min_bound.reshape(-1),
        #         bki_map.max_bound.reshape(-1),
        #         bki_map.grid_size.reshape(-1), valid_voxels_mask=voxels_mask)
        #     pdb.set_trace()
        #     print("Elapsed publish map time ", time.time() - curr_time)

    idx += 1
    curr_time = time.time()

print("Class counts for train split ", total_class_counts)


iteration  0
iteration  1
iteration  2
iteration  3
iteration  4
iteration  5
iteration  6
iteration  7
iteration  8
iteration  9
iteration  10
iteration  11
iteration  12
iteration  13
iteration  14
iteration  15
iteration  16
iteration  17
iteration  18
iteration  19
iteration  20
iteration  21
iteration  22
iteration  23
iteration  24
iteration  25
iteration  26
iteration  27
iteration  28
iteration  29
iteration  30
iteration  31
iteration  32
iteration  33
iteration  34
iteration  35
iteration  36
iteration  37
iteration  38
iteration  39
iteration  40
iteration  41
iteration  42
iteration  43
iteration  44
iteration  45
iteration  46
iteration  47
iteration  48
iteration  49
iteration  50
iteration  51
iteration  52
iteration  53
iteration  54
iteration  55
iteration  56
iteration  57
iteration  58
iteration  59
iteration  60
iteration  61
iteration  62
iteration  63
iteration  64
iteration  65
iteration  66
iteration  67
iteration  68
iteration  69
iteration  70
iteration  71
it

In [9]:
import os
import pdb
import json
import time
import torch
import torch.optim as optim
from model_utils import *
from torch import nn
from Models.FocalLoss import FocalLoss
from torch.utils.data import Dataset, DataLoader
from Data.dataset import Rellis3dDataset, ray_trace_batch
from torch.utils.tensorboard import SummaryWriter

#CONSTANTS
SEED = 42
NUM_CLASSES = colors.shape[0]
TRAIN_DIR = dataset_loc
NUM_FRAMES = 8
MODEL_NAME = "DiscreteBKI"
model_name = MODEL_NAME + "_" + str(NUM_CLASSES)

MODEL_RUN_DIR = os.path.join("Models", "Runs", model_name)
TRIAL_NUM = str(len(os.listdir(MODEL_RUN_DIR)))
NUM_WORKERS = 16
EPOCH_NUM = 500
FLOAT_TYPE = torch.float32
LABEL_TYPE = torch.uint8


#Model Parameters
class_frequencies = CLASS_COUNTS_REMAPPED
class_frequencies[DYNAMIC_LABELS] = 0 # dynamic object and void are filtered out
epsilon_w = 1e-6  # eps to avoid zero division
weights = torch.from_numpy( \
    (1 / np.log(class_frequencies + epsilon_w) ) *
    np.sum(class_frequencies)/class_frequencies.shape[0]
).to(torch.float32)
criterion = FocalLoss(2, weights, device)
    # nn.CrossEntropyLoss(weight=weights.to(device))

scenes = [ s for s in sorted(os.listdir(TRAIN_DIR)) if s.isdigit() ]
model_params_file = os.path.join(TRAIN_DIR, scenes[-1], 'params.json')
with open(model_params_file) as f:
    grid_params = json.load(f)
    grid_params['grid_size'] = [ int(p) for p in grid_params['grid_size'] ]

# Load model
lr = 1e-1
BETA1 = 0.9
BETA2 = 0.999
model, B, decayRate = get_model(MODEL_NAME, grid_params=grid_params, device=device)

rellis_ds = Rellis3dDataset(directory=TRAIN_DIR, device=device, num_frames=NUM_FRAMES, remap=True, use_aug=True)
dataloader = DataLoader(rellis_ds, batch_size=B, shuffle=True, collate_fn=rellis_ds.collate_fn, num_workers=NUM_WORKERS)

rellis_ds_val  = Rellis3dDataset(directory=TRAIN_DIR, device=device, num_frames=NUM_FRAMES, remap=True, use_aug=True, model_setting="val")
dataloader_val = DataLoader(rellis_ds, batch_size=B, shuffle=True, collate_fn=rellis_ds.collate_fn, num_workers=NUM_WORKERS)

trial_dir = os.path.join(MODEL_RUN_DIR, "t"+TRIAL_NUM)
save_dir = os.path.join("Models", "Weights", model_name, "t"+TRIAL_NUM)

if not os.path.exists(trial_dir):
    os.mkdir(trial_dir)
if not os.path.exists(save_dir):
    os.mkdir(save_dir)

writer = SummaryWriter(os.path.join(MODEL_RUN_DIR, "t"+TRIAL_NUM))

# Optimizer setup
setup_seed(SEED)
optimizer = optim.Adam(model.parameters(), lr=lr, betas=(BETA1, BETA2))
my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decayRate)
# torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,
#     T_max=100, eta_min=1e-4, verbose=True)

#torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decayRate)

train_count = 0
current_map = model.initialize_grid()
batch_idx = 0
# print(list(model.parameters()))
for epoch in range(EPOCH_NUM):
    # Training
    model.train()

    # start_time = time.time()
    for points, points_labels, voxels, invalid_voxels, _ in dataloader:
        batch_voxels_labels = torch.zeros((0, 1), device=device, dtype=LABEL_TYPE)
        batch_preds = torch.zeros((0, NUM_CLASSES), device=device, dtype=FLOAT_TYPE)

        # curr_time = time.time()
        # print("Collate time ", curr_time - start_time)
        optimizer.zero_grad()
        for f in range(len(points)):
            pc_np = np.vstack(np.array(points[f]))
            labels_np = np.vstack(np.array(points_labels[f]))
            labeled_pc = np.hstack((pc_np, labels_np))

             # Sample from free space
            fs_pc = ray_trace_batch(pc_np, labels_np, 0.3, device)
            labeled_pc = torch.from_numpy(np.vstack( (labeled_pc, fs_pc) ) ).to(device=device)

            if labeled_pc.shape[0]==0: # Zero padded
                print("continue")
                continue
            
            preds = model(current_map, labeled_pc)

            prior_mask = torch.logical_not(torch.all(preds==model.prior, dim=-1))

            voxels_np = np.array(voxels[f]).astype(np.uint8)
            voxels_labels= torch.from_numpy(
                voxels_np
            ).to(device)
            void_mask = voxels_labels!=0
            
            invalid_voxels_np = np.array(invalid_voxels[f]).astype(np.bool)
            valid_voxels_mask = torch.logical_not(
                torch.from_numpy(
                    invalid_voxels_np
                ).to(device, dtype=torch.bool)
            )

            # Exclude free space, invalid voxels, and nonupdated map cells
            voxels_mask = void_mask & valid_voxels_mask & prior_mask
            valid_voxels_labels = voxels_labels[voxels_mask]
            preds_masked = preds[voxels_mask]

            valid_voxels_labels       = valid_voxels_labels.view(-1, 1)
            batch_voxels_labels = torch.vstack((batch_voxels_labels, valid_voxels_labels))
            expected_preds = preds_masked / torch.sum(preds_masked, dim=-1, keepdim=True)
            batch_preds = torch.vstack((batch_preds, expected_preds))

        # # Debugging mask 
        # publish_voxels(preds, map_pub, 
        #     bki_map.centroids,
        #     bki_map.min_bound.reshape(-1),
        #     bki_map.max_bound.reshape(-1),
        #     bki_map.grid_size.reshape(-1),
        #     valid_voxels_mask=valid_voxels_mask.flatten())
        # pdb.set_trace()
        # publish_pc(
        #     frame_labeled_pc[:, :3], frame_labeled_pc[:, 3].detach(), map_pub,
        #     bki_map.min_bound.reshape(-1),
        #     bki_map.max_bound.reshape(-1),
        #     bki_map.grid_size.reshape(-1)
        # )
        
        batch_voxels_labels = batch_voxels_labels.reshape(-1)
        loss = criterion(batch_preds, batch_voxels_labels.long())
        loss.backward()
        optimizer.step()

        # AccuracyS
        with torch.no_grad():
            # Softmax on expectation
            max_batch_preds = torch.argmax(batch_preds, dim=-1)
            preds_masked = max_batch_preds.cpu().numpy()
            voxels_np = batch_voxels_labels.detach().cpu().numpy()
            accuracy = np.sum(preds_masked == voxels_np) / voxels_np.shape[0]

            inter, union = iou_one_frame(max_batch_preds, batch_voxels_labels, n_classes=NUM_CLASSES)
            inter = inter[union > 0]
            union = union[union > 0]

        # Record
        writer.add_scalar(MODEL_NAME + '/Loss/Train', loss.item(), train_count)
        writer.add_scalar(MODEL_NAME + '/Accuracy/Train', accuracy, train_count)
        writer.add_scalar(MODEL_NAME + '/mIoU/Train', np.mean(inter/union), train_count)
        # print("Memory allocated ", torch.cuda.memory_allocated(device=device)/1e9)
        # print("Memory reserved ", torch.cuda.memory_reserved(device=device)/1e9)
            
        train_count += len(points)

        # if train_count > 20:
        #     break # For debugging

    # Save model, decrease learning rate
    my_lr_scheduler.step()
    torch.save(model.state_dict(), os.path.join(save_dir, "Epoch" + str(epoch) + ".pt"))

    print("Testing inference on validation...")

    model.eval()
    with torch.no_grad():
        running_loss = 0.0
        counter = 0
        val_iter = 0
        num_correct = 0
        num_total = 0
        all_intersections = np.zeros(NUM_CLASSES)
        all_unions = np.zeros(NUM_CLASSES) # SMOOTHING

        for points, points_labels, voxels, invalid_voxels, _ in dataloader_val:
            batch_voxels_labels = torch.zeros((0, 1), device=device, dtype=LABEL_TYPE)
            batch_preds = torch.zeros((0, NUM_CLASSES), device=device, dtype=FLOAT_TYPE)
            for f in range(len(points)):
                pc_np = np.vstack(np.array(points[f]))
                labels_np = np.vstack(np.array(points_labels[f]))
                labeled_pc = np.hstack((pc_np, labels_np))

                # Sample from free space
                fs_pc = ray_trace_batch(pc_np, labels_np, 0.3, device)
                labeled_pc = torch.from_numpy(np.vstack( (labeled_pc, fs_pc) ) ).to(device=device)
                
                if labeled_pc.shape[0]==0: # Zero padded
                    print("continue")
                    continue
                
                preds = model(current_map, labeled_pc)

                prior_mask = torch.logical_not(torch.all(preds==model.prior, dim=-1))

                voxels_np = np.array(voxels[f]).astype(np.uint8)
                voxels_labels= torch.from_numpy(
                    voxels_np
                ).to(device)
                occupied_voxels_mask = (voxels_labels!=0) #& (voxels_labels!=20)
                
                invalid_voxels_np = np.array(invalid_voxels[f]).astype(np.bool)
                valid_voxels_mask = torch.logical_not(
                    torch.from_numpy(
                        invalid_voxels_np
                    ).to(device, dtype=torch.bool)
                )

                # Exclude free space, invalid voxels, and nonupdated map cells
                voxels_mask = occupied_voxels_mask & valid_voxels_mask & prior_mask
                valid_voxels_labels = voxels_labels[voxels_mask]
                preds_masked = preds[voxels_mask]

                valid_voxels_labels       = valid_voxels_labels.view(-1, 1)
                batch_voxels_labels = torch.vstack((batch_voxels_labels, valid_voxels_labels))
                expected_preds = preds_masked / torch.sum(preds_masked, dim=-1, keepdim=True)
                batch_preds = torch.vstack((batch_preds, expected_preds))

            batch_voxels_labels = batch_voxels_labels.reshape(-1)
            loss = criterion(batch_preds, batch_voxels_labels.long())
            running_loss += loss
            counter += batch_preds.shape[0]

            # Softmax on expectation
            max_batch_preds = torch.argmax(batch_preds, dim=-1)
            max_batch_preds_np = max_batch_preds.detach().cpu().numpy()
            voxels_np = batch_voxels_labels.detach().cpu().numpy()
            num_correct += np.sum(max_batch_preds_np == voxels_np)
            num_total += voxels_np.shape[0]

            inter, union = iou_one_frame(max_batch_preds, batch_voxels_labels, n_classes=NUM_CLASSES)
            try:
                all_intersections += inter
                all_unions += union
            except Exception as e:
                pdb.set_trace()

            # Record
            temp_intersections = all_intersections[all_unions > 0]
            temp_unions = all_unions[all_unions > 0]
            if val_iter%200:      
                print(f'Epoch Num: {epoch} ------ average val loss: {running_loss/counter}')
                print(f'Epoch Num: {epoch} ------ average val accuracy: {num_correct/num_total}')
                print(f'Epoch Num: {epoch} ------ val miou: {np.mean(temp_intersections / temp_unions)}')
            val_iter += 1
            # writer.add_scalar(MODEL_NAME + '/Loss/Val', running_loss/counter, epoch)
            # writer.add_scalar(MODEL_NAME + '/Accuracy/Val', num_correct/num_total, epoch)
            # writer.add_scalar(MODEL_NAME + '/mIoU/Val', np.mean(all_intersections / all_unions), epoch)

        # Log Epoch
        all_intersections = all_intersections[all_unions > 0]
        all_unions = all_unions[all_unions > 0]
        print(f'Epoch Num: {epoch} ------ average val loss: {running_loss/counter}')
        print(f'Epoch Num: {epoch} ------ average val accuracy: {num_correct/num_total}')
        print(f'Epoch Num: {epoch} ------ val miou: {np.mean(all_intersections / all_unions)}')
        writer.add_scalar(MODEL_NAME + '/Loss/Val', running_loss/counter, epoch)
        writer.add_scalar(MODEL_NAME + '/Accuracy/Val', num_correct/num_total, epoch)
        writer.add_scalar(MODEL_NAME + '/mIoU/Val', np.mean(all_intersections / all_unions), epoch)
        # preds = torch.argmax(preds / torch.sum(preds, dim=-1, keepdim=True), dim=-1)
        # publish_voxels(preds, map_pub, 
        #     model.centroids,
        #     model.min_bound.reshape(-1),
        #     model.max_bound.reshape(-1),
        #     model.grid_size.reshape(-1), valid_voxels_mask=voxels_mask)
        # pdb.set_trace()             
        # publish_voxels(voxels_labels, map_pub, 
        #     model.centroids,
        #     model.min_bound.reshape(-1),
        #     model.max_bound.reshape(-1),
        #     model.grid_size.reshape(-1), valid_voxels_mask=voxels_mask)
        # pdb.set_trace()

    print("Epoch ", epoch, " out of ", EPOCH_NUM, " complete.")

writer.close()
        

--Return--
None
> [0;32m/tmp/ipykernel_7578/1982406713.py[0m(40)[0;36m<cell line: 40>[0;34m()[0m
[0;32m     38 [0;31m[0mcriterion[0m [0;34m=[0m [0mFocalLoss[0m[0;34m([0m[0;36m2[0m[0;34m,[0m [0mweights[0m[0;34m,[0m [0mdevice[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     39 [0;31m    [0;31m# nn.CrossEntropyLoss(weight=weights.to(device))[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 40 [0;31m[0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     41 [0;31m[0mscenes[0m [0;34m=[0m [0;34m[[0m [0ms[0m [0;32mfor[0m [0ms[0m [0;32min[0m [0msorted[0m[0;34m([0m[0mos[0m[0;34m.[0m[0mlistdir[0m[0;34m([0m[0mTRAIN_DIR[0m[0;34m)[0m[0;34m)[0m [0;32mif[0m [0ms[0m[0;34m.[0m[0misdigit[0m[0;34m([0m[0;34m)[0m [0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     42 [0;31m[0mmodel_params_file[0m [0;34m=[0m [0mos[0m[0;34m.[0m[0mpath[0m[0;34m.[0m[0mjoin

BdbQuit: 

: 

In [None]:
from torch.autograd import gradcheck
from Models.BKIConvFilter import BKIConvFilter

#Sanity check BKI Conv filter works
conv_filter = BKIConvFilter.apply
mid = torch.tensor([1], dtype=torch.long, requires_grad=False)
weights = torch.randn((1, 1, 3, 3, 3), dtype=torch.double, requires_grad=True)
input = (weights, mid)

# Pytorch autograd using gradient against analytical
gfilter = gradcheck(conv_filter, input, eps=1e-6, atol=1e-4)
print("filter ", gfilter)


In [None]:
from torch.autograd import gradcheck
from model_utils import *
from Models.BKIConvFilter import BKIConvFilter

optimizer = optim.Adam(bki_map.parameters(), lr=1e-1, betas=(0.9, 0.999))

#Sanity check BKI Conv filter works
conv_filter = BKIConvFilter.apply
mid = torch.tensor([1], dtype=torch.long, requires_grad=False)
weights = torch.randn((1, 1, 3, 3, 3), dtype=torch.float, device=device, requires_grad=True)
map = bki_map.initialize_grid()
pc = torch.randint(0, 10, (12, 4), device=device)

filters = bki_map.bki_conv_filter(bki_map.weights, mid)

update = bki_map(map, pc)
# update = torch.unsqueeze(update.permute(3, 0, 1, 2), 1)
# update = F.conv3d(update, filters, padding="same")
# update = torch.squeeze(update).permute(1, 2, 3, 0)


loss = torch.sum(update)
loss.backward()

print("orig weights ", bki_map.weights[0,0,0,0,0])
print("grad ", bki_map.weights.grad[0,0,0,0,0])
optimizer.step()
print("new weights ", bki_map.weights[0,0,0,0,0])
pdb.set_trace()
print("loss")