# we are doing this masking procedure and wrote a loss, let's test it

In [1]:
from src.dataloaders.datasets.general_dataset import GeneralDataset
import torch
import numpy as np
import matplotlib.pyplot as plt

dataset = GeneralDataset(
    split='train',
    preprocess=False,
    data_path='/data1/lesliec/sarthak/data/DK_zarr/zarr_arrays/cell_type_arrays/GM12878_DNase.npz',
    data_is_zarr=False,
    sequences_bed_file='/data1/lesliec/sarthak/data/DK_zarr/sequences_enformer.bed',
    length=524288,
    load_in=False
)

dataset_mask_cont = GeneralDataset(
    split='train',
    preprocess=False,
    data_path='/data1/lesliec/sarthak/data/DK_zarr/zarr_arrays/cell_type_arrays/GM12878_DNase.npz',
    data_is_zarr=False,
    sequences_bed_file='/data1/lesliec/sarthak/data/DK_zarr/sequences_enformer.bed',
    length=524288,
    load_in=False,
    mlm=0.25,  # increased masking percentage
    acc_mask=0.25,  # increased accessibility masking percentage
    weight_peaks=True  # weight peaks more (this is the new parameter we added to the dataset class
)

dataset_mask_cat = GeneralDataset(
    split='train',
    preprocess=False,
    data_path='/data1/lesliec/sarthak/data/DK_zarr/zarr_arrays/cell_type_arrays/GM12878_DNase.npz',
    data_is_zarr=False,
    sequences_bed_file='/data1/lesliec/sarthak/data/DK_zarr/sequences_enformer.bed',
    length=524288,
    load_in=False,
    mlm=0.25,  # increased masking percentage
    acc_mask=0.25,  # increased accessibility masking percentage
    acc_type='category',  # categorical accessibility values
)

In [2]:
seq, acc, seq_unmask, acc_unmask = dataset_mask_cont[0]
print('seq:', seq.shape)
print('seq_unmask:', seq_unmask.shape)
print('acc:', acc.shape)
print('acc_unmask:', acc_unmask.shape)

seq: torch.Size([6, 524288])
seq_unmask: torch.Size([524288, 6])
acc: torch.Size([2, 524288])
acc_unmask: torch.Size([524288, 2])


In [3]:
seq = torch.rand(1, 524288, 5)
seq.shape

torch.Size([1, 524288, 5])

In [None]:
seq_unmask = seq_unmask.unsqueeze(0)
seq_unmask.shape

torch.Size([1, 524288, 6])

In [8]:
seq = torch.rand(2, 524288, 5) #now simulate batch size of 2
#duplicate seq_unmask to match batch size
seq_unmask = torch.cat([seq_unmask, seq_unmask], dim=0)
print(seq.shape, seq_unmask.shape)

torch.Size([2, 524288, 5]) torch.Size([2, 524288, 6])


In [None]:
import torch.nn.functional as F
def ce_loss_mask_seq(seq,seq_unmask,acc,acc_unmask):
    '''cross entropy loss function for sequence and accessibility classification
    seq: (batch_size, seq_len, vocab_size)
    seq_unmask: (batch_size, seq_len, vocab_size+1) #the last one is the mask
    acc: Not used
    acc_unmask: Not used
    '''
    
    #mask out useless elements, note this will collapse the batch dimension but that's ok
    mask = seq_unmask[:,:,-1] == 1
    seq = seq[mask]
    seq_unmask = seq_unmask[mask]
    
    # print(seq.shape, seq_unmask.shape)
    
    seq_unmask = seq_unmask[:,:-1] #remove the mask dim
    
    #now compute the loss
    loss = F.cross_entropy(seq, seq_unmask)
    return loss

loss = ce_loss_mask_seq(seq,seq_unmask,acc,acc_unmask)
print('loss:', loss)

torch.Size([261766, 5]) torch.Size([261766, 6])
loss: tensor(1.6420)


In [20]:
seq

tensor([[[0.5911, 0.8155, 0.3790, 0.4603, 0.2671],
         [0.4750, 0.2600, 0.8024, 0.9460, 0.8289],
         [0.1431, 0.6778, 0.2760, 0.9449, 0.1958],
         ...,
         [0.9885, 0.6013, 0.0174, 0.9407, 0.4589],
         [0.7305, 0.9701, 0.9287, 0.1451, 0.9505],
         [0.5324, 0.3756, 0.8453, 0.3226, 0.3585]],

        [[0.4535, 0.3415, 0.2855, 0.9170, 0.7606],
         [0.9668, 0.6466, 0.3859, 0.1298, 0.8456],
         [0.6013, 0.5249, 0.2071, 0.4356, 0.7695],
         ...,
         [0.4311, 0.7603, 0.1298, 0.6016, 0.8848],
         [0.2243, 0.7400, 0.8467, 0.0974, 0.5645],
         [0.9781, 0.8408, 0.5260, 0.0680, 0.8351]]])

In [29]:
seq_unmask[0,:10,:]

tensor([[1., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 1.],
        [0., 0., 1., 0., 0., 1.],
        [0., 0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0.]])

In [None]:
#let's manually calculate this to make sure it works like we would expect!

loss_list = []
for b in range(seq.shape[0]):
    tempseq = seq[b]
    tempseq_unmask = seq_unmask[b]
    for i in range(seq.shape[1]):
        if tempseq_unmask[i,-1] == 0: #only looks at masked values!
            continue
        #now manually calculate the loss
        #raise to the power of exp
        pred = tempseq[i]
        target = tempseq_unmask[i,:-1]
        # print(pred.shape, target.shape) #both are 5 as you would expect
        # break
        print(i,target)
        true = torch.argmax(target)
        print(true)
        if i > 4:
            break
        # break

0 tensor([1., 0., 0., 0., 0.])
tensor(0)
1 tensor([1., 0., 0., 0., 0.])
tensor(0)
5 tensor([0., 0., 1., 0., 0.])
tensor(2)
0 tensor([1., 0., 0., 0., 0.])
tensor(0)
1 tensor([1., 0., 0., 0., 0.])
tensor(0)
5 tensor([0., 0., 1., 0., 0.])
tensor(2)


In [None]:
from tqdm import tqdm
loss_list = []
for b in range(seq.shape[0]):
    tempseq = seq[b]
    tempseq_unmask = seq_unmask[b]
    for i in range(seq.shape[1]):
        if tempseq_unmask[i,-1] == 0: #only looks at masked values!
            continue
        #now manually calculate the loss
        #raise to the power of exp
        pred = tempseq[i]
        target = tempseq_unmask[i,:-1]
        # print(pred.shape, target.shape) #both are 5 as you would expect
        # break
        # print(i,target)
        true = torch.argmax(target)
        # print(true)
        exp = torch.exp(pred)
        exp_sum = torch.sum(exp)
        s = exp[true]/exp_sum
        L = -torch.log(s)
        loss_list.append(L)
        

100%|██████████| 524288/524288 [00:09<00:00, 56930.46it/s]
100%|██████████| 524288/524288 [00:08<00:00, 58837.10it/s]


In [None]:
sum(loss_list)/len(loss_list) #bro it's literally exactly accurate lmao!

tensor(1.6420)

In [None]:
#technicallyy we should be providing class indices not one hot encoded values
def ce_loss_mask_seq(seq,seq_unmask,acc,acc_unmask):
    '''cross entropy loss function for sequence and accessibility classification
    seq: (batch_size, seq_len, vocab_size)
    seq_unmask: (batch_size, seq_len, vocab_size+1) #the last one is the mask
    acc: Not used
    acc_unmask: Not used
    '''
    
    #mask out useless elements, note this will collapse the batch dimension but that's ok
    mask = seq_unmask[:,:,-1] == 1
    seq = seq[mask]
    seq_unmask = seq_unmask[mask]
    
    # print(seq.shape, seq_unmask.shape)
    
    seq_unmask = seq_unmask[:,:-1] #remove the mask dim
    #now convert to class indices
    seq_unmask = torch.argmax(seq_unmask, dim=-1)
    
    #now compute the loss
    loss = F.cross_entropy(seq, seq_unmask)
    return loss

loss = ce_loss_mask_seq(seq,seq_unmask,acc,acc_unmask)
print('loss:', loss) #literallyy the same...

loss: tensor(1.6420)


In [35]:
loss.item()

1.6419885158538818

In [None]:
(sum(loss_list)/len(loss_list)).item()
#

1.6419991254806519

In [None]:
(sum(loss_list)/len(loss_list)) - loss #so tiny it doesn't even matter lmao!

tensor(1.0610e-05)

In [39]:
#now let's look at accessibility
print(acc.shape, acc_unmask.shape)

torch.Size([2, 524288]) torch.Size([524288, 2])


In [44]:
seq, acc, seq_unmask, acc_unmask = dataset_mask_cont[0]

In [45]:
acc = torch.rand(2, 524288, 1)
acc_unmask = acc_unmask.unsqueeze(0)
acc_unmask = torch.cat([acc_unmask, acc_unmask], dim=0)
print(acc.shape, acc_unmask.shape)

torch.Size([2, 524288, 1]) torch.Size([2, 524288, 2])


In [52]:
def poisson_loss_mask(seq,seq_unmask,acc,acc_unmask):
    '''poisson loss function for sequence and accessibility regression
    seq: Not used
    seq_unmask: Not used
    acc: (batch_size, seq_len, 1)
    acc_unmask: (batch_size, seq_len, 2)
    '''
    #subset it to the values that are beign evaluated
    acc = acc.squeeze(-1)
    mask = acc_unmask[:,:,1] == 1
    acc = acc[mask]
    acc_unmask = acc_unmask[mask][:,0] #remove the mask dim
    acc = F.softplus(acc)
    print(acc.shape, acc_unmask.shape)
    
    #and now compute the loss
    loss = F.poisson_nll_loss(acc, acc_unmask, log_input=False, full=False)
    return loss
loss = poisson_loss_mask(seq,seq_unmask,acc,acc_unmask)
print('loss:', loss)

torch.Size([279000]) torch.Size([279000])
loss: tensor(0.9876)


In [63]:
#let's manually calculate the nll loss too
loss_list = []
for b in range(acc.shape[0]):
    tempseq = acc[b]
    tempseq_unmask = acc_unmask[b]
    for i in tqdm(range(acc.shape[1])):
        if tempseq_unmask[i,-1] == 0: #only looks at masked values!
            continue
        #now manually calculate the loss
        #raise to the power of exp
        pred = tempseq[i]
        target = tempseq_unmask[i,:-1]
        # print(pred.shape, target.shape) #both are 1 as we exxpect here since it's elementwise!
        # break
        # print(i,target)
        #now do softplus on the prediction
        pred = torch.log(1+torch.exp(pred))
        L = pred - target*torch.log(pred+1e-8)
        loss_list.append(L)

100%|██████████| 524288/524288 [00:09<00:00, 53249.06it/s]
100%|██████████| 524288/524288 [00:10<00:00, 52059.39it/s]


In [64]:
sum(loss_list)/len(loss_list) #hmmm, why is it different here... OH, it's because we forgot the softplus lmao!

tensor([0.9876])

In [None]:
len(loss_list) #so same number of elements, why is it different...

279000

In [65]:
(sum(loss_list)/len(loss_list)).item() - loss.item() #so it's the same, just need to add the softplus!

9.119510650634766e-06

In [66]:
#final one is the weirdedst which is for cross entropy
seq, acc, seq_unmask, acc_unmask = dataset_mask_cat[0]
print(seq.shape, acc.shape, seq_unmask.shape, acc_unmask.shape)

torch.Size([6, 524288]) torch.Size([3, 524288]) torch.Size([524288, 6]) torch.Size([524288, 3])


In [None]:
acc = torch.rand(2, 524288, 1)
acc_unmask = acc_unmask.unsqueeze(0)
acc_unmask = torch.cat([acc_unmask, acc_unmask], dim=0)
print(acc.shape, acc_unmask.shape) #matches what we want!

torch.Size([2, 524288, 1]) torch.Size([2, 524288, 3])


In [77]:
def ce_loss_mask_acc(seq,seq_unmask,acc,acc_unmask): #separate so we can profile them separately, also, we have a single value, so use binary cross entropy
    '''cross entropy loss function for sequence and accessibility classification
    seq: Not used
    seq_unmask: Not used
    acc: (batch_size, seq_len, 1)
    acc_unmask: (batch_size, seq_len, 3) #the last one is the mask
    '''    
    #mask out useless elements
    acc = acc.squeeze(-1)
    mask = acc_unmask[:,:,2] == 1
    acc = acc[mask]
    acc_unmask = acc_unmask[mask]
    
    acc = acc.squeeze(0)
    acc_unmask = acc_unmask[:,1] #removes mask dim and just gets the values where it is accessible!
    
    #now compute the loss
    loss = F.binary_cross_entropy_with_logits(acc, acc_unmask)
    return loss

loss = ce_loss_mask_acc(seq,seq_unmask,acc,acc_unmask)
print('loss:', loss)

loss: tensor(0.9825)


In [None]:
#now let's manually calculate this!

loss_list = []
for b in range(acc.shape[0]):
    tempseq = acc[b]
    tempseq_unmask = acc_unmask[b]
    for i in tqdm(range(acc.shape[1])):
        if tempseq_unmask[i,-1] == 0: #only looks at masked values!
            continue
        
        #can think of BCE as CE where one class prob is sigmoid(pred) and the other is 1-sigmoid(pred)
        pred = tempseq[i]
        target = tempseq_unmask[i,:-1]
        # print(pred.shape, target.shape) #1 and 2 as we expect
        # break
        prob = torch.sigmoid(pred) #sigmoid is basically the probability of 1, already normalized too!
        prob2 = 1-prob
        true = torch.argmax(target)
        probs = torch.stack([prob2, prob]) #importantly we have to assign probabilityy of 1 as the sigmoid as higher means more likely 1
        # exp = torch.exp(probs)
        # exp_sum = torch.sum(exp)
        # s = exp[true]/exp_sum
        s = probs[true]
        L = -torch.log(s)
        loss_list.append(L)

100%|██████████| 524288/524288 [00:09<00:00, 52945.08it/s]
100%|██████████| 524288/524288 [00:10<00:00, 52075.25it/s]


In [79]:
sum(loss_list)/len(loss_list) #so it's the same!

tensor([0.9825])

In [80]:
(sum(loss_list)/len(loss_list)).item() - loss.item() #so it's the same!

9.47713851928711e-06

In [None]:
#what if we calculate the traditional way? eh not worth testing, it's fine!!


In [85]:
#let's see what perfect loss is like

def ce_loss_mask_seq(x, y):
    """
    Cross entropy loss for sequence classification.
    
    x: tuple (seq, dummy)
         - seq: (batch_size, seq_len, vocab_size)
    y: tuple (seq_unmask, dummy)
         - seq_unmask: (batch_size, seq_len, vocab_size+1)  (last channel is the mask)
    """
    seq = x[0]
    seq_unmask = y[0]
    
    # Create mask from last column of seq_unmask
    mask = seq_unmask[:, :, -1] == 1
    seq_pred = seq[mask]
    # Remove mask channel from target; resulting shape is (N, vocab_size)
    seq_target = seq_unmask[mask][:, :-1]
    
    loss = F.cross_entropy(seq_pred, seq_target)
    return loss

seq_unmask.shape

torch.Size([524288, 6])

In [86]:
seq_unmask = seq_unmask.unsqueeze(0)
seq_unmask.shape

torch.Size([1, 524288, 6])

In [91]:
#now input for perfect loss
ce_loss_mask_seq((seq_unmask[:,:,:-1], None), (seq_unmask, None))

tensor(0.9048)

In [93]:
mask = seq_unmask[:,:,-1] == 1
s1 = seq_unmask[mask][:, :-1]

In [94]:
s1.shape

torch.Size([131039, 5])

In [95]:
F.cross_entropy(s1, s1)

tensor(0.9048)

In [96]:
#let's define s1 in terms of indices
s1_idx = torch.argmax(s1, dim=-1)
s1_idx.shape

torch.Size([131039])

In [97]:
F.cross_entropy(s1, s1_idx)

tensor(0.9048)

In [None]:
s1.sum(1).max() #they're all 1s because it's 1 hot true values...

tensor(1.)

In [100]:
rand_idx = torch.randint(0, 5, (131039,))
rand_onehot = torch.zeros(131039, 5)
rand_onehot[torch.arange(131039), rand_idx] = 1
print(rand_onehot.shape, rand_idx.shape)

torch.Size([131039, 5]) torch.Size([131039])


In [101]:
F.cross_entropy(rand_onehot, rand_idx)

tensor(0.9048)

In [None]:
#oh that isn't perfect, here's what perfect is
#logits aren't already softmaxed, so can keep going down
F.cross_entropy(rand_onehot*1e9, rand_idx)

tensor(0.)

In [110]:
F.cross_entropy(rand_onehot*1000, rand_onehot) #so rand idx or rand onehot is fine

tensor(-0.)