In [10]:
import numpy as np
import torch
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [11]:
def my_risk_concordance_loss(risk, indctr, tcnv):
    # Which pairs can be compared
    # 1. tcnv_row< tcnv_col
    # 2. (i)indctr_row should not be censored    or   (ii) row & col come from same eye
    # censored = 1 => conversion event happened, 0=> censored
    
    ###### eye_flag is a unique index id s.t. all rsks of scans of the same patient have the same id ###
    B=indctr.shape[0]//3 # rsk1 from img1 of pair, rsk2 from img2 of pair, rsk3 from ODE 
    eye_flag=torch.arange(start=0, end=B, step=1).to(device) # 0 to B-1
    eye_flag=torch.cat((eye_flag, eye_flag, eye_flag), dim=0) # 3B   
    eye_flag=torch.unsqueeze(eye_flag, dim=1) # 3B,1
    #print('eye_flag shp: '+str(eye_flag.shape))# 
    #print(eye_flag)
    ######################################################
    eye_flag_row=eye_flag.repeat(1,(3*B)) # 3B, 3B  repeat a column 3B times
    eye_flag_row=eye_flag_row.view(-1) # 9B^2
    
    eye_flag_col=torch.transpose(eye_flag, 0,1) # 1,3B
    eye_flag_col=eye_flag_col.repeat((3*B),1) # 3B, 3B repeat a row 3B times
    eye_flag_col=eye_flag_col.view(-1) # 9B^2
    #print('eye_flag_row: '+str(eye_flag_row.shape)+'  eye_flag_col: '+str(eye_flag_col.shape))
    
    ### repeat the process for tcnv
    #print('tcnv: '+str(tcnv.shape))
    tcnv_row=tcnv.repeat(1, (3*B)) # 3B,3B
    tcnv_row=tcnv_row.view(-1) # 9B^2
    
    tcnv_col=torch.transpose(tcnv, 0,1) # 1,3B
    tcnv_col=tcnv_col.repeat((3*B), 1) # 3B, 3B
    tcnv_col=tcnv_col.view(-1) # 9B^2
    
    #print('tcnv_col: '+str(tcnv_col.shape)+'  tcnv_row: '+str(tcnv_row.shape))
    
    ### repeat the process for indictr
    indctr_row=indctr.repeat(1, (3*B)) # 3B,3B
    indctr_row=indctr_row.view(-1)  # 9B^2
    
    indctr_col=torch.transpose(indctr, 0,1) # 1,3B
    indctr_col=indctr_col.repeat((3*B),1)   # 3B,3B
    indctr_col=indctr_col.view(-1)  # 9B^2
    
    #print('indctr_col: '+str(indctr_col.shape)+'  indctr_row: '+str(indctr_row.shape))
    
    ### repeat the process for risk scores 
    risk_row=risk.repeat(1, (3*B)) # 3B,3B
    risk_row=risk_row.view(-1)  # 9B^2
    
    risk_col=torch.transpose(risk, 0, 1) # 1,3B
    risk_col=risk_col.repeat((3*B), 1) # 3B,3B
    risk_col=risk_col.view(-1)   # 9B^2
    
    #print('risk_row: '+str(risk_row.shape)+'  risk_col: '+str(risk_col.shape))
    ######################################################
    # compute the difference in risk
    risk_mtrx=(risk_row-risk_col) # 9B^2
    
    #print('risk_mtrx: '+str(risk_mtrx.shape))
    #############################################
    
    # 1. Choose pairs where tcnv_row< tcnv_col, and 
    # 2. Either (i)indctr_row should not be censored    or   (ii) row & col come from same eye
    # Event indicator 1=>event occured; 0=>censored
    
    idx=(tcnv_row<tcnv_col) & ( (indctr_row==1) | (eye_flag_row==eye_flag_col) )
    #print(idx)
    risk_mtrx=risk_mtrx[idx] # P
    del idx
    ##############################################################################################
    risk_mtrx=torch.unsqueeze(risk_mtrx,dim=1) # P,1
    #print('risk_mtrx: '+str(risk_mtrx.shape))
    #print(risk_mtrx)
    risk_mtrx=rank_model(risk_mtrx)
    #print('risk mtrx: '+str(risk_mtrx.shape))
    #print(risk_mtrx)
    ##############################################################################################
    
    ####################################################
    #gt=torch.ones_like(risk_mtrx).to(device)
    gt=torch.zeros_like(risk_mtrx).to(device)
    ####################################################
    
    #print('gt: '+str(gt.shape))
    #print(gt)
    
    # Now randomly reverse the risk matrix and the gt for half of the pairs
    r=torch.randint_like(risk_mtrx, low=0, high=2) # 0 or 1
    #print(torch.unique(gt))
    idx=(r==1)
    
    ###################################################
    #gt[idx]=0
    gt[idx]=1
    ###################################################
    
    risk_mtrx[idx]=risk_mtrx[idx]*-1
    #print('idx: '+str(idx.shape)+'  gt: '+str(gt.shape))
    #print(torch.unique(gt))
    #print(gt)
    ######
    # tcnv_row<tcnv_col and tcnv_row is not censored, ie. it is conversion date. 
    # So, row converts before col => risk_row should be >risk_col. Then
    # (risk_row-risk_col) should be 
        #>0 for correct ordering(GT=1) & 
        #<0 for incorrect ordering(GT=0, if the risk order is reversed).
        # Sigmoid makes <0 in range [0 to 0.5]    >0 in range [0.5 to 1]
    
    loss=F.binary_cross_entropy_with_logits(risk_mtrx, gt)
    return loss
    

a=torch.ones((10,1))
pred_rsk=torch.ones((30,1)).to(device)


indctr=torch.ones((30,1)).to(device)
tcnv=torch.cat((3*a, a, a), dim=0).to(device)
loss=my_risk_concordance_loss(pred_rsk, indctr, tcnv)
print(loss)