In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
from spherical_coordinates import *

In [2]:
#Load data
directory_2D = "D:\\CTA data\\Segments_deformed_2\\"
file_names_2D = os.listdir(directory_2D)
directory_3D = "D:\\CTA data\\Segments renamed\\"
file_names_3D = os.listdir(directory_3D)

data_2d = np.genfromtxt(os.path.join(directory_2D, file_names_2D[0]), delimiter=",")
data_3d = np.genfromtxt(os.path.join(directory_3D, file_names_3D[0]), delimiter=",")[1:, 1:4]

#Add row of zeros
data_2d = np.hstack((data_2d, np.zeros((data_2d.shape[0],1))))

#Normalize 3D
data_3d[:, 0] -= np.min(data_3d[:, 0])
data_3d[:, 1] -= np.min(data_3d[:, 1])
data_3d[:, 2] -= np.min(data_3d[:, 2])

#Convert to spherical:
origin_2D, spherical_2D = convert_to_spherical(data_2d)
origin_3D, spherical_3D = convert_to_spherical(data_3d)
spherical_2D

tensor([[0.0228, 1.5708, 0.8024],
        [0.0235, 1.5708, 0.8314],
        [0.0228, 1.5708, 0.7965],
        ...,
        [0.0234, 1.5708, 0.3989],
        [0.0233, 1.5708, 0.3569],
        [0.0234, 1.5708, 0.3985]])

In [3]:
origin_2D = torch.reshape(origin_2D, (3,1)).float()
origin_3D = torch.reshape(origin_3D, (3,1)).float()

spherical_2D = torch.reshape(spherical_2D, (3, 349)).float()
spherical_3D = torch.reshape(spherical_3D, (3, 349)).float()

origin_2D

tensor([[0.],
        [0.],
        [0.]])

In [16]:
origin_branch = nn.Sequential(
            nn.Conv1d(in_channels=3, out_channels=8, kernel_size=1),
            # nn.BatchNorm1d(8),
            nn.ReLU(),
            nn.Conv1d(in_channels=8, out_channels=16, kernel_size=1),
            # nn.BatchNorm1d(16),
            nn.ReLU(),
            nn.Conv1d(in_channels=16, out_channels=32, kernel_size=1),
            # nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Conv1d(in_channels=32, out_channels=64, kernel_size=1),
        )

shape_branch = nn.Sequential(
            nn.Conv1d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
            # nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
        )

class Downsample(nn.Module):
    def __init__(self, in_channels, out):
        super(Downsample, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv1d(in_channels=in_channels, out_channels=64, kernel_size=1),
            # nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
            # nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
            # nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=3, padding=1, stride=2),
            # nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(in_channels=64, out_channels=out, kernel_size=1)
        )

        self.bn_relu_add = nn.Sequential(
            # nn.BatchNorm1d(out),
            nn.ReLU()
        )

        self.conv2 = nn.Sequential(nn.Conv1d(in_channels=in_channels, out_channels=out, kernel_size=3, padding=1, stride=2))

    def forward(self, x):
        return self.bn_relu_add(torch.add(self.conv_layers(x), self.conv2(x)))


org_2 = origin_branch(origin_2D.float())
shape_2 = shape_branch(spherical_2D.float())

org_3 = origin_branch(origin_3D.float())
shape_3 = shape_branch(spherical_3D.float())

features_2 = torch.add(org_2, shape_2)
features_3 = torch.add(org_3, shape_3)

total_features = torch.cat((features_3, features_2), dim=0)
total_features.shape

dwn_sampler_1 = Downsample(128, 128)
dwn_sampler_2 = Downsample(128, 256)
dwn_sampler_3 = Downsample(256, 512)
ds1 = dwn_sampler_1.forward(total_features)
ds2 = dwn_sampler_2.forward(ds1)
ds3 = dwn_sampler_3.forward(ds2)
print(total_features.shape, ds1.shape, ds2.shape, ds3.shape)

torch.Size([128, 349]) torch.Size([128, 175]) torch.Size([256, 88]) torch.Size([512, 44])
