In [1]:
from models import ResnetEncoder_Streamlines
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


In [2]:
batch_size = 8
inputs_shape = (batch_size, 45, 6, 6, 6)
labels_shape = (batch_size,)
inputs = torch.randn(inputs_shape).to(device)
labels = torch.randint(low=0, high=26, size=labels_shape).to(device)

cnn_flattened_size = inputs_shape[1] * inputs_shape[2] * inputs_shape[3] * inputs_shape[4]

print("inputs shape", inputs.shape)
print("labels shape", labels.shape)

inputs shape torch.Size([8, 45, 6, 6, 6])
labels shape torch.Size([8])


In [3]:
num_classes = 27
model = ResnetEncoder_Streamlines(num_classes=num_classes, task='classification')

In [4]:
model = model.to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

model.train()

ResnetEncoder_Streamlines(
  (channelwise_conv): Sequential(
    (0): ReflectionPad3d((3, 3, 3, 3, 3, 3))
    (1): Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(1, 1, 1), bias=False)
    (2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU(inplace=True)
    (4): Conv3d(64, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (5): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): ResnetBlock(
      (res_block): Sequential(
        (0): ReflectionPad3d((1, 1, 1, 1, 1, 1))
        (1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), bias=False)
        (2): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU(inplace=True)
        (4): ReflectionPad3d((1, 1, 1, 1, 1, 1))
        (5): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), bias=False)
        (6): BatchNorm3d(256, eps=1e-05, mo

In [5]:
previous_prediction_1 = torch.randn((batch_size, num_classes))
previous_prediction_2 = torch.randn((batch_size, num_classes))
# Concatenate the previous predictions together along dimension 2
previous_predictions = torch.cat((previous_prediction_1, previous_prediction_2), dim=1).cuda()
print("Previous prediction size outside", previous_predictions.shape)

original_shapes = torch.tensor([129, 171, 119])

Previous prediction size outside torch.Size([8, 54])


In [6]:
outputs = model(inputs, previous_predictions, original_shapes)

In [7]:
# Print the sum of each row in the matrix
print(outputs.cpu().detach().numpy().sum(axis=1))

[1.         1.0000001  1.0000001  0.99999994 1.         1.
 1.0000001  1.        ]


In [6]:
for epoch in range(10):
    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs = model(inputs, previous_predictions, original_shapes)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    # print statistics
    print('[%d] loss: %.3f' % (epoch + 1, loss.item()))

print('Finished Training')

[1] loss: 3.318
[2] loss: 3.312
[3] loss: 3.290
[4] loss: 3.265
[5] loss: 3.251
[6] loss: 3.242
[7] loss: 3.236
[8] loss: 3.230
[9] loss: 3.222
[10] loss: 3.214
Finished Training
