In [1]:
import torch
import torch.nn as nn

In [2]:
decoder = nn.Linear(50,100)

In [3]:
decoder_input = torch.rand(32,100,50)
decoder_output = decoder(decoder_input)
print(decoder_output.shape)

torch.Size([32, 100, 100])


In [4]:
loss_input = decoder_output.reshape(-1, 100)
print(loss_input[0])
print(loss_input.shape)

tensor([-0.2922, -0.2054, -0.3956, -0.3311,  0.0612,  0.4051, -0.2804, -0.1187,
        -0.3056, -0.3681,  0.3640,  0.3437,  0.2199, -0.1579,  0.1786,  0.1638,
         0.0344,  0.0934,  0.0931, -0.4408, -0.8414,  0.0593,  0.2242,  0.1501,
        -0.4293,  0.0312, -0.2293, -0.3472, -0.1272, -0.2919,  0.5080,  0.1690,
         0.4366,  0.4158, -0.0054,  0.2943, -0.5481, -0.0244, -0.1767,  0.1667,
         0.3313,  0.3906, -0.4446,  0.5121, -0.1443,  0.0738, -0.1822, -0.0846,
        -0.7640, -0.4090,  0.3801, -1.0907, -0.0948,  0.1942, -0.2738, -0.2938,
         0.0042, -0.0948, -0.2106, -0.3535,  0.1637, -0.0181, -0.3271,  0.0500,
         0.4417,  0.0418,  0.2129,  0.1864,  0.1630, -0.5445, -0.1679,  0.0298,
         0.2825, -0.8279,  0.4854, -1.0426,  0.2697, -0.2426,  0.7397, -0.1191,
        -0.0151, -0.1437,  0.5148,  0.3006,  0.0264,  0.1464, -0.4273, -0.1789,
        -0.4207, -0.2534,  0.6260, -0.1082, -0.3113, -0.0452,  0.0843, -0.4715,
        -0.4042, -0.1171, -0.1131,  0.32

In [5]:
def get_bucket_limits(num_outputs:int, full_range:tuple=None, ys:torch.Tensor=None):
    assert (ys is not None) or (full_range is not None)
    if ys is not None:
        ys = ys.flatten()
        if len(ys) % num_outputs: ys = ys[:-(len(ys) % num_outputs)]
        print(f'Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys.')
        ys_per_bucket = len(ys) // num_outputs
        if full_range is None:
            full_range = (ys.min(), ys.max())
        else:
            assert full_range[0] <= ys.min() and full_range[1] >= ys.max()
            full_range = torch.tensor(full_range)
        ys_sorted, ys_order = ys.sort(0)
        bucket_limits = (ys_sorted[ys_per_bucket-1::ys_per_bucket][:-1]+ys_sorted[ys_per_bucket::ys_per_bucket])/2
        print(full_range)
        bucket_limits = torch.cat([full_range[0].unsqueeze(0), bucket_limits, full_range[1].unsqueeze(0)],0)

    else:
        class_width = (full_range[1] - full_range[0]) / num_outputs
        bucket_limits = torch.cat([full_range[0] + torch.arange(num_outputs).float()*class_width, torch.tensor(full_range[1]).unsqueeze(0)], 0)

    assert len(bucket_limits) - 1 == num_outputs and full_range[0] == bucket_limits[0] and full_range[-1] == bucket_limits[-1]
    return bucket_limits

In [6]:
borders=get_bucket_limits(100, full_range=(-5,5))
print(borders)

tensor([-5.0000, -4.9000, -4.8000, -4.7000, -4.6000, -4.5000, -4.4000, -4.3000,
        -4.2000, -4.1000, -4.0000, -3.9000, -3.8000, -3.7000, -3.6000, -3.5000,
        -3.4000, -3.3000, -3.2000, -3.1000, -3.0000, -2.9000, -2.8000, -2.7000,
        -2.6000, -2.5000, -2.4000, -2.3000, -2.2000, -2.1000, -2.0000, -1.9000,
        -1.8000, -1.7000, -1.6000, -1.5000, -1.4000, -1.3000, -1.2000, -1.1000,
        -1.0000, -0.9000, -0.8000, -0.7000, -0.6000, -0.5000, -0.4000, -0.3000,
        -0.2000, -0.1000,  0.0000,  0.1000,  0.2000,  0.3000,  0.4000,  0.5000,
         0.6000,  0.7000,  0.8000,  0.9000,  1.0000,  1.1000,  1.2000,  1.3000,
         1.4000,  1.5000,  1.6000,  1.7000,  1.8000,  1.9000,  2.0000,  2.1000,
         2.2000,  2.3000,  2.4000,  2.5000,  2.6000,  2.7000,  2.8000,  2.9000,
         3.0000,  3.1000,  3.2000,  3.3000,  3.4000,  3.5000,  3.6000,  3.7000,
         3.8000,  3.9000,  4.0000,  4.1000,  4.2000,  4.3000,  4.4000,  4.5000,
         4.6000,  4.7000,  4.8000,  4.90

In [7]:
bucket_log_probs = torch.log_softmax(loss_input, -1)
print(bucket_log_probs[0])
print(bucket_log_probs.shape)

bucket_widths = borders[1:] - borders[:-1]
scaled_bucket_log_probs = bucket_log_probs - torch.log(bucket_widths)
print(scaled_bucket_log_probs[0])
print(scaled_bucket_log_probs.shape)

tensor([-4.9032, -4.8164, -5.0066, -4.9421, -4.5498, -4.2059, -4.8914, -4.7297,
        -4.9166, -4.9791, -4.2470, -4.2673, -4.3911, -4.7689, -4.4324, -4.4472,
        -4.5766, -4.5176, -4.5179, -5.0518, -5.4524, -4.5517, -4.3868, -4.4609,
        -5.0403, -4.5798, -4.8403, -4.9582, -4.7382, -4.9029, -4.1030, -4.4420,
        -4.1745, -4.1952, -4.6164, -4.3168, -5.1591, -4.6354, -4.7877, -4.4443,
        -4.2797, -4.2205, -5.0556, -4.0989, -4.7553, -4.5372, -4.7932, -4.6956,
        -5.3750, -5.0201, -4.2309, -5.7017, -4.7058, -4.4168, -4.8848, -4.9048,
        -4.6068, -4.7058, -4.8216, -4.9645, -4.4473, -4.6291, -4.9382, -4.5610,
        -4.1693, -4.5692, -4.3981, -4.4246, -4.4480, -5.1555, -4.7789, -4.5812,
        -4.3285, -5.4389, -4.1256, -5.6536, -4.3413, -4.8536, -3.8713, -4.7301,
        -4.6261, -4.7547, -4.0962, -4.3104, -4.5846, -4.4646, -5.0383, -4.7899,
        -5.0318, -4.8644, -3.9850, -4.7192, -4.9224, -4.6562, -4.5268, -5.0825,
        -5.0152, -4.7281, -4.7241, -4.28

In [8]:
def map_to_bucket_idx(borders, y):
    target_sample = torch.searchsorted(borders, y) - 1
    target_sample[y == borders[0]] = 0
    target_sample[y == borders[-1]] = 100 - 1
    return target_sample

In [13]:
# 위의 input x=[3200, 100] 에 대응하는 y=[3200]
targets = torch.rand(32,100,1).flatten()
print(targets)
print(targets.shape)
target_sample = map_to_bucket_idx(borders, targets)
print(target_sample)
print(target_sample.shape)
print(target_sample.unsqueeze(-1).shape)
print(target_sample[0])
print(-scaled_bucket_log_probs[0,target_sample[0]])
print(target_sample[-2])
print(-scaled_bucket_log_probs[-2,target_sample[-2]])

tensor([0.1873, 0.5338, 0.9764,  ..., 0.9534, 0.5121, 0.1801])
torch.Size([3200])
tensor([51, 55, 59,  ..., 59, 55, 51])
torch.Size([3200])
torch.Size([3200, 1])
tensor(51)
tensor(3.3991, grad_fn=<NegBackward0>)
tensor(55)
tensor(2.9402, grad_fn=<NegBackward0>)


In [20]:
losses = -scaled_bucket_log_probs.gather(-1,target_sample.unsqueeze(-1)).squeeze(-1)
print(losses)
print(losses.shape)


tensor([3.3991, 2.8007, 2.6417,  ..., 2.6891, 2.9402, 2.9419],
       grad_fn=<NegBackward0>)
torch.Size([3200])


In [21]:
losses = losses.view(*decoder_output.shape[0:2]).squeeze(-1)
print(losses)
print(losses.shape)
loss = losses.mean()
print(loss)
print(loss.shape)

tensor([[3.3991, 2.8007, 2.6417,  ..., 2.0208, 2.2635, 2.8726],
        [2.6442, 2.4853, 3.0302,  ..., 2.1433, 2.3291, 2.4181],
        [2.0150, 2.5583, 2.8525,  ..., 2.4902, 2.3388, 2.8720],
        ...,
        [1.9488, 2.7554, 2.6354,  ..., 2.3348, 1.9084, 2.2315],
        [2.4966, 2.2723, 2.2883,  ..., 2.9811, 2.6794, 2.6002],
        [2.2101, 2.2639, 2.5728,  ..., 2.6891, 2.9402, 2.9419]],
       grad_fn=<SqueezeBackward1>)
torch.Size([32, 100])
tensor(2.5179, grad_fn=<MeanBackward0>)
torch.Size([])
