In [1]:
import torch
import torch.nn as nn
from torch.nn import init
from torch.optim import lr_scheduler
import numpy as np

In [2]:
from neuralcompress.utils.tpc_dataloader import get_tpc_dataloaders
from neuralcompress.models.bcae import BCAE
from neuralcompress.models.losses import get_tpc_losses

In [21]:
train_loader, valid_loader, test_loader = get_tpc_dataloaders(
    '/data/datasets/sphenix/highest_framedata_3d/outer/',
    batch_size  = 32,
    train_sz    = 960,
    valid_sz    = 320,
    test_sz     = 320,
    is_random   = True
)

In [26]:
conv_layer_1 = {
    'out_channels': 8,
    'kernel_size' : [4, 3, 3],
    'padding'     : [1, 0, 1],
    'stride'      : [2, 2, 1]
}
conv_layer_2 = {
    'out_channels': 16,
    'kernel_size' : [4, 4, 3],
    'padding'     : [1, 1, 1],
    'stride'      : [2, 2, 1]
}
conv_layer_3 = {
    'out_channels': 32,
    'kernel_size' : [4, 4, 3],
    'padding'     : [1, 1, 1],
    'stride'      : [2, 2, 1]
}
conv_layer_4 = {
    'out_channels': 32,
    'kernel_size' : [4, 3, 3],
    'padding'     : [1, 0, 1],
    'stride'      : [2, 2, 1]
}

# Construct decoder network
deconv_layer_1 = {
    'out_channels': 16,
    'kernel_size' : [4, 3, 3],
    'padding'     : [1, 0, 1],
    'stride'      : [2, 2, 1],
    'output_padding': 0
}
deconv_layer_2 = {
    'out_channels': 8,
    'kernel_size' : [4, 4, 3],
    'padding'     : [1, 1, 1],
    'stride'      : [2, 2, 1],
    'output_padding': 0
}
deconv_layer_3 = {
    'out_channels': 4,
    'kernel_size' : [4, 4, 3],
    'padding'     : [1, 1, 1],
    'stride'      : [2, 2, 1],
    'output_padding': 0
}
deconv_layer_4 = {
    'out_channels': 2,
    'kernel_size' : [4, 3, 3],
    'padding'     : [1, 0, 1],
    'stride'      : [2, 2, 1],
    'output_padding': 0
}
conv_args_list = [
    conv_layer_1, conv_layer_2,
    conv_layer_3, conv_layer_4
]
deconv_args_list = [
    deconv_layer_1, deconv_layer_2,
    deconv_layer_3, deconv_layer_4
]

In [12]:
# conv_layer_1 = {
#     'out_channels': 8,
#     'kernel_size' : [4, 3, 4],
#     'padding'     : [1, 0, 1],
#     'stride'      : [2, 2, 2]
# }
# conv_layer_2 = {
#     'out_channels': 16,
#     'kernel_size' : [4, 4, 4],
#     'padding'     : [1, 1, 1],
#     'stride'      : [2, 2, 2]
# }
# conv_layer_3 = {
#     'out_channels': 32,
#     'kernel_size' : [4, 4, 4],
#     'padding'     : [1, 1, 1],
#     'stride'      : [2, 2, 2]
# }

# # Construct decoder network
# deconv_layer_1 = {
#     'out_channels': 8,
#     'kernel_size' : [4, 4, 4],
#     'padding'     : [1, 1, 1],
#     'stride'      : [2, 2, 2],
#     'output_padding': 0
# }
# deconv_layer_2 = {
#     'out_channels': 4,
#     'kernel_size' : [4, 4, 4],
#     'padding'     : [1, 1, 1],
#     'stride'      : [2, 2, 2],
#     'output_padding': 0
# }
# deconv_layer_3 = {
#     'out_channels': 2,
#     'kernel_size' : [4, 3, 4],
#     'padding'     : [1, 0, 1],
#     'stride'      : [2, 2, 2],
#     'output_padding': 0
# }

# conv_args_list = [conv_layer_1, conv_layer_2, conv_layer_3]
# deconv_args_list = [deconv_layer_1, deconv_layer_2, deconv_layer_3]

In [27]:
bcae = BCAE(
    image_channels   = 1,
    code_channels    = 8,
    conv_args_list   = conv_args_list,
    deconv_args_list = deconv_args_list,
    activ            = {'name': 'leakyrelu', 'negative_slope': .2},
    norm             = 'instance'
).cuda()

def winit_func(m, init_gain=.2):
    classname = m.__class__.__name__
    if (
        hasattr(m, 'weight') and
        (classname.find('Conv') != -1 or classname.find('Linear') != -1)
    ):
        init.xavier_normal_(m.weight.data, init_gain)

bcae.apply(lambda m: winit_func(m))

BCAE(
  (encoder): Encoder(
    (layers): Sequential(
      (encoder_block_0): TPCResidualBlock(
        (main_block): Sequential(
          (0): Conv3d(1, 8, kernel_size=(4, 3, 3), stride=(2, 2, 1), padding=(1, 0, 1))
          (1): LeakyReLU(negative_slope=0.2)
          (2): InstanceNorm3d(8, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (3): Conv3d(8, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        )
        (side_block): Sequential(
          (0): Conv3d(1, 8, kernel_size=(4, 3, 3), stride=(2, 2, 1), padding=(1, 0, 1))
          (1): LeakyReLU(negative_slope=0.2)
          (2): InstanceNorm3d(8, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        )
        (activ): LeakyReLU(negative_slope=0.2)
        (norm): InstanceNorm3d(8, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      )
      (encoder_block_1): TPCResidualBlock(
        (main_block): Sequential(
          (0): Conv3d(8, 16, ke

In [28]:
lr_initial, ratio_initial = 0.01, 20000
ratio, epochs = ratio_initial, 2000

optimizer = torch.optim.AdamW(bcae.parameters(), lr=lr_initial)
# scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=200, eta_min=0.00001)
def get_lr(
    epoch, 
    lr, 
    rate=.95, 
    every_num_epochs=20, 
    minimum_lr=0.0001
):
    lr = lr * (rate ** (epoch // every_num_epochs))
    if lr < minimum_lr:
        lr = minimum_lr
    return lr

def transform(x):
    return torch.exp(x) * 6 + 64

loss_args = {
    'transform'        : transform,
    'weight_pow'       : .1, 
    'clf_threshold'    : .5,
    'target_threshold' : 64,
    'gamma'            : 2, 
    'eps'              : 1e-8
}

In [30]:
metric = nn.MSELoss()
train_losses, valid_losses = [], []
MSEs = []
for e in range(epochs):
    # Adjust learning rate
    lr = get_lr(e, lr_initial)
    for g in optimizer.param_groups:
        g['lr'] = lr
    
    # training
    for x in train_loader:
        x = x.cuda()
        y_c, y_r = bcae(x)
        loss_c, loss_r = get_tpc_losses(y_c, y_r, x, loss_args)
        loss = loss_r +  ratio * loss_c
        train_losses.append([loss_c.item(), loss_r.item(), loss.item()])
        
        # back-propagate
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # scheduler.step()
    
    # validation
    with torch.no_grad():
        for x in valid_loader:
            x = x.cuda()
            y_c, y_r = bcae(x)        
            val_loss_c, val_loss_r = get_tpc_losses(y_c, y_r, x, loss_args)
            val_loss = val_loss_r +  ratio * val_loss_c

            m = metric(transform(y_r) * (y_c > .5), x)
            MSEs.append(m.item())
            valid_losses.append([val_loss_c.item(), val_loss_r.item(), val_loss.item()])

        train_loss_avg = np.mean(train_losses[-len(train_loader):], axis=0)
        valid_loss_avg = np.mean(valid_losses[-len(valid_loader):], axis=0)
        mse_avg = np.mean(MSEs[-len(valid_loader):])
    
    for param_group in optimizer.param_groups:
        lr = param_group['lr']
        break
    
    print(f'Epoch {e}: {lr: .6f}, {ratio:.6f}')
    print(f'\ttrain loss = {train_loss_avg[0]: .6f}, {train_loss_avg[1]: .6f}, {train_loss_avg[2]: .6f}')
    print(f'\tvalid loss = {valid_loss_avg[0]: .6f}, {valid_loss_avg[1]: .6f}, {valid_loss_avg[2]: .6f}, {mse_avg: .6f}\n')
    
    ratio = train_loss_avg[1] / train_loss_avg[0]
    # # scheduler.step(loss)
    # if e % 20 == 19:
    #     torch.save(bcae.state_dict(), f'{path_prefix}/epoch-{e}.pt')

Epoch 0:  0.010000, 59742.071456
	train loss =  0.255234,  18201.121830,  33449.343750
	valid loss =  0.170423,  16212.396191,  26393.806250,  1608.325183



KeyboardInterrupt: 