In [1]:
import torch
import numpy as np
import sys
sys.path.append('..')
from deepSymmetry.src import load_data

import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from torchvision import datasets, transforms

from se3cnn import SE3Convolution, SE3Dropout, SE3BNConvolution
from se3cnn.blocks import GatedBlock
from se3cnn.non_linearities import ScalarActivation
from se3cnn.dropout import SE3Dropout
from se3cnn import kernel
from se3cnn.filter import low_pass_filter

from tensorflow.python.framework import dtypes

In [5]:
train_name = 'data/dataReady0'
train_set = load_data.read_data_set(train_name, dtype=dtypes.float16, seed = 1)

Extracting data/dataReady0
Extracting data/dataReady0_label
(39785, 13824)


Custom loss:

In [2]:
def weighted_custom_loss(output, target):
    order_out = output[:, 0 : NUM_CLASSES]
    order_target = target[:, 0 : 1].type(torch.LongTensor).squeeze_()
    axis_out = output[:, NUM_CLASSES : NUM_CLASSES + 6]
    axis_target = target[:, 1 : 7]

    loss = 0.5 * nn.CrossEntropyLoss()(order_out, order_target) + nn.MSELoss(reduction='sum')(axis_out, axis_target)
    
    return loss

Residential network:

In [None]:
class ResEquiNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = SE3BNConvolution(repr_in_1, repr_out_1, size=4)
        self.pool1 = nn.AvgPool3d(pool_size, pool_stride)
        self.conv2 = SE3BNConvolution(repr_in_2, repr_out_2, size=4)
        self.pool2 = nn.AvgPool3d(pool_size, pool_stride)
        
        self.lin1 = nn.Linear(n_input_1, n_output_1)
        self.drop1 = nn.Dropout(prob)
        self.lin2 = nn.Linear(n_output_1, n_output_2)
        self.drop2 = nn.Dropout(prob)
        self.lin3 = nn.Linear(n_output_2, NUM_CLASSES+6)

    def forward(self, x):
        prev_layer = x.expand(100,200,24,24,24)
        x = self.pool1(self.conv1(x))
        x = torch.cat([torch.zeros(100,200,7,10,10), x, torch.zeros(100,200,7,10,10)], 2)
        x = torch.cat([torch.zeros(100,200,24,7,10), x, torch.zeros(100,200,24,7,10)], 3)
        x = torch.cat([torch.zeros(100,200,24,24,7), x, torch.zeros(100,200,24,24,7)], 4)
        x = torch.add(x, prev_layer)
        x = self.pool2(self.conv2(x))
        x = x.view(batch_size,-1) 
        x = F.leaky_relu(self.lin1(x))
        x = self.drop1(x)
        x = F.relu(self.lin2(x))
        x = self.drop2(x)
        return self.lin3(x)
    
def train(model, device, train_set, batch_size, optimizer, epoch, per_epoch, decr_rate):
    model.train()
    flag = True
    new_epoch = True
    
    if new_epoch:
        batch_idx = 1
        new_epoch = False
        data, target, _ = train_set.next_batch(batch_size)
        data = torch.from_numpy(data.reshape(batch_size,1,24,24,24)).type(torch.FloatTensor)
        target = torch.from_numpy(target.reshape(batch_size,-1)).type(torch.FloatTensor)
        cnt = epoch // per_epoch
        if ((epoch+1) // per_epoch > cnt) and flag:
            lr = 0
            for param_group in optimizer.param_groups:
                lr = param_group['lr']
            lr *= decr_rate
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            flag = False
        if ((epoch+1) // per_epoch <= cnt):
            flag = True
            
        optimizer.zero_grad()
        output = model(data)
        # loss_fn = nn.MSELoss(reduction='sum')
        # loss = loss_fn(output, target)
        loss = weighted_custom_loss(output, target)
        loss.backward()
        optimizer.step()
        
    while (train_set._index_in_epoch + batch_size) < train_set._num_examples:
        batch_idx += 1
        data, target, _ = train_set.next_batch(batch_size)
        data = torch.from_numpy(data.reshape(batch_size,1,24,24,24)).type(torch.FloatTensor)
        target = torch.from_numpy(target.reshape(batch_size,-1)).type(torch.FloatTensor)
        cnt = epoch // per_epoch
        if ((epoch+1) // per_epoch > cnt) and flag:
            lr = 0
            for param_group in optimizer.param_groups:
                lr = param_group['lr']
            lr *= decr_rate
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            flag = False
        if ((epoch+1) // per_epoch <= cnt):
            flag = True
            
        optimizer.zero_grad()
        output = model(data)
        # loss_fn = nn.MSELoss(reduction='sum')
        # loss = loss_fn(output, target)
        loss = weighted_custom_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * batch_size, train_set._num_examples,
                100. * batch_idx / train_set._num_examples, loss.item()))

def test(model, device, test_set):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_set:
            output = model(data)
            loss_fn = nn.MSELoss(reduction='sum')
            test_loss += loss_fn(output, target) 
            # pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += int(torch.argmax(output) == torch.argmax(target))

    test_loss /= len(test_set)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_set),
        100. * correct / len(test_set)))
    return test_loss

In [10]:
repr_in_1 = [(1,0)]
repr_out_1 = [(2,0),(2,1),(2,2),(2,3),(2,4),(2,5),(2,6),(2,7),(2,8),(2,9)]
repr_in_2 = [(2,0),(2,1),(2,2),(2,3),(2,4),(2,5),(2,6),(2,7),(2,8),(2,9)]
repr_out_2 = [(1,0),(2,1),(2,2),(2,3)]
size = 4
activation = (None, F.leaky_relu)
pool_size = 2
pool_stride = 2
bias = True

n_input_1 = 837
n_output_1 = 1000 
n_output_2 = 50

batch_size = 100
prob = 0.5
NUM_CLASSES = 10

In [None]:
epochs = 1
device = torch.device('cpu')
torch.manual_seed(1)

model_tenth_order = ResEquiNet().to(device)
learning_rate = 5e-3
optimizer = torch.optim.Adam(model_tenth_order.parameters(), lr=learning_rate)

per_epoch = 10
decr_rate = 0.995

for epoch in range(1, epochs + 1):
    train(model_tenth_order, device, train_set, batch_size, optimizer, epoch, per_epoch, decr_rate)
    # test(model_hard, device, test_set)

In [78]:
c = torch.zeros(2,1,3)
c[:,:,1] = torch.Tensor([1])
c

tensor([[[0., 1., 0.]],

        [[0., 1., 0.]]])

In [47]:
def tile(a, dim, n_tile):
    init_dim = a.size(dim)
    repeat_idx = [1] * a.dim()
    repeat_idx[dim] = n_tile
    a = a.repeat(*(repeat_idx))
    order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
    return torch.index_select(a, dim, order_index)

In [50]:
tile(c,0,2)

tensor([[1., 1.],
        [1., 1.],
        [0., 0.],
        [0., 0.]])

In [73]:
torch.cat(list(torch.split(c, 2, dim=1))*2)

tensor([[[0., 1., 0.],
         [0., 2., 0.]],

        [[0., 1., 0.],
         [0., 2., 0.]],

        [[0., 1., 0.],
         [0., 2., 0.]],

        [[0., 1., 0.],
         [0., 2., 0.]]])

In [91]:
torch.cat([torch.ones(2,2,3), c.expand(2,2,3), torch.ones(2,2,3)], 1)

tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [0., 1., 0.],
         [0., 1., 0.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [0., 1., 0.],
         [0., 1., 0.],
         [1., 1., 1.],
         [1., 1., 1.]]])